"""
Helpers for HybridMode — context assembly, source building, metadata, and citation.
Extracted from hybrid_mode.py to keep the orchestrator focused on the pipeline.
"""
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from lalandre_rag.graph.context_budget import GraphContextBudget
from lalandre_rag.graph.ranker import rank_graph_nodes, rank_relationships
from lalandre_rag.response import (
format_doc_location,
format_source_header,
validate_citations,
)
from lalandre_rag.retrieval.context.community_reports import CommunityReport
from lalandre_rag.retrieval.context.models import ContextSlice
from lalandre_rag.retrieval.query_router import RetrievalPlan
from lalandre_rag.scoring import SCORE_SCALE_NORMALIZED_0_1
ProgressCallback = Optional[Callable[[Dict[str, Any]], None]]
[docs]
def emit_progress(
callback: ProgressCallback,
*,
phase: str,
status: str,
label: str,
detail: Optional[str] = None,
count: Optional[int] = None,
duration_ms: Optional[float] = None,
meta: Optional[Dict[str, Any]] = None,
) -> None:
"""Emit a structured progress event when a callback is configured."""
if callback is None:
return
event: Dict[str, Any] = {
"phase": phase,
"status": status,
"label": label,
}
if detail:
event["detail"] = detail
if count is not None:
event["count"] = count
if duration_ms is not None:
event["duration_ms"] = round(float(duration_ms), 1)
if meta:
event["meta"] = meta
callback(event)
# ---------------------------------------------------------------------------
# Context assembly
# ---------------------------------------------------------------------------
[docs]
def build_source_context(
*,
context_slices: List[ContextSlice],
max_context_chars: int,
min_chars_per_source: int,
max_sources: int,
) -> tuple[str, List[Dict[str, Any]], int]:
"""Build the source-context block and return (context_text, refs, remaining_chars)."""
remaining_chars = max(max_context_chars, 0)
selected_slices = context_slices[:max_sources]
total_docs = len(selected_slices)
context_parts: List[str] = []
refs: List[Dict[str, Any]] = []
for idx, doc in enumerate(selected_slices, start=1):
source_id = f"S{idx}"
location = format_doc_location(
doc.doc.chunk_id,
doc.doc.chunk_index,
doc.doc.subdivision_type,
doc.doc.subdivision_id,
)
header = format_source_header(
source_id,
doc.act.celex,
location,
doc.act.title,
regulatory_level=doc.act.regulatory_level,
)
header_len = len(header) + 1
if remaining_chars <= 0:
break
if header_len > remaining_chars:
break
remaining_chars = max(0, remaining_chars - header_len)
remaining_docs = total_docs - idx
content = doc.content or ""
if remaining_chars <= 0:
content_used = ""
else:
reserved_for_others = min_chars_per_source * max(remaining_docs, 0)
alloc = max(0, remaining_chars - reserved_for_others)
if alloc < min_chars_per_source:
alloc = min(remaining_chars, min_chars_per_source)
content_used = content[:alloc]
remaining_chars = max(0, remaining_chars - len(content_used))
context_parts.append(f"{header}\n{content_used}")
refs.append(
{
"doc": doc,
"content_used": content_used,
"content_truncated": bool(content and len(content_used) < len(content)),
"source_id": source_id,
}
)
context = "\n\n---\n\n".join(context_parts)
return context, refs, remaining_chars
[docs]
def build_relation_summary(
*,
context_slices: List[ContextSlice],
line_limit: int,
) -> str:
"""Build a compact relation-signals block for the LLM context."""
if line_limit <= 0:
return ""
act_celex_map = {item.act.act_id: item.act.celex for item in context_slices}
lines: List[str] = ["--- Relation signals ---"]
seen: set[tuple[int, int | None, str, str | None]] = set()
for item in context_slices:
relations = item.act.relations or []
for relation in relations:
source_id = relation.get("source_act_id")
target_id = relation.get("target_act_id")
relation_type_raw = relation.get("relation_type")
relation_type = str(relation_type_raw).upper() if relation_type_raw is not None else "RELATED_TO"
description_raw = relation.get("description")
description = (
str(description_raw).strip() if isinstance(description_raw, str) and description_raw.strip() else None
)
if not isinstance(source_id, int):
continue
key = (source_id, target_id if isinstance(target_id, int) else None, relation_type, description)
if key in seen:
continue
seen.add(key)
source_celex = act_celex_map.get(source_id, f"ACT-{source_id}")
if isinstance(target_id, int):
target_celex = act_celex_map.get(target_id, f"ACT-{target_id}")
else:
target_celex_raw = relation.get("target_celex")
target_celex = (
str(target_celex_raw).strip()
if isinstance(target_celex_raw, str) and target_celex_raw.strip()
else "EXTERNAL"
)
line = f"- {source_celex} -[{relation_type}]-> {target_celex}"
if description:
line += f" | {description}"
lines.append(line)
if len(lines) - 1 >= line_limit:
return "\n".join(lines)
if len(lines) == 1:
return ""
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Metadata & validation
# ---------------------------------------------------------------------------
[docs]
def attach_citation_validation(
*,
response: Dict[str, Any],
answer: str,
sources: List[Dict[str, Any]],
) -> None:
"""Validate mixed source citations in the answer and attach results to metadata."""
if not sources:
return
source_ids: List[str] = []
for source_doc in sources:
source_id = source_doc.get("source_id")
if isinstance(source_id, str) and source_id:
source_ids.append(source_id)
response["metadata"]["citation_validation"] = validate_citations(answer, source_ids)
# ── Graph fetch result + ranked context builder ─────────────────────────
[docs]
@dataclass
class GraphFetchResult:
"""Raw output from Neo4j graph expansion (before ranking)."""
nodes: List[Dict[str, Any]]
relationships: List[Dict[str, Any]]
seed_act_ids: Set[int]
expanded_act_ids: Set[int]
community_block: str = ""
community_meta: Dict[str, Any] = field(default_factory=dict)
duration_ms: float = 0.0
[docs]
def build_ranked_graph_context(
*,
fetch_result: GraphFetchResult,
semantic_results: List[Any],
max_context_chars: int,
graph_acts_limit: int,
graph_relationships_limit: int,
hop_decay: float,
semantic_boost: float,
relation_weight_factor: float,
budget_semantic_share: float,
budget_graph_share: float,
budget_relation_share: float,
min_chars_per_source: int,
max_depth: int,
) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Rank graph nodes/relationships and build budget-aware context.
Returns (combined_context_str, metadata_dict, graph_node_refs, relationship_refs).
"""
semantic_act_ids = fetch_result.seed_act_ids
ranked_nodes = rank_graph_nodes(
graph_context=fetch_result.nodes,
relationships=fetch_result.relationships,
semantic_act_ids=semantic_act_ids,
seed_act_ids=fetch_result.seed_act_ids,
max_depth=max_depth,
hop_decay=hop_decay,
semantic_boost=semantic_boost,
relation_weight_factor=relation_weight_factor,
)
ranked_nodes = ranked_nodes[:graph_acts_limit]
top_act_ids = {int(n["id"]) for n in ranked_nodes if n.get("id") is not None}
ranked_rels = rank_relationships(
relationships=fetch_result.relationships,
top_act_ids=top_act_ids | fetch_result.seed_act_ids,
)
ranked_rels = ranked_rels[:graph_relationships_limit]
budget = GraphContextBudget(
max_chars=max_context_chars,
semantic_share=budget_semantic_share,
graph_share=budget_graph_share,
relation_share=budget_relation_share,
min_chars_per_source=min_chars_per_source,
)
ctx_result = budget.build(
semantic_results=semantic_results,
ranked_nodes=ranked_nodes,
ranked_relationships=ranked_rels,
)
combined = ctx_result.combined_context
if fetch_result.community_block:
combined = f"{fetch_result.community_block}\n\n{combined}"
meta: Dict[str, Any] = {
"ranked_nodes_used": len(ranked_nodes),
"ranked_relationships_used": len(ranked_rels),
"total_graph_nodes": len(fetch_result.nodes),
"total_relationships": len(fetch_result.relationships),
"graph_acts_used_for_context": ctx_result.graph_nodes_used,
"graph_relationships_used_for_context": ctx_result.relationships_used,
"context_budget": ctx_result.chars_used,
"top_rank_score": ranked_nodes[0].get("_rank_score", 0) if ranked_nodes else 0,
"top_rank_score_raw": ranked_nodes[0].get("_rank_score_raw", 0) if ranked_nodes else 0,
"top_relation_score": ranked_rels[0].get("_rel_weight", 0) if ranked_rels else 0,
"top_relation_score_raw": ranked_rels[0].get("_rel_weight_raw", 0) if ranked_rels else 0,
"score_scale": SCORE_SCALE_NORMALIZED_0_1,
}
return combined, meta, ctx_result.graph_node_refs, ctx_result.relationship_refs