Source code for lalandre_rag.agentic.runtime

"""Concrete planning runtime for the PydanticAI-driven RAG pipeline."""

from __future__ import annotations

import logging
import re
import time
from dataclasses import dataclass, field
from typing import Any, List, Optional

from lalandre_core.config import get_config

from lalandre_rag.retrieval.result import RetrievalResult

from .tools import (
    run_decomposition_agent,
    run_evaluation_agent,
    run_planner_agent,
    run_refinement_agent,
)

logger = logging.getLogger(__name__)

_MULTI_CELEX_RE = re.compile(
    r"(?:(?:\d{4}[LR]\d{3,})|(?:\d{4}/\d+/[A-Z]{2,}))",
    re.IGNORECASE,
)

_MULTI_FACET_PATTERNS = [
    r"\b(différence|différences|distinction|comparer|comparaison|versus|vs\.?)\b",
    r"\bd'une part\b.*\bd'autre part\b",
    r"\b(comment|quell?e?s?)\b.{5,60}\bet\b.{5,60}\b(comment|quell?e?s?)\b",
    r"\bà la fois\b",
    r"\bobligation[sx]?\b.{3,50}\bet\b.{3,50}\bobligation[sx]?\b",
]
_COMPILED_PATTERNS = [re.compile(pattern, re.IGNORECASE | re.DOTALL) for pattern in _MULTI_FACET_PATTERNS]


[docs] @dataclass(frozen=True) class AgenticComplementaryQuery: """A targeted follow-up query proposed by the planner.""" query: str level_hint: Optional[str] = None
[docs] @dataclass class DecomposedQuery: """Structured decomposition used by the planning graph.""" sub_questions: List[str] = field(default_factory=list) synthesize: bool = False decomposed: bool = False decompose_ms: float = 0.0 output_validation_retries: int = 0
[docs] @dataclass class AgenticRetrievalPlan: """Planner decision for retrieval/refinement phases.""" primary_query: str intent_class: str = "documentary" skip_retrieval: bool = False needs_complementary: bool = False complementary_queries: List[AgenticComplementaryQuery] = field(default_factory=list) needs_compression: bool = False clarification_question: Optional[str] = None strict_grounding_requested: bool = False rationale: str = "" planning_ms: float = 0.0 planner_used: bool = False output_validation_retries: int = 0
[docs] @dataclass class EvalResult: """Sufficiency evaluation for CRAG correction.""" status: str gap_hint: Optional[str] eval_ms: float fallback: bool = False output_validation_retries: int = 0
def _heuristic_should_decompose(question: str) -> bool: celex_matches = _MULTI_CELEX_RE.findall(question) if len(celex_matches) >= 2: return True return any(pattern.search(question) for pattern in _COMPILED_PATTERNS)
[docs] def decompose_query( question: str, llm: Any, *, heuristic_only: bool = True, max_sub_questions: int = 3, ) -> DecomposedQuery: """Decompose a complex question into independent sub-questions.""" started_at = time.perf_counter() if heuristic_only and not _heuristic_should_decompose(question): return DecomposedQuery() try: output, retries = run_decomposition_agent(question=question, llm=llm) elapsed_ms = (time.perf_counter() - started_at) * 1000.0 sub_questions = list(output.sub_questions)[:max_sub_questions] synthesize = bool(output.synthesize) if not sub_questions: logger.debug("QueryDecomposer: LLM returned no sub-questions for '%s'", question[:60]) return DecomposedQuery( decompose_ms=round(elapsed_ms, 1), output_validation_retries=retries, ) logger.info( "QueryDecomposer: %d sub-questions in %.1f ms for '%s'", len(sub_questions), elapsed_ms, question[:60], ) return DecomposedQuery( sub_questions=sub_questions, synthesize=synthesize, decomposed=True, decompose_ms=round(elapsed_ms, 1), output_validation_retries=retries, ) except Exception as exc: elapsed_ms = (time.perf_counter() - started_at) * 1000.0 logger.warning("QueryDecomposer failed (non-fatal): %s", exc) return DecomposedQuery( decompose_ms=round(elapsed_ms, 1), output_validation_retries=0, )
[docs] def plan_retrieval(question: str, llm: Any) -> AgenticRetrievalPlan: """Run the planner LLM to decide the retrieval strategy.""" started_at = time.perf_counter() try: output, retries = run_planner_agent(question=question, llm=llm) elapsed_ms = (time.perf_counter() - started_at) * 1000.0 complementary = [ AgenticComplementaryQuery(query=query.query, level_hint=query.level_hint) for query in output.complementary_queries ] return AgenticRetrievalPlan( primary_query=output.primary_query or question, intent_class=output.intent_class, skip_retrieval=bool(output.skip_retrieval), needs_complementary=bool(output.needs_complementary), complementary_queries=complementary[:2], needs_compression=bool(output.needs_compression), clarification_question=output.clarification_question, strict_grounding_requested=bool(output.strict_grounding_requested), rationale=output.rationale, planning_ms=round(elapsed_ms, 1), planner_used=True, output_validation_retries=retries, ) except Exception as exc: elapsed_ms = (time.perf_counter() - started_at) * 1000.0 logger.warning("Planner failed (non-fatal), using passthrough: %s", exc) return AgenticRetrievalPlan( primary_query=question, planning_ms=round(elapsed_ms, 1), planner_used=False, rationale=f"Planner fallback: {exc}", output_validation_retries=0, )
[docs] def refine_retrieval(question: str, gap_hint: str, llm: Any) -> AgenticRetrievalPlan: """Generate a refined retrieval query targeting the identified gap.""" started_at = time.perf_counter() try: output, retries = run_refinement_agent(question=question, gap_hint=gap_hint, llm=llm) elapsed_ms = (time.perf_counter() - started_at) * 1000.0 refined_query = output.refined_query or question rationale = output.rationale or f"CRAG refinement: {gap_hint}" logger.debug("CRAG refine: '%s' → '%s' (%.1f ms)", question[:60], refined_query[:60], elapsed_ms) return AgenticRetrievalPlan( primary_query=refined_query, planning_ms=round(elapsed_ms, 1), planner_used=True, rationale=rationale, output_validation_retries=retries, ) except Exception as exc: elapsed_ms = (time.perf_counter() - started_at) * 1000.0 logger.warning("CRAG refine failed (non-fatal): %s", exc) fallback_query = f"{question} {gap_hint}" if gap_hint else question return AgenticRetrievalPlan( primary_query=fallback_query, planning_ms=round(elapsed_ms, 1), planner_used=False, rationale=f"CRAG refine fallback: {exc}", output_validation_retries=0, )
def _build_eval_context(results: List[RetrievalResult], max_chars: int = 2500) -> str: parts: List[str] = [] remaining = max_chars for index, result in enumerate(results, start=1): snippet = (result.content or "")[: get_config().context_budget.snippet_preview_chars] celex_part = f"{result.celex}: " if result.celex else "" entry = f"[S{index}] {celex_part}{snippet}" if len(entry) > remaining: break parts.append(entry) remaining -= len(entry) return "\n".join(parts)
[docs] def evaluate_retrieval( question: str, results: List[RetrievalResult], llm: Any, ) -> EvalResult: """Evaluate whether current retrieval evidence is sufficient.""" started_at = time.perf_counter() if not results: return EvalResult( status="INSUFFICIENT", gap_hint="Aucun document récupéré.", eval_ms=0.0, ) context_preview = _build_eval_context(results) try: output, retries = run_evaluation_agent( question=question, context_preview=context_preview, llm=llm, ) elapsed_ms = (time.perf_counter() - started_at) * 1000.0 logger.debug("CRAG eval: status=%s gap=%s (%.1f ms)", output.status, output.gap, elapsed_ms) return EvalResult( status=output.status, gap_hint=output.gap, eval_ms=round(elapsed_ms, 1), output_validation_retries=retries, ) except Exception as exc: elapsed_ms = (time.perf_counter() - started_at) * 1000.0 logger.warning( "RetrievalEvaluator failed (non-fatal, defaulting to SUFFICIENT): %s", exc, ) return EvalResult( status="SUFFICIENT", gap_hint=None, eval_ms=round(elapsed_ms, 1), fallback=True, output_validation_retries=0, )
__all__ = [ "AgenticComplementaryQuery", "AgenticRetrievalPlan", "DecomposedQuery", "EvalResult", "decompose_query", "evaluate_retrieval", "plan_retrieval", "refine_retrieval", ]