Source code for lalandre_rag.modes.hybrid_helpers

"""
Helpers for HybridMode — context assembly, source building, metadata, and citation.

Extracted from hybrid_mode.py to keep the orchestrator focused on the pipeline.
"""

from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

from lalandre_rag.graph.context_budget import GraphContextBudget
from lalandre_rag.graph.ranker import rank_graph_nodes, rank_relationships
from lalandre_rag.response import (
    format_doc_location,
    format_source_header,
    validate_citations,
)
from lalandre_rag.retrieval.context.community_reports import CommunityReport
from lalandre_rag.retrieval.context.models import ContextSlice
from lalandre_rag.retrieval.query_router import RetrievalPlan
from lalandre_rag.scoring import SCORE_SCALE_NORMALIZED_0_1

ProgressCallback = Optional[Callable[[Dict[str, Any]], None]]


[docs] def emit_progress( callback: ProgressCallback, *, phase: str, status: str, label: str, detail: Optional[str] = None, count: Optional[int] = None, duration_ms: Optional[float] = None, meta: Optional[Dict[str, Any]] = None, ) -> None: """Emit a structured progress event when a callback is configured.""" if callback is None: return event: Dict[str, Any] = { "phase": phase, "status": status, "label": label, } if detail: event["detail"] = detail if count is not None: event["count"] = count if duration_ms is not None: event["duration_ms"] = round(float(duration_ms), 1) if meta: event["meta"] = meta callback(event)
# --------------------------------------------------------------------------- # Context assembly # ---------------------------------------------------------------------------
[docs] def build_source_context( *, context_slices: List[ContextSlice], max_context_chars: int, min_chars_per_source: int, max_sources: int, ) -> tuple[str, List[Dict[str, Any]], int]: """Build the source-context block and return (context_text, refs, remaining_chars).""" remaining_chars = max(max_context_chars, 0) selected_slices = context_slices[:max_sources] total_docs = len(selected_slices) context_parts: List[str] = [] refs: List[Dict[str, Any]] = [] for idx, doc in enumerate(selected_slices, start=1): source_id = f"S{idx}" location = format_doc_location( doc.doc.chunk_id, doc.doc.chunk_index, doc.doc.subdivision_type, doc.doc.subdivision_id, ) header = format_source_header( source_id, doc.act.celex, location, doc.act.title, regulatory_level=doc.act.regulatory_level, ) header_len = len(header) + 1 if remaining_chars <= 0: break if header_len > remaining_chars: break remaining_chars = max(0, remaining_chars - header_len) remaining_docs = total_docs - idx content = doc.content or "" if remaining_chars <= 0: content_used = "" else: reserved_for_others = min_chars_per_source * max(remaining_docs, 0) alloc = max(0, remaining_chars - reserved_for_others) if alloc < min_chars_per_source: alloc = min(remaining_chars, min_chars_per_source) content_used = content[:alloc] remaining_chars = max(0, remaining_chars - len(content_used)) context_parts.append(f"{header}\n{content_used}") refs.append( { "doc": doc, "content_used": content_used, "content_truncated": bool(content and len(content_used) < len(content)), "source_id": source_id, } ) context = "\n\n---\n\n".join(context_parts) return context, refs, remaining_chars
[docs] def build_relation_summary( *, context_slices: List[ContextSlice], line_limit: int, ) -> str: """Build a compact relation-signals block for the LLM context.""" if line_limit <= 0: return "" act_celex_map = {item.act.act_id: item.act.celex for item in context_slices} lines: List[str] = ["--- Relation signals ---"] seen: set[tuple[int, int | None, str, str | None]] = set() for item in context_slices: relations = item.act.relations or [] for relation in relations: source_id = relation.get("source_act_id") target_id = relation.get("target_act_id") relation_type_raw = relation.get("relation_type") relation_type = str(relation_type_raw).upper() if relation_type_raw is not None else "RELATED_TO" description_raw = relation.get("description") description = ( str(description_raw).strip() if isinstance(description_raw, str) and description_raw.strip() else None ) if not isinstance(source_id, int): continue key = (source_id, target_id if isinstance(target_id, int) else None, relation_type, description) if key in seen: continue seen.add(key) source_celex = act_celex_map.get(source_id, f"ACT-{source_id}") if isinstance(target_id, int): target_celex = act_celex_map.get(target_id, f"ACT-{target_id}") else: target_celex_raw = relation.get("target_celex") target_celex = ( str(target_celex_raw).strip() if isinstance(target_celex_raw, str) and target_celex_raw.strip() else "EXTERNAL" ) line = f"- {source_celex} -[{relation_type}]-> {target_celex}" if description: line += f" | {description}" lines.append(line) if len(lines) - 1 >= line_limit: return "\n".join(lines) if len(lines) == 1: return "" return "\n".join(lines)
[docs] def format_reports_block(reports: List[CommunityReport]) -> str: """Render community reports into a text block for the LLM context.""" if not reports: return "" lines: List[str] = ["--- Community Reports ---"] for report in reports: lines.append(f"[{report.community_id}] {report.summary}") if report.central_acts: pivots = ", ".join(f"{item['celex']} (deg={item['degree']})" for item in report.central_acts) lines.append(f"Pivots: {pivots}") if report.top_relation_types: relation_types = ", ".join(f"{item['relation_type']}:{item['count']}" for item in report.top_relation_types) lines.append(f"Types: {relation_types}") for evidence in report.evidences: lines.append(f"Evidence: {evidence}") return "\n".join(lines)
# --------------------------------------------------------------------------- # Metadata & validation # ---------------------------------------------------------------------------
[docs] def attach_citation_validation( *, response: Dict[str, Any], answer: str, sources: List[Dict[str, Any]], ) -> None: """Validate mixed source citations in the answer and attach results to metadata.""" if not sources: return source_ids: List[str] = [] for source_doc in sources: source_id = source_doc.get("source_id") if isinstance(source_id, str) and source_id: source_ids.append(source_id) response["metadata"]["citation_validation"] = validate_citations(answer, source_ids)
[docs] def build_plan_metadata( *, retrieval_plan: RetrievalPlan, requested_top_k: int, effective_top_k: int, requested_granularity: Optional[str], effective_granularity: Optional[str], requested_include_relations: bool, effective_include_relations: bool, retrieval_query: str, original_question: str, ) -> Dict[str, Any]: """Serialize a retrieval plan to an audit-friendly metadata dict.""" return { "profile": retrieval_plan.profile, "use_graph": retrieval_plan.use_graph, "rationale": retrieval_plan.rationale, "execution_mode": retrieval_plan.execution_mode, "routing_source": retrieval_plan.routing_source, "intent_label": retrieval_plan.intent_label, "parser_confidence": retrieval_plan.parser_confidence, "requested_top_k": requested_top_k, "effective_top_k": effective_top_k, "requested_granularity": requested_granularity, "effective_granularity": effective_granularity, "requested_include_relations": requested_include_relations, "effective_include_relations": effective_include_relations, "retrieval_query": retrieval_query, "query_rewritten": retrieval_query.strip() != original_question.strip(), "query_rewritten_by_parser": bool(retrieval_plan.search_query), }
# ── Graph fetch result + ranked context builder ─────────────────────────
[docs] @dataclass class GraphFetchResult: """Raw output from Neo4j graph expansion (before ranking).""" nodes: List[Dict[str, Any]] relationships: List[Dict[str, Any]] seed_act_ids: Set[int] expanded_act_ids: Set[int] community_block: str = "" community_meta: Dict[str, Any] = field(default_factory=dict) duration_ms: float = 0.0
[docs] def build_ranked_graph_context( *, fetch_result: GraphFetchResult, semantic_results: List[Any], max_context_chars: int, graph_acts_limit: int, graph_relationships_limit: int, hop_decay: float, semantic_boost: float, relation_weight_factor: float, budget_semantic_share: float, budget_graph_share: float, budget_relation_share: float, min_chars_per_source: int, max_depth: int, ) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]], List[Dict[str, Any]]]: """Rank graph nodes/relationships and build budget-aware context. Returns (combined_context_str, metadata_dict, graph_node_refs, relationship_refs). """ semantic_act_ids = fetch_result.seed_act_ids ranked_nodes = rank_graph_nodes( graph_context=fetch_result.nodes, relationships=fetch_result.relationships, semantic_act_ids=semantic_act_ids, seed_act_ids=fetch_result.seed_act_ids, max_depth=max_depth, hop_decay=hop_decay, semantic_boost=semantic_boost, relation_weight_factor=relation_weight_factor, ) ranked_nodes = ranked_nodes[:graph_acts_limit] top_act_ids = {int(n["id"]) for n in ranked_nodes if n.get("id") is not None} ranked_rels = rank_relationships( relationships=fetch_result.relationships, top_act_ids=top_act_ids | fetch_result.seed_act_ids, ) ranked_rels = ranked_rels[:graph_relationships_limit] budget = GraphContextBudget( max_chars=max_context_chars, semantic_share=budget_semantic_share, graph_share=budget_graph_share, relation_share=budget_relation_share, min_chars_per_source=min_chars_per_source, ) ctx_result = budget.build( semantic_results=semantic_results, ranked_nodes=ranked_nodes, ranked_relationships=ranked_rels, ) combined = ctx_result.combined_context if fetch_result.community_block: combined = f"{fetch_result.community_block}\n\n{combined}" meta: Dict[str, Any] = { "ranked_nodes_used": len(ranked_nodes), "ranked_relationships_used": len(ranked_rels), "total_graph_nodes": len(fetch_result.nodes), "total_relationships": len(fetch_result.relationships), "graph_acts_used_for_context": ctx_result.graph_nodes_used, "graph_relationships_used_for_context": ctx_result.relationships_used, "context_budget": ctx_result.chars_used, "top_rank_score": ranked_nodes[0].get("_rank_score", 0) if ranked_nodes else 0, "top_rank_score_raw": ranked_nodes[0].get("_rank_score_raw", 0) if ranked_nodes else 0, "top_relation_score": ranked_rels[0].get("_rel_weight", 0) if ranked_rels else 0, "top_relation_score_raw": ranked_rels[0].get("_rel_weight_raw", 0) if ranked_rels else 0, "score_scale": SCORE_SCALE_NORMALIZED_0_1, } return combined, meta, ctx_result.graph_node_refs, ctx_result.relationship_refs