"""Pydantic Graph orchestration for RAG planning phases."""
from __future__ import annotations
import asyncio
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Dict, Optional, cast
from uuid import uuid4
from lalandre_core.config import get_config
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
from lalandre_rag.agentic.runtime import (
AgenticComplementaryQuery,
DecomposedQuery,
evaluate_retrieval,
plan_retrieval,
refine_retrieval,
)
from lalandre_rag.modes.hybrid_graph import fetch_graph_context
from lalandre_rag.modes.hybrid_helpers import emit_progress
from lalandre_rag.retrieval.context.compressor import compress_context
from lalandre_rag.retrieval.overview import build_retrieval_overview
from .deps import AgenticPlanningDeps
from .models import (
PhaseTraceEvent,
PlanningContext,
PlanningEarlyExit,
PlanningGraphState,
PlanningResult,
)
logger = logging.getLogger(__name__)
_PROFILE_LABELS = {
"global_overview": "Vue d'ensemble globale",
"citation_precision": "Recherche de citation précise",
"relationship_focus": "Analyse de relations juridiques",
"contextual_default": "Question contextuelle",
"manual_collections": "Collections imposées",
"manual_override": "Paramètres manuels",
}
def _hierarchy_detail(selected_counts: Dict[str, Any]) -> str:
acts = int(selected_counts.get("acts") or 0)
subdivisions = int(selected_counts.get("subdivisions") or 0)
chunks = int(selected_counts.get("chunks") or 0)
return (
f"{acts} acte{'s' if acts != 1 else ''} · "
f"{subdivisions} subdivision{'s' if subdivisions != 1 else ''} · "
f"{chunks} chunk{'s' if chunks != 1 else ''}"
)
def _record_progress(
state: PlanningGraphState,
deps: AgenticPlanningDeps,
*,
phase: str,
status: str,
label: str,
detail: Optional[str] = None,
count: Optional[int] = None,
duration_ms: Optional[float] = None,
meta: Optional[Dict[str, Any]] = None,
tool: Optional[str] = None,
) -> None:
payload: Dict[str, Any] = {
"phase": phase,
"status": status,
"label": label,
}
if detail is not None:
payload["detail"] = detail
if count is not None:
payload["count"] = count
if duration_ms is not None:
payload["duration_ms"] = round(float(duration_ms), 1)
if meta:
payload["meta"] = meta
event = PhaseTraceEvent(
phase=phase,
status=status,
label=label,
detail=detail,
count=count,
duration_ms=round(float(duration_ms), 1) if duration_ms is not None else None,
meta=meta or {},
tool=tool,
)
state.trace_events.append(event)
emit_progress(
deps.progress_callback,
phase=phase,
status=status,
label=label,
detail=cast(Optional[str], payload.get("detail")),
count=cast(Optional[int], payload.get("count")),
duration_ms=cast(Optional[float], payload.get("duration_ms")),
meta=cast(Optional[Dict[str, Any]], payload.get("meta")),
)
def _append_path(state: PlanningGraphState, node_name: str) -> None:
state.planner_path.append(node_name)
def _finalize_agentic_meta(state: PlanningGraphState) -> None:
state.agentic_meta.setdefault("planner_runtime", "pydanticai")
state.agentic_meta["planner_run_id"] = state.planner_run_id
state.agentic_meta["planner_path"] = list(state.planner_path)
state.agentic_meta["tool_trace"] = [event.model_dump() for event in state.trace_events]
state.agentic_meta["output_validation_retries"] = state.output_validation_retries
[docs]
@dataclass
class LoadConversationContext(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]):
"""Bootstrap node for future conversation-memory loading."""
[docs]
async def run(self, ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps]) -> DecomposeQuestion:
"""Advance to the question-decomposition phase."""
_append_path(ctx.state, self.get_node_id())
return DecomposeQuestion()
[docs]
@dataclass
class DecomposeQuestion(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]):
"""Placeholder decomposition node for complex multi-part questions."""
[docs]
async def run(
self,
ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps],
) -> RouteQuestion:
"""Advance to the routing phase."""
_append_path(ctx.state, self.get_node_id())
return RouteQuestion()
[docs]
@dataclass
class RouteQuestion(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]):
"""Route the user question toward the appropriate retrieval profile."""
[docs]
async def run(
self,
ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps],
) -> PlanRetrieval:
"""Compute routing metadata and transition to planning."""
_append_path(ctx.state, self.get_node_id())
state = ctx.state
deps = ctx.deps
_record_progress(
state,
deps,
phase="routing",
status="active",
label="Qualification de la demande",
detail="Choix du profil de recherche et du niveau de granularité",
tool="route",
)
# Launch planning LLM call in parallel with routing (they're independent).
if not state.collections:
_record_progress(
state,
deps,
phase="planning",
status="active",
label="Planification de la recherche",
detail="Préparation de la stratégie de récupération (en parallèle)",
tool="plan",
)
planning_executor = ThreadPoolExecutor(max_workers=1)
state.planning_future = planning_executor.submit(
plan_retrieval,
state.question,
deps.lightweight_llm,
)
started_at = time.perf_counter()
if state.collections:
retrieval_plan = deps.query_router.route(
question=state.question,
top_k=state.top_k,
requested_granularity=state.granularity,
)
retrieval_plan = retrieval_plan.__class__(
profile="manual_collections",
granularity=state.granularity,
top_k=state.top_k,
include_relations_hint=state.include_relations,
rationale="Explicit collections provided: routing heuristics bypassed.",
use_graph=bool(state.use_graph),
execution_mode="hybrid",
routing_source="heuristic",
)
effective_top_k = state.top_k
effective_granularity = state.granularity
effective_include_relations = state.include_relations
else:
retrieval_plan = deps.query_router.route(
question=state.question,
top_k=state.top_k,
requested_granularity=state.granularity,
)
effective_top_k = retrieval_plan.top_k
effective_granularity = retrieval_plan.granularity
effective_include_relations = state.include_relations or retrieval_plan.include_relations_hint
if state.use_graph is not None:
retrieval_plan = retrieval_plan.__class__(
**{
**retrieval_plan.__dict__,
"use_graph": state.use_graph,
}
)
state.routing_ms = (time.perf_counter() - started_at) * 1000.0
state.retrieval_plan = retrieval_plan
state.effective_top_k = effective_top_k
state.effective_granularity = effective_granularity
state.effective_include_relations = effective_include_relations
state.agentic_meta["routing"] = {
"profile": retrieval_plan.profile,
"execution_mode": retrieval_plan.execution_mode,
"granularity": effective_granularity,
"top_k": effective_top_k,
"use_graph": retrieval_plan.use_graph,
"rationale": retrieval_plan.rationale,
}
intent_parser = getattr(deps.query_router, "intent_parser", None)
parsed_intent = getattr(intent_parser, "_last_parsed_intent", None)
if parsed_intent is not None:
state.output_validation_retries += int(getattr(parsed_intent, "output_validation_retries", 0))
_record_progress(
state,
deps,
phase="routing",
status="done",
label=str(_PROFILE_LABELS.get(retrieval_plan.profile, retrieval_plan.profile)),
detail=retrieval_plan.rationale,
duration_ms=state.routing_ms,
meta={
"profile": retrieval_plan.profile,
"execution_mode": retrieval_plan.execution_mode,
"granularity": effective_granularity,
"top_k": effective_top_k,
"use_graph": retrieval_plan.use_graph,
},
tool="route",
)
return PlanRetrieval()
[docs]
@dataclass
class PlanRetrieval(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]):
"""Build the retrieval plan and complementary-query strategy."""
[docs]
async def run(
self,
ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps],
) -> RunRetrieval:
"""Resolve the planning step and transition to retrieval."""
_append_path(ctx.state, self.get_node_id())
state = ctx.state
deps = ctx.deps
if state.collections:
return RunRetrieval()
# If routing already ran concurrently (via RouteQuestion) we just
# need to await the planning future that was launched in parallel.
if state.planning_future is not None:
agentic_plan = state.planning_future.result()
state.planning_future = None
else:
_record_progress(
state,
deps,
phase="planning",
status="active",
label="Planification de la recherche",
detail="Préparation de la stratégie de récupération",
tool="plan",
)
agentic_plan = await asyncio.to_thread(plan_retrieval, state.question, deps.lightweight_llm)
state.planner_ms = agentic_plan.planning_ms
state.agentic_plan = agentic_plan
state.output_validation_retries += int(getattr(agentic_plan, "output_validation_retries", 0))
state.agentic_meta["planner"] = {
"used": agentic_plan.planner_used,
"intent_class": agentic_plan.intent_class,
"primary_query": agentic_plan.primary_query,
"clarification_question": agentic_plan.clarification_question,
"strict_grounding_requested": agentic_plan.strict_grounding_requested,
"needs_complementary": agentic_plan.needs_complementary,
"needs_compression": agentic_plan.needs_compression,
"complementary_count": len(agentic_plan.complementary_queries),
"rationale": agentic_plan.rationale,
"planning_ms": state.planner_ms,
}
_record_progress(
state,
deps,
phase="planning",
status="done",
label="Plan de recherche établi",
detail=agentic_plan.rationale,
duration_ms=state.planner_ms,
meta={
"planner_used": bool(agentic_plan.planner_used),
"intent_class": agentic_plan.intent_class,
"primary_query": agentic_plan.primary_query,
"clarification_question": agentic_plan.clarification_question,
"strict_grounding_requested": bool(agentic_plan.strict_grounding_requested),
"needs_complementary": bool(agentic_plan.needs_complementary),
"needs_compression": bool(agentic_plan.needs_compression),
"complementary_queries": [cq.query for cq in agentic_plan.complementary_queries],
},
tool="plan",
)
return RunRetrieval()
[docs]
@dataclass
class RunRetrieval(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]):
"""Execute retrieval and context enrichment for the planned query."""
[docs]
async def run(
self,
ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps],
) -> EvaluateEvidence | End[PlanningResult]:
"""Run retrieval and either finish early or evaluate sufficiency."""
_append_path(ctx.state, self.get_node_id())
state = ctx.state
deps = ctx.deps
config = get_config()
search_cfg = config.search
if state.agentic_plan is not None and state.agentic_plan.skip_retrieval:
_record_progress(
state,
deps,
phase="retrieval",
status="done",
label="Recherche évitée",
detail=state.agentic_plan.rationale,
count=0,
duration_ms=0.0,
tool="retrieve",
)
_finalize_agentic_meta(state)
return End(
PlanningEarlyExit(
kind="skip_retrieval",
routing_ms=state.routing_ms,
planner_ms=state.planner_ms,
retrieval_ms=0.0,
intent_class=state.agentic_plan.intent_class,
clarification_question=state.agentic_plan.clarification_question,
strict_grounding_requested=bool(state.agentic_plan.strict_grounding_requested),
agentic_rationale=state.agentic_plan.rationale,
agentic_meta=state.agentic_meta,
)
)
_record_progress(
state,
deps,
phase="retrieval",
status="active",
label="Recherche des preuves",
detail=(
"Exploration ciblée des collections imposées"
if state.collections
else "Exploration des passages les plus pertinents"
),
tool="retrieve",
)
started_at = time.perf_counter()
decomp_result = cast(Optional[DecomposedQuery], state.decomposition_result)
if state.collections:
retrieval_results = deps.retrieval_service.retrieve(
query=state.question,
top_k=state.effective_top_k,
score_threshold=state.score_threshold,
filters=state.filters,
collections=state.collections,
granularity=state.effective_granularity,
embedding_preset=state.embedding_preset,
)
elif decomp_result is not None and decomp_result.decomposed and decomp_result.sub_questions:
def _sub_retrieve(query: str) -> list[Any]:
try:
return deps.retrieval_service.retrieve(
query=query,
top_k=state.effective_top_k,
score_threshold=state.score_threshold,
filters=state.filters,
collections=None,
granularity=state.effective_granularity,
embedding_preset=state.embedding_preset,
)
except Exception as exc:
logger.warning("Sub-question retrieval failed (non-fatal): %s", exc)
return []
seen_ids: set[Any] = set()
retrieval_results = []
max_w = min(len(decomp_result.sub_questions), get_config().search.max_parallel_workers)
with ThreadPoolExecutor(max_workers=max_w) as executor:
futures = [executor.submit(_sub_retrieve, sub_question) for sub_question in decomp_result.sub_questions]
for future in futures:
for result in future.result():
if result.subdivision_id not in seen_ids:
retrieval_results.append(result)
seen_ids.add(result.subdivision_id)
retrieval_results.sort(key=lambda item: item.score, reverse=True)
else:
retrieval_results = deps.retrieval_service.retrieve(
query=state.question,
top_k=state.effective_top_k,
score_threshold=state.score_threshold,
filters=state.filters,
collections=state.collections,
granularity=state.effective_granularity,
embedding_preset=state.embedding_preset,
)
state.retrieval_ms = (time.perf_counter() - started_at) * 1000.0
state.retrieval_results = retrieval_results
state.retrieval_stats = deps.retrieval_service.last_retrieval_stats
retrieval_overview = build_retrieval_overview(
retrieval_results,
effective_granularity=state.effective_granularity,
candidate_counts={
"after_fusion": int(state.retrieval_stats.candidates_after_fusion),
"after_threshold": int(state.retrieval_stats.candidates_after_threshold),
"after_rerank": int(state.retrieval_stats.candidates_after_rerank),
"in_context": len(retrieval_results),
},
)
_record_progress(
state,
deps,
phase="retrieval",
status="done",
label=(
f"{len(retrieval_results)} résultat"
f"{'s' if len(retrieval_results) != 1 else ''} retrouvé"
f"{'s' if len(retrieval_results) != 1 else ''}"
),
detail=(
f"{_hierarchy_detail(cast(Dict[str, Any], retrieval_overview['selected_counts']))} "
"retenus après fusion, filtrage et reranking"
+ (
"\n"
f"Temps retrieval: embedding {state.retrieval_stats.embedding_ms:.0f} ms · "
f"sémantique {state.retrieval_stats.semantic_search_ms:.0f} ms · "
f"lexical {state.retrieval_stats.lexical_search_ms:.0f} ms · "
f"rerank {state.retrieval_stats.rerank_ms:.0f} ms"
if (
state.retrieval_stats.embedding_ms > 0.0
or state.retrieval_stats.semantic_search_ms > 0.0
or state.retrieval_stats.lexical_search_ms > 0.0
or state.retrieval_stats.rerank_ms > 0.0
)
else ""
)
),
count=len(retrieval_results),
duration_ms=state.retrieval_ms,
meta={
**retrieval_overview,
"cache_hit": state.retrieval_stats.cache_hit,
"query_variants_count": state.retrieval_stats.query_variants_count,
"embedding_ms": state.retrieval_stats.embedding_ms,
"semantic_search_ms": state.retrieval_stats.semantic_search_ms,
"lexical_search_ms": state.retrieval_stats.lexical_search_ms,
"parallel_search_ms": state.retrieval_stats.parallel_search_ms,
"fusion_ms": state.retrieval_stats.fusion_ms,
"rerank_ms": state.retrieval_stats.rerank_ms,
},
tool="retrieve",
)
gate_threshold = search_cfg.relevance_gate_threshold
if gate_threshold is not None and retrieval_results and retrieval_results[0].score < gate_threshold:
_finalize_agentic_meta(state)
return End(
PlanningEarlyExit(
kind="relevance_gate",
routing_ms=state.routing_ms,
planner_ms=state.planner_ms,
retrieval_ms=state.retrieval_ms,
intent_class=(state.agentic_plan.intent_class if state.agentic_plan is not None else "documentary"),
clarification_question=(
state.agentic_plan.clarification_question if state.agentic_plan is not None else None
),
strict_grounding_requested=bool(
state.agentic_plan.strict_grounding_requested if state.agentic_plan is not None else False
),
agentic_meta=state.agentic_meta,
best_score=retrieval_results[0].score,
gate_threshold=gate_threshold,
candidates_dropped=len(retrieval_results),
)
)
if not retrieval_results:
_finalize_agentic_meta(state)
return End(
PlanningEarlyExit(
kind="empty",
routing_ms=state.routing_ms,
planner_ms=state.planner_ms,
retrieval_ms=state.retrieval_ms,
intent_class=(state.agentic_plan.intent_class if state.agentic_plan is not None else "documentary"),
clarification_question=(
state.agentic_plan.clarification_question if state.agentic_plan is not None else None
),
strict_grounding_requested=bool(
state.agentic_plan.strict_grounding_requested if state.agentic_plan is not None else False
),
agentic_meta=state.agentic_meta,
)
)
state.retrieval_query = (
state.agentic_plan.primary_query
if state.agentic_plan is not None and state.agentic_plan.planner_used
else (state.retrieval_plan.search_query if state.retrieval_plan is not None else state.question)
)
return EvaluateEvidence()
[docs]
@dataclass
class EvaluateEvidence(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]):
"""Assess whether retrieved evidence is sufficient for answering."""
[docs]
async def run(
self,
ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps],
) -> MaybeFetchGraphSupport:
"""Evaluate retrieval quality before optional graph augmentation."""
_append_path(ctx.state, self.get_node_id())
state = ctx.state
deps = ctx.deps
config = get_config()
search_cfg = config.search
if not search_cfg.crag_enabled or not state.retrieval_results:
return MaybeFetchGraphSupport()
# Pre-launch graph fetch in parallel with CRAG — it only needs initial act_ids.
retrieval_plan = state.retrieval_plan
use_graph = (
retrieval_plan is not None
and getattr(retrieval_plan, "use_graph", False)
and deps.graph_rag_service is not None
)
if use_graph:
graph_act_ids = {r.act_id for r in state.retrieval_results if r.act_id is not None}
executor = ThreadPoolExecutor(max_workers=1)
state.graph_prefetch_future = executor.submit(
fetch_graph_context,
act_ids=graph_act_ids,
graph_rag_service=cast(Any, deps.graph_rag_service),
community_enricher=cast(Any, deps.community_enricher),
max_depth=state.graph_depth,
)
scored = [r.score for r in state.retrieval_results if r.score is not None]
if scored and max(scored) >= search_cfg.crag_skip_score_threshold:
state.agentic_meta["crag"] = {
"skipped": True,
"reason": "top_score_above_threshold",
"top_score": round(max(scored), 4),
"threshold": search_cfg.crag_skip_score_threshold,
}
_record_progress(
state,
deps,
phase="crag",
status="done",
label="CRAG ignoré (score élevé)",
detail=f"Score max {max(scored):.3f} ≥ seuil {search_cfg.crag_skip_score_threshold}",
tool="evaluate",
)
return MaybeFetchGraphSupport()
_record_progress(
state,
deps,
phase="crag",
status="active",
label="Auto-vérification de la récupération",
detail="Contrôle de suffisance des preuves récupérées",
tool="evaluate",
)
started_at = time.perf_counter()
crag_evals: list[dict[str, Any]] = []
crag_iter = 0
while crag_iter < search_cfg.crag_max_iterations:
eval_result = await asyncio.to_thread(
evaluate_retrieval,
state.question,
state.retrieval_results,
deps.lightweight_llm,
)
state.output_validation_retries += int(getattr(eval_result, "output_validation_retries", 0))
crag_evals.append(
{
"iteration": crag_iter,
"status": eval_result.status,
"gap": eval_result.gap_hint,
"eval_ms": eval_result.eval_ms,
"fallback": eval_result.fallback,
}
)
if eval_result.status == "SUFFICIENT" or not eval_result.gap_hint:
break
refined_plan = await asyncio.to_thread(
refine_retrieval,
state.question,
eval_result.gap_hint,
deps.lightweight_llm,
)
state.output_validation_retries += int(getattr(refined_plan, "output_validation_retries", 0))
existing_ids = {result.subdivision_id for result in state.retrieval_results}
try:
additional = deps.retrieval_service.retrieve(
query=refined_plan.primary_query,
top_k=state.effective_top_k,
score_threshold=state.score_threshold,
filters=state.filters,
collections=state.collections,
granularity=state.effective_granularity,
embedding_preset=state.embedding_preset,
)
new_results = [result for result in additional if result.subdivision_id not in existing_ids]
if new_results:
state.retrieval_results = state.retrieval_results + new_results
except Exception as exc:
logger.warning("CRAG re-retrieval failed (non-fatal): %s", exc)
break
crag_iter += 1
total_ms = (time.perf_counter() - started_at) * 1000.0
state.agentic_meta["crag"] = {
"iterations": crag_iter,
"evaluations": crag_evals,
"crag_ms": round(total_ms, 1),
}
_record_progress(
state,
deps,
phase="crag",
status="done",
label=(
f"CRAG terminé après {crag_iter} itération{'s' if crag_iter > 1 else ''}"
if crag_iter > 0
else "CRAG validé sans itération corrective"
),
detail=(
f"Dernier statut: {crag_evals[-1]['status']}"
if crag_evals
else "Aucune évaluation corrective nécessaire"
),
duration_ms=total_ms,
meta={"evaluations": crag_evals, "iterations": crag_iter},
tool="evaluate",
)
return MaybeFetchGraphSupport()
[docs]
@dataclass
class MaybeFetchGraphSupport(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]):
"""Optionally augment retrieval context with graph-derived support."""
[docs]
async def run(
self,
ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps],
) -> CompressContext:
"""Fetch graph support when the current plan allows it."""
_append_path(ctx.state, self.get_node_id())
state = ctx.state
deps = ctx.deps
retrieval_plan = state.retrieval_plan
assert retrieval_plan is not None
_record_progress(
state,
deps,
phase="enrichment",
status="active",
label="Assemblage des preuves",
detail=(
"Enrichissement des passages et du graphe"
if retrieval_plan.use_graph and deps.graph_rag_service is not None
else "Enrichissement des passages retenus"
),
tool="enrich",
)
started_at = time.perf_counter()
graph_fetch = None
graph_enrichment_ms = 0.0
community_meta: Dict[str, Any] = {}
use_graph = retrieval_plan.use_graph and deps.graph_rag_service is not None
if use_graph and deps.graph_rag_service is not None:
# Use prefetched graph result from EvaluateEvidence if available.
prefetch = state.graph_prefetch_future
if prefetch is not None:
with ThreadPoolExecutor(max_workers=1) as executor:
ctx_future = executor.submit(
deps.context_service.enrich_results,
state.retrieval_results,
state.effective_include_relations,
state.include_subjects,
)
context_slices = ctx_future.result()
graph_fetch = prefetch.result()
state.graph_prefetch_future = None
else:
graph_act_ids = {result.act_id for result in state.retrieval_results if result.act_id is not None}
with ThreadPoolExecutor(max_workers=min(2, get_config().search.max_parallel_workers)) as executor:
ctx_future = executor.submit(
deps.context_service.enrich_results,
state.retrieval_results,
state.effective_include_relations,
state.include_subjects,
)
graph_future = executor.submit(
fetch_graph_context,
act_ids=graph_act_ids,
graph_rag_service=cast(Any, deps.graph_rag_service),
community_enricher=cast(Any, deps.community_enricher),
max_depth=state.graph_depth,
)
context_slices = ctx_future.result()
graph_fetch = graph_future.result()
if graph_fetch is not None:
graph_enrichment_ms = graph_fetch.duration_ms
community_meta = graph_fetch.community_meta
else:
context_slices = deps.context_service.enrich_results(
state.retrieval_results,
include_relations=state.effective_include_relations,
include_subjects=state.include_subjects,
)
state.context_enrichment_ms = (time.perf_counter() - started_at) * 1000.0
state.graph_enrichment_ms = graph_enrichment_ms
state.context_slices = context_slices
state.graph_fetch = graph_fetch
state.community_meta = community_meta
enrichment_overview = build_retrieval_overview(
context_slices,
effective_granularity=state.effective_granularity,
)
_record_progress(
state,
deps,
phase="enrichment",
status="done",
label=(
f"{len(context_slices)} passage"
f"{'s' if len(context_slices) != 1 else ''} retenu"
f"{'s' if len(context_slices) != 1 else ''}"
),
detail=(
f"Contexte final: {_hierarchy_detail(cast(Dict[str, Any], enrichment_overview['selected_counts']))}"
+ (" + enrichissement graphe" if graph_fetch is not None else "")
),
count=len(context_slices),
duration_ms=state.context_enrichment_ms,
meta={
**enrichment_overview,
"graph_used": graph_fetch is not None,
"graph_nodes": len(graph_fetch.nodes) if graph_fetch is not None else 0,
"graph_relationships": len(graph_fetch.relationships) if graph_fetch is not None else 0,
},
tool="enrich",
)
if (
state.retrieval_depth == "deep"
and state.agentic_plan is not None
and state.agentic_plan.needs_complementary
and state.agentic_plan.complementary_queries
):
_record_progress(
state,
deps,
phase="complementary",
status="active",
label="Recherche complémentaire",
detail="Exploration de requêtes additionnelles pour élargir la preuve",
count=len(state.agentic_plan.complementary_queries),
tool="retrieve_complementary",
)
comp_started_at = time.perf_counter()
search_cfg = get_config().search
existing_act_ids = {slice_.act.act_id for slice_ in state.context_slices}
complementary_count = 0
def _run_comp_query(
query: AgenticComplementaryQuery,
) -> tuple[AgenticComplementaryQuery, list[Any]]:
try:
return query, deps.retrieval_service.retrieve(
query=query.query,
top_k=search_cfg.complementary_top_k,
score_threshold=state.score_threshold,
filters=state.filters,
granularity=state.effective_granularity,
embedding_preset=state.embedding_preset,
)
except Exception as exc:
logger.warning("Complementary retrieval failed (non-fatal): %s", exc)
return query, []
queries_to_run = state.agentic_plan.complementary_queries[: search_cfg.complementary_max_queries]
max_w = max(min(len(queries_to_run), get_config().search.max_parallel_workers), 1)
with ThreadPoolExecutor(max_workers=max_w) as executor:
futures = [executor.submit(_run_comp_query, query) for query in queries_to_run]
for future in futures:
query, comp_results = future.result()
new_results = [result for result in comp_results if result.act_id not in existing_act_ids]
if new_results:
new_slices = deps.context_service.enrich_results(
new_results,
include_relations=state.effective_include_relations,
include_subjects=state.include_subjects,
)
state.context_slices.extend(new_slices)
existing_act_ids.update(slice_.act.act_id for slice_ in new_slices)
complementary_count += len(new_slices)
logger.debug(
"Complementary retrieval '%s' added %d new slices",
query.query,
len(new_slices),
)
state.complementary_ms = (time.perf_counter() - comp_started_at) * 1000.0
state.agentic_meta["complementary"] = {
"queries_executed": len(queries_to_run),
"new_slices_added": complementary_count,
"complementary_ms": round(state.complementary_ms, 1),
}
_record_progress(
state,
deps,
phase="complementary",
status="done",
label="Recherche complémentaire terminée",
detail=(
f"{complementary_count} passage{'s' if complementary_count != 1 else ''} ajouté"
if complementary_count > 0
else "Aucun passage supplémentaire retenu"
),
count=complementary_count,
duration_ms=state.complementary_ms,
meta=cast(Dict[str, Any], state.agentic_meta.get("complementary") or {}),
tool="retrieve_complementary",
)
return CompressContext()
[docs]
@dataclass
class CompressContext(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]):
"""Finalize and optionally compress context before generation."""
[docs]
async def run(
self,
ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps],
) -> End[PlanningResult]:
"""Produce the terminal planning artifact for downstream generation."""
_append_path(ctx.state, self.get_node_id())
state = ctx.state
deps = ctx.deps
config = get_config()
search_cfg = config.search
pre_compression_count = len(state.context_slices)
max_context_chars = config.generation.max_context_chars
if (state.agentic_plan is not None and state.agentic_plan.needs_compression) or sum(
len(slice_.content or "") for slice_ in state.context_slices
) > max_context_chars * search_cfg.compression_threshold_ratio:
_record_progress(
state,
deps,
phase="compression",
status="active",
label="Compression du contexte",
detail="Réduction du contexte pour respecter le budget de génération",
count=pre_compression_count,
tool="compress",
)
comp_started_at = time.perf_counter()
state.context_slices = compress_context(
state.context_slices,
deps.lightweight_llm,
budget_chars=max_context_chars,
)
state.compression_ms = (time.perf_counter() - comp_started_at) * 1000.0
state.agentic_meta["compression"] = {
"triggered": True,
"pre_compression_slices": pre_compression_count,
"post_compression_slices": len(state.context_slices),
"compression_ms": round(state.compression_ms, 1),
}
_record_progress(
state,
deps,
phase="compression",
status="done",
label="Compression terminée",
detail=(
f"{pre_compression_count} → {len(state.context_slices)} passage"
f"{'s' if len(state.context_slices) != 1 else ''}"
),
count=len(state.context_slices),
duration_ms=state.compression_ms,
meta=cast(Dict[str, Any], state.agentic_meta.get("compression") or {}),
tool="compress",
)
_finalize_agentic_meta(state)
retrieval_plan = state.retrieval_plan
assert retrieval_plan is not None
return End(
PlanningContext(
context_slices=state.context_slices,
graph_fetch=state.graph_fetch,
retrieval_plan=retrieval_plan,
agentic_plan=state.agentic_plan,
agentic_meta=state.agentic_meta,
retrieval_query=state.retrieval_query or state.question,
effective_top_k=state.effective_top_k,
effective_granularity=state.effective_granularity,
effective_include_relations=state.effective_include_relations,
community_meta=state.community_meta,
routing_ms=state.routing_ms,
planner_ms=state.planner_ms,
retrieval_ms=state.retrieval_ms,
context_enrichment_ms=state.context_enrichment_ms,
graph_enrichment_ms=state.graph_enrichment_ms,
complementary_ms=state.complementary_ms,
compression_ms=state.compression_ms,
retrieval_stats=state.retrieval_stats,
)
)
_PLANNING_GRAPH = Graph(
nodes=(
LoadConversationContext,
DecomposeQuestion,
RouteQuestion,
PlanRetrieval,
RunRetrieval,
EvaluateEvidence,
MaybeFetchGraphSupport,
CompressContext,
),
name="rag_planning_graph",
auto_instrument=False,
)
[docs]
def run_planning_graph(
*,
deps: AgenticPlanningDeps,
) -> PlanningResult:
"""Execute the planning graph synchronously and return the terminal artifact."""
state = PlanningGraphState(
question=deps.question,
top_k=deps.top_k,
score_threshold=deps.score_threshold,
filters=deps.filters,
include_relations=deps.include_relations,
include_subjects=deps.include_subjects,
collections=deps.collections,
granularity=deps.granularity,
graph_depth=deps.graph_depth,
use_graph=deps.use_graph,
embedding_preset=deps.embedding_preset,
retrieval_depth=deps.retrieval_depth,
planner_run_id=uuid4().hex,
)
result = _PLANNING_GRAPH.run_sync(LoadConversationContext(), state=state, deps=deps)
return result.output