"""Concrete planning runtime for the PydanticAI-driven RAG pipeline."""
from __future__ import annotations
import logging
import re
import time
from dataclasses import dataclass, field
from typing import Any, List, Optional
from lalandre_core.config import get_config
from lalandre_rag.retrieval.result import RetrievalResult
from .tools import (
run_decomposition_agent,
run_evaluation_agent,
run_planner_agent,
run_refinement_agent,
)
logger = logging.getLogger(__name__)
_MULTI_CELEX_RE = re.compile(
r"(?:(?:\d{4}[LR]\d{3,})|(?:\d{4}/\d+/[A-Z]{2,}))",
re.IGNORECASE,
)
_MULTI_FACET_PATTERNS = [
r"\b(différence|différences|distinction|comparer|comparaison|versus|vs\.?)\b",
r"\bd'une part\b.*\bd'autre part\b",
r"\b(comment|quell?e?s?)\b.{5,60}\bet\b.{5,60}\b(comment|quell?e?s?)\b",
r"\bà la fois\b",
r"\bobligation[sx]?\b.{3,50}\bet\b.{3,50}\bobligation[sx]?\b",
]
_COMPILED_PATTERNS = [re.compile(pattern, re.IGNORECASE | re.DOTALL) for pattern in _MULTI_FACET_PATTERNS]
[docs]
@dataclass(frozen=True)
class AgenticComplementaryQuery:
"""A targeted follow-up query proposed by the planner."""
query: str
level_hint: Optional[str] = None
[docs]
@dataclass
class DecomposedQuery:
"""Structured decomposition used by the planning graph."""
sub_questions: List[str] = field(default_factory=list)
synthesize: bool = False
decomposed: bool = False
decompose_ms: float = 0.0
output_validation_retries: int = 0
[docs]
@dataclass
class AgenticRetrievalPlan:
"""Planner decision for retrieval/refinement phases."""
primary_query: str
intent_class: str = "documentary"
skip_retrieval: bool = False
needs_complementary: bool = False
complementary_queries: List[AgenticComplementaryQuery] = field(default_factory=list)
needs_compression: bool = False
clarification_question: Optional[str] = None
strict_grounding_requested: bool = False
rationale: str = ""
planning_ms: float = 0.0
planner_used: bool = False
output_validation_retries: int = 0
[docs]
@dataclass
class EvalResult:
"""Sufficiency evaluation for CRAG correction."""
status: str
gap_hint: Optional[str]
eval_ms: float
fallback: bool = False
output_validation_retries: int = 0
def _heuristic_should_decompose(question: str) -> bool:
celex_matches = _MULTI_CELEX_RE.findall(question)
if len(celex_matches) >= 2:
return True
return any(pattern.search(question) for pattern in _COMPILED_PATTERNS)
[docs]
def decompose_query(
question: str,
llm: Any,
*,
heuristic_only: bool = True,
max_sub_questions: int = 3,
) -> DecomposedQuery:
"""Decompose a complex question into independent sub-questions."""
started_at = time.perf_counter()
if heuristic_only and not _heuristic_should_decompose(question):
return DecomposedQuery()
try:
output, retries = run_decomposition_agent(question=question, llm=llm)
elapsed_ms = (time.perf_counter() - started_at) * 1000.0
sub_questions = list(output.sub_questions)[:max_sub_questions]
synthesize = bool(output.synthesize)
if not sub_questions:
logger.debug("QueryDecomposer: LLM returned no sub-questions for '%s'", question[:60])
return DecomposedQuery(
decompose_ms=round(elapsed_ms, 1),
output_validation_retries=retries,
)
logger.info(
"QueryDecomposer: %d sub-questions in %.1f ms for '%s'",
len(sub_questions),
elapsed_ms,
question[:60],
)
return DecomposedQuery(
sub_questions=sub_questions,
synthesize=synthesize,
decomposed=True,
decompose_ms=round(elapsed_ms, 1),
output_validation_retries=retries,
)
except Exception as exc:
elapsed_ms = (time.perf_counter() - started_at) * 1000.0
logger.warning("QueryDecomposer failed (non-fatal): %s", exc)
return DecomposedQuery(
decompose_ms=round(elapsed_ms, 1),
output_validation_retries=0,
)
[docs]
def plan_retrieval(question: str, llm: Any) -> AgenticRetrievalPlan:
"""Run the planner LLM to decide the retrieval strategy."""
started_at = time.perf_counter()
try:
output, retries = run_planner_agent(question=question, llm=llm)
elapsed_ms = (time.perf_counter() - started_at) * 1000.0
complementary = [
AgenticComplementaryQuery(query=query.query, level_hint=query.level_hint)
for query in output.complementary_queries
]
return AgenticRetrievalPlan(
primary_query=output.primary_query or question,
intent_class=output.intent_class,
skip_retrieval=bool(output.skip_retrieval),
needs_complementary=bool(output.needs_complementary),
complementary_queries=complementary[:2],
needs_compression=bool(output.needs_compression),
clarification_question=output.clarification_question,
strict_grounding_requested=bool(output.strict_grounding_requested),
rationale=output.rationale,
planning_ms=round(elapsed_ms, 1),
planner_used=True,
output_validation_retries=retries,
)
except Exception as exc:
elapsed_ms = (time.perf_counter() - started_at) * 1000.0
logger.warning("Planner failed (non-fatal), using passthrough: %s", exc)
return AgenticRetrievalPlan(
primary_query=question,
planning_ms=round(elapsed_ms, 1),
planner_used=False,
rationale=f"Planner fallback: {exc}",
output_validation_retries=0,
)
[docs]
def refine_retrieval(question: str, gap_hint: str, llm: Any) -> AgenticRetrievalPlan:
"""Generate a refined retrieval query targeting the identified gap."""
started_at = time.perf_counter()
try:
output, retries = run_refinement_agent(question=question, gap_hint=gap_hint, llm=llm)
elapsed_ms = (time.perf_counter() - started_at) * 1000.0
refined_query = output.refined_query or question
rationale = output.rationale or f"CRAG refinement: {gap_hint}"
logger.debug("CRAG refine: '%s' → '%s' (%.1f ms)", question[:60], refined_query[:60], elapsed_ms)
return AgenticRetrievalPlan(
primary_query=refined_query,
planning_ms=round(elapsed_ms, 1),
planner_used=True,
rationale=rationale,
output_validation_retries=retries,
)
except Exception as exc:
elapsed_ms = (time.perf_counter() - started_at) * 1000.0
logger.warning("CRAG refine failed (non-fatal): %s", exc)
fallback_query = f"{question} {gap_hint}" if gap_hint else question
return AgenticRetrievalPlan(
primary_query=fallback_query,
planning_ms=round(elapsed_ms, 1),
planner_used=False,
rationale=f"CRAG refine fallback: {exc}",
output_validation_retries=0,
)
def _build_eval_context(results: List[RetrievalResult], max_chars: int = 2500) -> str:
parts: List[str] = []
remaining = max_chars
for index, result in enumerate(results, start=1):
snippet = (result.content or "")[: get_config().context_budget.snippet_preview_chars]
celex_part = f"{result.celex}: " if result.celex else ""
entry = f"[S{index}] {celex_part}{snippet}"
if len(entry) > remaining:
break
parts.append(entry)
remaining -= len(entry)
return "\n".join(parts)
[docs]
def evaluate_retrieval(
question: str,
results: List[RetrievalResult],
llm: Any,
) -> EvalResult:
"""Evaluate whether current retrieval evidence is sufficient."""
started_at = time.perf_counter()
if not results:
return EvalResult(
status="INSUFFICIENT",
gap_hint="Aucun document récupéré.",
eval_ms=0.0,
)
context_preview = _build_eval_context(results)
try:
output, retries = run_evaluation_agent(
question=question,
context_preview=context_preview,
llm=llm,
)
elapsed_ms = (time.perf_counter() - started_at) * 1000.0
logger.debug("CRAG eval: status=%s gap=%s (%.1f ms)", output.status, output.gap, elapsed_ms)
return EvalResult(
status=output.status,
gap_hint=output.gap,
eval_ms=round(elapsed_ms, 1),
output_validation_retries=retries,
)
except Exception as exc:
elapsed_ms = (time.perf_counter() - started_at) * 1000.0
logger.warning(
"RetrievalEvaluator failed (non-fatal, defaulting to SUFFICIENT): %s",
exc,
)
return EvalResult(
status="SUFFICIENT",
gap_hint=None,
eval_ms=round(elapsed_ms, 1),
fallback=True,
output_validation_retries=0,
)
__all__ = [
"AgenticComplementaryQuery",
"AgenticRetrievalPlan",
"DecomposedQuery",
"EvalResult",
"decompose_query",
"evaluate_retrieval",
"plan_retrieval",
"refine_retrieval",
]