Source code for lalandre_rag.retrieval.context.service

"""
Context Service
Enriches retrieval results with metadata, relationships, and formatting
"""

from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, TypedDict, cast

from lalandre_core.config import get_config
from lalandre_core.utils.regulatory_level import level_to_label
from lalandre_db_postgres import (
    ActRelationsSQL,
    ActsSQL,
    ActSubjectsSQL,
    ChunksSQL,
    PostgresRepository,
    SubdivisionsSQL,
)
from sqlalchemy.orm import joinedload

from lalandre_rag.retrieval.trace import TRACE_KEYS_SET, extract_trace

from ..result import RetrievalResult
from .models import ActContext, ContextSlice, DocumentMeta


[docs] class RelationPayload(TypedDict): """Normalized relation payload attached to enriched act contexts.""" relation_type: Any source_act_id: int target_act_id: int | None target_celex: str | None description: str | None confidence: float | None
def _as_optional_int(value: Any) -> int | None: return value if isinstance(value, int) else None def _as_str(value: Any, default: str) -> str: if isinstance(value, str): return value if value is None: return default return str(value) def _as_optional_str(value: Any) -> str | None: if value is None: return None return value if isinstance(value, str) else str(value) def _serialize_date(value: Any) -> str | None: """Serialize a date/datetime to ISO string, or pass through strings.""" if value is None: return None if hasattr(value, "isoformat"): return value.isoformat() return str(value)
[docs] class ContextService: """ Enriches retrieval results into context slices Responsibilities: - Enrich results with act metadata (title, type, CELEX) - Add relationships between documents - Format context for LLM consumption - Generate context summaries """ def __init__(self, pg_repo: PostgresRepository): """ Initialize context service Args: pg_repo: PostgreSQL repository for metadata queries """ self.pg_repo = pg_repo
[docs] def enrich_results( self, results: List[RetrievalResult], include_relations: bool = False, include_subjects: bool = False, hydrate_content: bool = True, ) -> List[ContextSlice]: """ Enrich retrieval results with act metadata, optional relations and subjects. Returns context slices with explicit act/doc separation. """ if not results: return [] act_ids = sorted({r.act_id for r in results if r.act_id}) act_map: Dict[int, ActsSQL] = {} relations_map: Dict[int, List[RelationPayload]] = {} subjects_map: Dict[int, List[str]] = {} subdivision_content_map: Dict[int, str] = {} chunk_content_map: Dict[int, str] = {} def _is_chunk_result(result: RetrievalResult) -> bool: metadata = result.metadata or {} return ( metadata.get("collection") == "chunks" or metadata.get("source_collection") == "chunks" or "chunk_id" in metadata or "chunk_index" in metadata ) def _enum_value(value: Any) -> Any: return value.value if hasattr(value, "value") else value # Precompute IDs needed for hydration subdivision_ids: List[int] = [] chunk_ids: List[int] = [] if hydrate_content: subdivision_ids = sorted({r.subdivision_id for r in results if not _is_chunk_result(r) and not r.content}) chunk_ids = sorted( { chunk_id for r in results if not r.content for chunk_id in [_as_optional_int((r.metadata or {}).get("chunk_id"))] if chunk_id is not None } ) # ── Run independent DB queries in parallel ──────────────────── def _fetch_subdivisions() -> Dict[int, str]: if not subdivision_ids: return {} with self.pg_repo.get_session() as s: rows = cast( List[tuple[int, str]], s.query( SubdivisionsSQL.id, SubdivisionsSQL.content, ) .filter(SubdivisionsSQL.id.in_(subdivision_ids)) .all(), ) return {sid: content for sid, content in rows} def _fetch_chunks() -> Dict[int, str]: if not chunk_ids: return {} with self.pg_repo.get_session() as s: rows = cast( List[tuple[int, str]], s.query( ChunksSQL.id, ChunksSQL.content, ) .filter(ChunksSQL.id.in_(chunk_ids)) .all(), ) return {cid: content for cid, content in rows} def _fetch_acts() -> Dict[int, ActsSQL]: if not act_ids: return {} with self.pg_repo.get_session() as s: query = s.query(ActsSQL).filter(ActsSQL.id.in_(act_ids)) if include_subjects: query = query.options(joinedload(ActsSQL.subjects).joinedload(ActSubjectsSQL.subject)) acts = cast(List[ActsSQL], query.all()) # Eagerly load attributes before session closes result = {} for act in acts: # Touch lazy-loaded fields while session is open if include_subjects: _ = act.subjects # noqa: F841 result[act.id] = act # Detach from session so objects remain usable for act in acts: s.expunge(act) return result def _fetch_relations() -> List[ActRelationsSQL]: if not act_ids or not include_relations: return [] with self.pg_repo.get_session() as s: return cast( List[ActRelationsSQL], s.query(ActRelationsSQL) .filter((ActRelationsSQL.source_act_id.in_(act_ids)) | (ActRelationsSQL.target_act_id.in_(act_ids))) .all(), ) with ThreadPoolExecutor(max_workers=get_config().search.max_parallel_workers) as executor: f_subs = executor.submit(_fetch_subdivisions) f_chunks = executor.submit(_fetch_chunks) f_acts = executor.submit(_fetch_acts) f_rels = executor.submit(_fetch_relations) subdivision_content_map = f_subs.result() chunk_content_map = f_chunks.result() act_map = f_acts.result() relations_raw = f_rels.result() # Build relations map from raw results if include_relations: for rel in relations_raw: rel_dict: RelationPayload = { "relation_type": _enum_value(rel.relation_type), "source_act_id": rel.source_act_id, "target_act_id": rel.target_act_id, "target_celex": rel.target_celex, "description": rel.description, "confidence": rel.confidence, } relations_map.setdefault(rel.source_act_id, []).append(rel_dict) if rel.target_act_id is not None: relations_map.setdefault(rel.target_act_id, []).append(rel_dict) # Build subjects map if include_subjects and act_ids: for act in act_map.values(): labels: List[str] = [] for subj_rel in act.subjects or []: subject = subj_rel.subject if not subject: continue label = subject.label_fr or subject.label_en or subject.eurovoc_code labels.append(f"{subject.eurovoc_code}: {label}") subjects_map[act.id] = labels def _relation_dicts(act_id: int) -> List[Dict[str, Any]] | None: if not include_relations: return None relation_payloads = relations_map.get(act_id) if relation_payloads is None: return None return [dict(rel) for rel in relation_payloads] act_context_map: Dict[int, ActContext] = {} for act in act_map.values(): act_context_map[act.id] = ActContext( act_id=act.id, celex=act.celex or "Unknown", title=act.title or act.celex or "Unknown", act_type=str(_enum_value(act.act_type) or "unknown"), regulatory_level=level_to_label(getattr(act, "level", None)), url_eurlex=act.url_eurlex, relations=_relation_dicts(act.id), subjects=subjects_map.get(act.id) if include_subjects else None, adoption_date=_serialize_date(getattr(act, "adoption_date", None)), force_date=_serialize_date(getattr(act, "force_date", None)), ) slices: List[ContextSlice] = [] for result in results: metadata = result.metadata or {} content = result.content if hydrate_content: if not content: chunk_id = _as_optional_int(metadata.get("chunk_id")) if chunk_id is not None and chunk_id in chunk_content_map: content = chunk_content_map[chunk_id] elif result.subdivision_id in subdivision_content_map: content = subdivision_content_map[result.subdivision_id] act_ctx = act_context_map.get(result.act_id) if not act_ctx: celex = _as_str(result.celex or metadata.get("celex"), "Unknown") title = _as_str(metadata.get("title") or metadata.get("act_title"), celex) act_type = metadata.get("act_type") or "unknown" url_eurlex = _as_optional_str(metadata.get("url_eurlex") or metadata.get("url")) act_ctx = ActContext( act_id=result.act_id, celex=celex, title=title, act_type=str(act_type), regulatory_level=level_to_label(metadata.get("level")), url_eurlex=url_eurlex, relations=_relation_dicts(result.act_id), subjects=subjects_map.get(result.act_id) if include_subjects else None, adoption_date=_serialize_date(metadata.get("adoption_date")), force_date=_serialize_date(metadata.get("force_date")), ) act_context_map[result.act_id] = act_ctx subdivision_type = _as_str(_enum_value(result.subdivision_type), "UNKNOWN") source_kind = "chunk" if _is_chunk_result(result) else "subdivision" trace = extract_trace(metadata) payload_metadata = {key: value for key, value in metadata.items() if key not in TRACE_KEYS_SET} doc_meta = DocumentMeta( source_kind=source_kind, subdivision_id=result.subdivision_id, subdivision_type=subdivision_type, sequence_order=result.sequence_order, chunk_id=_as_optional_int(metadata.get("chunk_id")), chunk_index=_as_optional_int(metadata.get("chunk_index")), char_start=_as_optional_int(metadata.get("char_start")), char_end=_as_optional_int(metadata.get("char_end")), payload=payload_metadata, ) slices.append( ContextSlice( content=content, score=result.score, act=act_ctx, doc=doc_meta, trace=trace, ) ) return slices