Source code for lalandre_rag.graph.ranker

"""
Graph node ranking for Graph RAG.

Scores and ranks graph-expanded nodes by relevance to avoid sending
noise to the LLM.  Three signals are combined:

1. **Hop distance** – nodes closer to the seed acts score higher.
2. **Semantic overlap** – nodes that also appear in the Qdrant results
   get a boost (they matched the query both semantically and structurally).
3. **Relation type weight** – AMENDS / IMPLEMENTS (strong legal ties)
   outweigh CITES / DEROGATES (weaker references).

Usage::

    ranked = rank_graph_nodes(
        graph_context=graph_context,
        relationships=relationships,
        semantic_act_ids=semantic_act_ids,
        seed_act_ids=seed_act_ids,
    )
    # ranked is sorted best-first; slice to your budget
"""

import logging
from typing import Any, Dict, List, Set

from lalandre_core.config import get_config

from lalandre_rag.scoring import (
    clamp_unit_interval,
    non_negative,
    normalize_by_max,
    round_score,
)

logger = logging.getLogger(__name__)


def _relation_weight(relation_type: str) -> float:
    cfg = get_config().graph
    return cfg.ranking_relation_weights.get(relation_type.upper(), cfg.ranking_default_relation_weight)


def _max_relation_weight() -> float:
    """Return the strongest configured relation weight for normalization."""
    cfg = get_config().graph
    weights = [non_negative(float(weight)) for weight in cfg.ranking_relation_weights.values()]
    weights.append(non_negative(float(cfg.ranking_default_relation_weight)))
    return max(weights, default=1.0) or 1.0


# ── Hop distance estimation ───────────────────────────────────────────────


def _estimate_hop_distances(
    seed_act_ids: Set[int],
    relationships: List[Dict[str, Any]],
    max_depth: int = 5,
) -> Dict[int, int]:
    """
    BFS-style hop distance from the seed set.

    Returns ``{act_id: min_hop_distance}`` for every reachable node.
    Seed nodes have distance 0.
    """
    distances: Dict[int, int] = {aid: 0 for aid in seed_act_ids}
    frontier = set(seed_act_ids)
    adjacency: Dict[int, Set[int]] = {}
    for rel in relationships:
        src = rel.get("start_node")
        tgt = rel.get("end_node")
        if src is None or tgt is None:
            continue
        src, tgt = int(src), int(tgt)
        adjacency.setdefault(src, set()).add(tgt)
        adjacency.setdefault(tgt, set()).add(src)

    for hop in range(1, max_depth + 1):
        next_frontier: Set[int] = set()
        for node in frontier:
            for neighbor in adjacency.get(node, set()):
                if neighbor not in distances:
                    distances[neighbor] = hop
                    next_frontier.add(neighbor)
        frontier = next_frontier
        if not frontier:
            break
    return distances


# ── Main scoring function ─────────────────────────────────────────────────


[docs] def rank_graph_nodes( *, graph_context: List[Dict[str, Any]], relationships: List[Dict[str, Any]], semantic_act_ids: Set[int], seed_act_ids: Set[int], max_depth: int = 5, hop_decay: float = 0.5, semantic_boost: float = 0.3, relation_weight_factor: float = 0.25, ) -> List[Dict[str, Any]]: """Score and rank graph-expanded nodes. Each node receives a normalized composite score in ``[0, 1]`` based on hop distance, semantic overlap, and incident relation strength. Args: graph_context: Graph-expanded act nodes to score. relationships: Graph relationships connecting the candidate nodes. semantic_act_ids: Act identifiers also returned by semantic search. seed_act_ids: Seed act identifiers used to start graph expansion. max_depth: Maximum BFS depth used to estimate hop distance. hop_decay: Exponential decay applied to hop distance. semantic_boost: Non-negative weight applied to semantic overlap. relation_weight_factor: Non-negative weight applied to relation strength. Returns: The input nodes enriched with ranking metadata and sorted best-first. """ if not graph_context: return [] hop_weight = 1.0 semantic_weight = non_negative(float(semantic_boost)) relation_weight = non_negative(float(relation_weight_factor)) effective_hop_decay = clamp_unit_interval(float(hop_decay)) max_relation_weight = _max_relation_weight() total_weight = hop_weight + semantic_weight + relation_weight # 1. Hop distances hop_distances = _estimate_hop_distances(seed_act_ids, relationships, max_depth) # 2. Relation-type score per node (average weight of incident relations) node_relation_weights: Dict[int, List[float]] = {} for rel in relationships: rtype = str(rel.get("type", "RELATED_TO")).upper() w = non_negative(float(_relation_weight(rtype))) for endpoint in (rel.get("start_node"), rel.get("end_node")): if endpoint is not None: node_relation_weights.setdefault(int(endpoint), []).append(w) # 3. Score each node scored: List[Dict[str, Any]] = [] for node in graph_context: act_id = node.get("id") if act_id is None: scored.append( { **node, "_rank_score": 0.0, "_rank_score_raw": 0.0, "_rank_trace": { "hop_distance": None, "hop_signal": 0.0, "semantic_signal": 0.0, "relation_signal": 0.0, "raw_score": 0.0, "normalized_score": 0.0, }, } ) continue act_id = int(act_id) # Hop score: exponential decay hop = hop_distances.get(act_id, max_depth + 1) hop_signal = effective_hop_decay**hop # Semantic overlap bonus semantic_signal = 1.0 if act_id in semantic_act_ids else 0.0 # Relation-type score weights = node_relation_weights.get(act_id, []) avg_rel_weight = (sum(weights) / len(weights)) if weights else 0.0 relation_signal = normalize_by_max(avg_rel_weight, max_relation_weight) raw_score = hop_weight * hop_signal + semantic_weight * semantic_signal + relation_weight * relation_signal normalized_score = raw_score / total_weight if total_weight > 0 else 0.0 scored.append( { **node, "_rank_score": round_score(normalized_score), "_rank_score_raw": round_score(raw_score), "_rank_trace": { "hop_distance": hop, "hop_signal": round_score(hop_signal), "semantic_signal": round_score(semantic_signal), "relation_signal": round_score(relation_signal), "weights": { "hop": hop_weight, "semantic": round_score(semantic_weight), "relation": round_score(relation_weight), }, "avg_relation_weight_raw": round_score(avg_rel_weight), "max_relation_weight_raw": round_score(max_relation_weight), "raw_score": round_score(raw_score), "normalized_score": round_score(normalized_score), }, } ) scored.sort(key=lambda n: n.get("_rank_score", 0.0), reverse=True) return scored
[docs] def rank_relationships( *, relationships: List[Dict[str, Any]], top_act_ids: Set[int], ) -> List[Dict[str, Any]]: """ Keep only relationships that connect nodes in ``top_act_ids`` and sort by relation-type weight descending. """ max_relation_weight = _max_relation_weight() kept: List[Dict[str, Any]] = [] for rel in relationships: src = rel.get("start_node") tgt = rel.get("end_node") if src is None or tgt is None: continue if int(src) in top_act_ids and int(tgt) in top_act_ids: rtype = str(rel.get("type", "RELATED_TO")).upper() raw_weight = non_negative(float(_relation_weight(rtype))) normalized_weight = normalize_by_max(raw_weight, max_relation_weight) kept.append( { **rel, "_rel_weight": round_score(normalized_weight), "_rel_weight_raw": round_score(raw_weight), "_rel_trace": { "type": rtype, "raw_weight": round_score(raw_weight), "max_weight": round_score(max_relation_weight), "normalized_score": round_score(normalized_weight), }, } ) kept.sort(key=lambda r: r.get("_rel_weight", 0.0), reverse=True) return kept