"""
Context Service
Enriches retrieval results with metadata, relationships, and formatting
"""
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, TypedDict, cast
from lalandre_core.config import get_config
from lalandre_core.utils.regulatory_level import level_to_label
from lalandre_db_postgres import (
ActRelationsSQL,
ActsSQL,
ActSubjectsSQL,
ChunksSQL,
PostgresRepository,
SubdivisionsSQL,
)
from sqlalchemy.orm import joinedload
from lalandre_rag.retrieval.trace import TRACE_KEYS_SET, extract_trace
from ..result import RetrievalResult
from .models import ActContext, ContextSlice, DocumentMeta
[docs]
class RelationPayload(TypedDict):
"""Normalized relation payload attached to enriched act contexts."""
relation_type: Any
source_act_id: int
target_act_id: int | None
target_celex: str | None
description: str | None
confidence: float | None
def _as_optional_int(value: Any) -> int | None:
return value if isinstance(value, int) else None
def _as_str(value: Any, default: str) -> str:
if isinstance(value, str):
return value
if value is None:
return default
return str(value)
def _as_optional_str(value: Any) -> str | None:
if value is None:
return None
return value if isinstance(value, str) else str(value)
def _serialize_date(value: Any) -> str | None:
"""Serialize a date/datetime to ISO string, or pass through strings."""
if value is None:
return None
if hasattr(value, "isoformat"):
return value.isoformat()
return str(value)
[docs]
class ContextService:
"""
Enriches retrieval results into context slices
Responsibilities:
- Enrich results with act metadata (title, type, CELEX)
- Add relationships between documents
- Format context for LLM consumption
- Generate context summaries
"""
def __init__(self, pg_repo: PostgresRepository):
"""
Initialize context service
Args:
pg_repo: PostgreSQL repository for metadata queries
"""
self.pg_repo = pg_repo
[docs]
def enrich_results(
self,
results: List[RetrievalResult],
include_relations: bool = False,
include_subjects: bool = False,
hydrate_content: bool = True,
) -> List[ContextSlice]:
"""
Enrich retrieval results with act metadata, optional relations and subjects.
Returns context slices with explicit act/doc separation.
"""
if not results:
return []
act_ids = sorted({r.act_id for r in results if r.act_id})
act_map: Dict[int, ActsSQL] = {}
relations_map: Dict[int, List[RelationPayload]] = {}
subjects_map: Dict[int, List[str]] = {}
subdivision_content_map: Dict[int, str] = {}
chunk_content_map: Dict[int, str] = {}
def _is_chunk_result(result: RetrievalResult) -> bool:
metadata = result.metadata or {}
return (
metadata.get("collection") == "chunks"
or metadata.get("source_collection") == "chunks"
or "chunk_id" in metadata
or "chunk_index" in metadata
)
def _enum_value(value: Any) -> Any:
return value.value if hasattr(value, "value") else value
# Precompute IDs needed for hydration
subdivision_ids: List[int] = []
chunk_ids: List[int] = []
if hydrate_content:
subdivision_ids = sorted({r.subdivision_id for r in results if not _is_chunk_result(r) and not r.content})
chunk_ids = sorted(
{
chunk_id
for r in results
if not r.content
for chunk_id in [_as_optional_int((r.metadata or {}).get("chunk_id"))]
if chunk_id is not None
}
)
# ── Run independent DB queries in parallel ────────────────────
def _fetch_subdivisions() -> Dict[int, str]:
if not subdivision_ids:
return {}
with self.pg_repo.get_session() as s:
rows = cast(
List[tuple[int, str]],
s.query(
SubdivisionsSQL.id,
SubdivisionsSQL.content,
)
.filter(SubdivisionsSQL.id.in_(subdivision_ids))
.all(),
)
return {sid: content for sid, content in rows}
def _fetch_chunks() -> Dict[int, str]:
if not chunk_ids:
return {}
with self.pg_repo.get_session() as s:
rows = cast(
List[tuple[int, str]],
s.query(
ChunksSQL.id,
ChunksSQL.content,
)
.filter(ChunksSQL.id.in_(chunk_ids))
.all(),
)
return {cid: content for cid, content in rows}
def _fetch_acts() -> Dict[int, ActsSQL]:
if not act_ids:
return {}
with self.pg_repo.get_session() as s:
query = s.query(ActsSQL).filter(ActsSQL.id.in_(act_ids))
if include_subjects:
query = query.options(joinedload(ActsSQL.subjects).joinedload(ActSubjectsSQL.subject))
acts = cast(List[ActsSQL], query.all())
# Eagerly load attributes before session closes
result = {}
for act in acts:
# Touch lazy-loaded fields while session is open
if include_subjects:
_ = act.subjects # noqa: F841
result[act.id] = act
# Detach from session so objects remain usable
for act in acts:
s.expunge(act)
return result
def _fetch_relations() -> List[ActRelationsSQL]:
if not act_ids or not include_relations:
return []
with self.pg_repo.get_session() as s:
return cast(
List[ActRelationsSQL],
s.query(ActRelationsSQL)
.filter((ActRelationsSQL.source_act_id.in_(act_ids)) | (ActRelationsSQL.target_act_id.in_(act_ids)))
.all(),
)
with ThreadPoolExecutor(max_workers=get_config().search.max_parallel_workers) as executor:
f_subs = executor.submit(_fetch_subdivisions)
f_chunks = executor.submit(_fetch_chunks)
f_acts = executor.submit(_fetch_acts)
f_rels = executor.submit(_fetch_relations)
subdivision_content_map = f_subs.result()
chunk_content_map = f_chunks.result()
act_map = f_acts.result()
relations_raw = f_rels.result()
# Build relations map from raw results
if include_relations:
for rel in relations_raw:
rel_dict: RelationPayload = {
"relation_type": _enum_value(rel.relation_type),
"source_act_id": rel.source_act_id,
"target_act_id": rel.target_act_id,
"target_celex": rel.target_celex,
"description": rel.description,
"confidence": rel.confidence,
}
relations_map.setdefault(rel.source_act_id, []).append(rel_dict)
if rel.target_act_id is not None:
relations_map.setdefault(rel.target_act_id, []).append(rel_dict)
# Build subjects map
if include_subjects and act_ids:
for act in act_map.values():
labels: List[str] = []
for subj_rel in act.subjects or []:
subject = subj_rel.subject
if not subject:
continue
label = subject.label_fr or subject.label_en or subject.eurovoc_code
labels.append(f"{subject.eurovoc_code}: {label}")
subjects_map[act.id] = labels
def _relation_dicts(act_id: int) -> List[Dict[str, Any]] | None:
if not include_relations:
return None
relation_payloads = relations_map.get(act_id)
if relation_payloads is None:
return None
return [dict(rel) for rel in relation_payloads]
act_context_map: Dict[int, ActContext] = {}
for act in act_map.values():
act_context_map[act.id] = ActContext(
act_id=act.id,
celex=act.celex or "Unknown",
title=act.title or act.celex or "Unknown",
act_type=str(_enum_value(act.act_type) or "unknown"),
regulatory_level=level_to_label(getattr(act, "level", None)),
url_eurlex=act.url_eurlex,
relations=_relation_dicts(act.id),
subjects=subjects_map.get(act.id) if include_subjects else None,
adoption_date=_serialize_date(getattr(act, "adoption_date", None)),
force_date=_serialize_date(getattr(act, "force_date", None)),
)
slices: List[ContextSlice] = []
for result in results:
metadata = result.metadata or {}
content = result.content
if hydrate_content:
if not content:
chunk_id = _as_optional_int(metadata.get("chunk_id"))
if chunk_id is not None and chunk_id in chunk_content_map:
content = chunk_content_map[chunk_id]
elif result.subdivision_id in subdivision_content_map:
content = subdivision_content_map[result.subdivision_id]
act_ctx = act_context_map.get(result.act_id)
if not act_ctx:
celex = _as_str(result.celex or metadata.get("celex"), "Unknown")
title = _as_str(metadata.get("title") or metadata.get("act_title"), celex)
act_type = metadata.get("act_type") or "unknown"
url_eurlex = _as_optional_str(metadata.get("url_eurlex") or metadata.get("url"))
act_ctx = ActContext(
act_id=result.act_id,
celex=celex,
title=title,
act_type=str(act_type),
regulatory_level=level_to_label(metadata.get("level")),
url_eurlex=url_eurlex,
relations=_relation_dicts(result.act_id),
subjects=subjects_map.get(result.act_id) if include_subjects else None,
adoption_date=_serialize_date(metadata.get("adoption_date")),
force_date=_serialize_date(metadata.get("force_date")),
)
act_context_map[result.act_id] = act_ctx
subdivision_type = _as_str(_enum_value(result.subdivision_type), "UNKNOWN")
source_kind = "chunk" if _is_chunk_result(result) else "subdivision"
trace = extract_trace(metadata)
payload_metadata = {key: value for key, value in metadata.items() if key not in TRACE_KEYS_SET}
doc_meta = DocumentMeta(
source_kind=source_kind,
subdivision_id=result.subdivision_id,
subdivision_type=subdivision_type,
sequence_order=result.sequence_order,
chunk_id=_as_optional_int(metadata.get("chunk_id")),
chunk_index=_as_optional_int(metadata.get("chunk_index")),
char_start=_as_optional_int(metadata.get("char_start")),
char_end=_as_optional_int(metadata.get("char_end")),
payload=payload_metadata,
)
slices.append(
ContextSlice(
content=content,
score=result.score,
act=act_ctx,
doc=doc_meta,
trace=trace,
)
)
return slices