"""
Graph node ranking for Graph RAG.
Scores and ranks graph-expanded nodes by relevance to avoid sending
noise to the LLM. Three signals are combined:
1. **Hop distance** – nodes closer to the seed acts score higher.
2. **Semantic overlap** – nodes that also appear in the Qdrant results
get a boost (they matched the query both semantically and structurally).
3. **Relation type weight** – AMENDS / IMPLEMENTS (strong legal ties)
outweigh CITES / DEROGATES (weaker references).
Usage::
ranked = rank_graph_nodes(
graph_context=graph_context,
relationships=relationships,
semantic_act_ids=semantic_act_ids,
seed_act_ids=seed_act_ids,
)
# ranked is sorted best-first; slice to your budget
"""
import logging
from typing import Any, Dict, List, Set
from lalandre_core.config import get_config
from lalandre_rag.scoring import (
clamp_unit_interval,
non_negative,
normalize_by_max,
round_score,
)
logger = logging.getLogger(__name__)
def _relation_weight(relation_type: str) -> float:
cfg = get_config().graph
return cfg.ranking_relation_weights.get(relation_type.upper(), cfg.ranking_default_relation_weight)
def _max_relation_weight() -> float:
"""Return the strongest configured relation weight for normalization."""
cfg = get_config().graph
weights = [non_negative(float(weight)) for weight in cfg.ranking_relation_weights.values()]
weights.append(non_negative(float(cfg.ranking_default_relation_weight)))
return max(weights, default=1.0) or 1.0
# ── Hop distance estimation ───────────────────────────────────────────────
def _estimate_hop_distances(
seed_act_ids: Set[int],
relationships: List[Dict[str, Any]],
max_depth: int = 5,
) -> Dict[int, int]:
"""
BFS-style hop distance from the seed set.
Returns ``{act_id: min_hop_distance}`` for every reachable node.
Seed nodes have distance 0.
"""
distances: Dict[int, int] = {aid: 0 for aid in seed_act_ids}
frontier = set(seed_act_ids)
adjacency: Dict[int, Set[int]] = {}
for rel in relationships:
src = rel.get("start_node")
tgt = rel.get("end_node")
if src is None or tgt is None:
continue
src, tgt = int(src), int(tgt)
adjacency.setdefault(src, set()).add(tgt)
adjacency.setdefault(tgt, set()).add(src)
for hop in range(1, max_depth + 1):
next_frontier: Set[int] = set()
for node in frontier:
for neighbor in adjacency.get(node, set()):
if neighbor not in distances:
distances[neighbor] = hop
next_frontier.add(neighbor)
frontier = next_frontier
if not frontier:
break
return distances
# ── Main scoring function ─────────────────────────────────────────────────
[docs]
def rank_graph_nodes(
*,
graph_context: List[Dict[str, Any]],
relationships: List[Dict[str, Any]],
semantic_act_ids: Set[int],
seed_act_ids: Set[int],
max_depth: int = 5,
hop_decay: float = 0.5,
semantic_boost: float = 0.3,
relation_weight_factor: float = 0.25,
) -> List[Dict[str, Any]]:
"""Score and rank graph-expanded nodes.
Each node receives a normalized composite score in ``[0, 1]`` based on
hop distance, semantic overlap, and incident relation strength.
Args:
graph_context: Graph-expanded act nodes to score.
relationships: Graph relationships connecting the candidate nodes.
semantic_act_ids: Act identifiers also returned by semantic search.
seed_act_ids: Seed act identifiers used to start graph expansion.
max_depth: Maximum BFS depth used to estimate hop distance.
hop_decay: Exponential decay applied to hop distance.
semantic_boost: Non-negative weight applied to semantic overlap.
relation_weight_factor: Non-negative weight applied to relation strength.
Returns:
The input nodes enriched with ranking metadata and sorted best-first.
"""
if not graph_context:
return []
hop_weight = 1.0
semantic_weight = non_negative(float(semantic_boost))
relation_weight = non_negative(float(relation_weight_factor))
effective_hop_decay = clamp_unit_interval(float(hop_decay))
max_relation_weight = _max_relation_weight()
total_weight = hop_weight + semantic_weight + relation_weight
# 1. Hop distances
hop_distances = _estimate_hop_distances(seed_act_ids, relationships, max_depth)
# 2. Relation-type score per node (average weight of incident relations)
node_relation_weights: Dict[int, List[float]] = {}
for rel in relationships:
rtype = str(rel.get("type", "RELATED_TO")).upper()
w = non_negative(float(_relation_weight(rtype)))
for endpoint in (rel.get("start_node"), rel.get("end_node")):
if endpoint is not None:
node_relation_weights.setdefault(int(endpoint), []).append(w)
# 3. Score each node
scored: List[Dict[str, Any]] = []
for node in graph_context:
act_id = node.get("id")
if act_id is None:
scored.append(
{
**node,
"_rank_score": 0.0,
"_rank_score_raw": 0.0,
"_rank_trace": {
"hop_distance": None,
"hop_signal": 0.0,
"semantic_signal": 0.0,
"relation_signal": 0.0,
"raw_score": 0.0,
"normalized_score": 0.0,
},
}
)
continue
act_id = int(act_id)
# Hop score: exponential decay
hop = hop_distances.get(act_id, max_depth + 1)
hop_signal = effective_hop_decay**hop
# Semantic overlap bonus
semantic_signal = 1.0 if act_id in semantic_act_ids else 0.0
# Relation-type score
weights = node_relation_weights.get(act_id, [])
avg_rel_weight = (sum(weights) / len(weights)) if weights else 0.0
relation_signal = normalize_by_max(avg_rel_weight, max_relation_weight)
raw_score = hop_weight * hop_signal + semantic_weight * semantic_signal + relation_weight * relation_signal
normalized_score = raw_score / total_weight if total_weight > 0 else 0.0
scored.append(
{
**node,
"_rank_score": round_score(normalized_score),
"_rank_score_raw": round_score(raw_score),
"_rank_trace": {
"hop_distance": hop,
"hop_signal": round_score(hop_signal),
"semantic_signal": round_score(semantic_signal),
"relation_signal": round_score(relation_signal),
"weights": {
"hop": hop_weight,
"semantic": round_score(semantic_weight),
"relation": round_score(relation_weight),
},
"avg_relation_weight_raw": round_score(avg_rel_weight),
"max_relation_weight_raw": round_score(max_relation_weight),
"raw_score": round_score(raw_score),
"normalized_score": round_score(normalized_score),
},
}
)
scored.sort(key=lambda n: n.get("_rank_score", 0.0), reverse=True)
return scored
[docs]
def rank_relationships(
*,
relationships: List[Dict[str, Any]],
top_act_ids: Set[int],
) -> List[Dict[str, Any]]:
"""
Keep only relationships that connect nodes in ``top_act_ids``
and sort by relation-type weight descending.
"""
max_relation_weight = _max_relation_weight()
kept: List[Dict[str, Any]] = []
for rel in relationships:
src = rel.get("start_node")
tgt = rel.get("end_node")
if src is None or tgt is None:
continue
if int(src) in top_act_ids and int(tgt) in top_act_ids:
rtype = str(rel.get("type", "RELATED_TO")).upper()
raw_weight = non_negative(float(_relation_weight(rtype)))
normalized_weight = normalize_by_max(raw_weight, max_relation_weight)
kept.append(
{
**rel,
"_rel_weight": round_score(normalized_weight),
"_rel_weight_raw": round_score(raw_weight),
"_rel_trace": {
"type": rtype,
"raw_weight": round_score(raw_weight),
"max_weight": round_score(max_relation_weight),
"normalized_score": round_score(normalized_weight),
},
}
)
kept.sort(key=lambda r: r.get("_rel_weight", 0.0), reverse=True)
return kept