Source code for lalandre_rag.agentic.graph

"""Pydantic Graph orchestration for RAG planning phases."""

from __future__ import annotations

import asyncio
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Dict, Optional, cast
from uuid import uuid4

from lalandre_core.config import get_config
from pydantic_graph import BaseNode, End, Graph, GraphRunContext

from lalandre_rag.agentic.runtime import (
    AgenticComplementaryQuery,
    DecomposedQuery,
    evaluate_retrieval,
    plan_retrieval,
    refine_retrieval,
)
from lalandre_rag.modes.hybrid_graph import fetch_graph_context
from lalandre_rag.modes.hybrid_helpers import emit_progress
from lalandre_rag.retrieval.context.compressor import compress_context
from lalandre_rag.retrieval.overview import build_retrieval_overview

from .deps import AgenticPlanningDeps
from .models import (
    PhaseTraceEvent,
    PlanningContext,
    PlanningEarlyExit,
    PlanningGraphState,
    PlanningResult,
)

logger = logging.getLogger(__name__)

_PROFILE_LABELS = {
    "global_overview": "Vue d'ensemble globale",
    "citation_precision": "Recherche de citation précise",
    "relationship_focus": "Analyse de relations juridiques",
    "contextual_default": "Question contextuelle",
    "manual_collections": "Collections imposées",
    "manual_override": "Paramètres manuels",
}


def _hierarchy_detail(selected_counts: Dict[str, Any]) -> str:
    acts = int(selected_counts.get("acts") or 0)
    subdivisions = int(selected_counts.get("subdivisions") or 0)
    chunks = int(selected_counts.get("chunks") or 0)
    return (
        f"{acts} acte{'s' if acts != 1 else ''} · "
        f"{subdivisions} subdivision{'s' if subdivisions != 1 else ''} · "
        f"{chunks} chunk{'s' if chunks != 1 else ''}"
    )


def _record_progress(
    state: PlanningGraphState,
    deps: AgenticPlanningDeps,
    *,
    phase: str,
    status: str,
    label: str,
    detail: Optional[str] = None,
    count: Optional[int] = None,
    duration_ms: Optional[float] = None,
    meta: Optional[Dict[str, Any]] = None,
    tool: Optional[str] = None,
) -> None:
    payload: Dict[str, Any] = {
        "phase": phase,
        "status": status,
        "label": label,
    }
    if detail is not None:
        payload["detail"] = detail
    if count is not None:
        payload["count"] = count
    if duration_ms is not None:
        payload["duration_ms"] = round(float(duration_ms), 1)
    if meta:
        payload["meta"] = meta
    event = PhaseTraceEvent(
        phase=phase,
        status=status,
        label=label,
        detail=detail,
        count=count,
        duration_ms=round(float(duration_ms), 1) if duration_ms is not None else None,
        meta=meta or {},
        tool=tool,
    )
    state.trace_events.append(event)
    emit_progress(
        deps.progress_callback,
        phase=phase,
        status=status,
        label=label,
        detail=cast(Optional[str], payload.get("detail")),
        count=cast(Optional[int], payload.get("count")),
        duration_ms=cast(Optional[float], payload.get("duration_ms")),
        meta=cast(Optional[Dict[str, Any]], payload.get("meta")),
    )


def _append_path(state: PlanningGraphState, node_name: str) -> None:
    state.planner_path.append(node_name)


def _finalize_agentic_meta(state: PlanningGraphState) -> None:
    state.agentic_meta.setdefault("planner_runtime", "pydanticai")
    state.agentic_meta["planner_run_id"] = state.planner_run_id
    state.agentic_meta["planner_path"] = list(state.planner_path)
    state.agentic_meta["tool_trace"] = [event.model_dump() for event in state.trace_events]
    state.agentic_meta["output_validation_retries"] = state.output_validation_retries


[docs] @dataclass class LoadConversationContext(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]): """Bootstrap node for future conversation-memory loading."""
[docs] async def run(self, ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps]) -> DecomposeQuestion: """Advance to the question-decomposition phase.""" _append_path(ctx.state, self.get_node_id()) return DecomposeQuestion()
[docs] @dataclass class DecomposeQuestion(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]): """Placeholder decomposition node for complex multi-part questions."""
[docs] async def run( self, ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps], ) -> RouteQuestion: """Advance to the routing phase.""" _append_path(ctx.state, self.get_node_id()) return RouteQuestion()
[docs] @dataclass class RouteQuestion(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]): """Route the user question toward the appropriate retrieval profile."""
[docs] async def run( self, ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps], ) -> PlanRetrieval: """Compute routing metadata and transition to planning.""" _append_path(ctx.state, self.get_node_id()) state = ctx.state deps = ctx.deps _record_progress( state, deps, phase="routing", status="active", label="Qualification de la demande", detail="Choix du profil de recherche et du niveau de granularité", tool="route", ) # Launch planning LLM call in parallel with routing (they're independent). if not state.collections: _record_progress( state, deps, phase="planning", status="active", label="Planification de la recherche", detail="Préparation de la stratégie de récupération (en parallèle)", tool="plan", ) planning_executor = ThreadPoolExecutor(max_workers=1) state.planning_future = planning_executor.submit( plan_retrieval, state.question, deps.lightweight_llm, ) started_at = time.perf_counter() if state.collections: retrieval_plan = deps.query_router.route( question=state.question, top_k=state.top_k, requested_granularity=state.granularity, ) retrieval_plan = retrieval_plan.__class__( profile="manual_collections", granularity=state.granularity, top_k=state.top_k, include_relations_hint=state.include_relations, rationale="Explicit collections provided: routing heuristics bypassed.", use_graph=bool(state.use_graph), execution_mode="hybrid", routing_source="heuristic", ) effective_top_k = state.top_k effective_granularity = state.granularity effective_include_relations = state.include_relations else: retrieval_plan = deps.query_router.route( question=state.question, top_k=state.top_k, requested_granularity=state.granularity, ) effective_top_k = retrieval_plan.top_k effective_granularity = retrieval_plan.granularity effective_include_relations = state.include_relations or retrieval_plan.include_relations_hint if state.use_graph is not None: retrieval_plan = retrieval_plan.__class__( **{ **retrieval_plan.__dict__, "use_graph": state.use_graph, } ) state.routing_ms = (time.perf_counter() - started_at) * 1000.0 state.retrieval_plan = retrieval_plan state.effective_top_k = effective_top_k state.effective_granularity = effective_granularity state.effective_include_relations = effective_include_relations state.agentic_meta["routing"] = { "profile": retrieval_plan.profile, "execution_mode": retrieval_plan.execution_mode, "granularity": effective_granularity, "top_k": effective_top_k, "use_graph": retrieval_plan.use_graph, "rationale": retrieval_plan.rationale, } intent_parser = getattr(deps.query_router, "intent_parser", None) parsed_intent = getattr(intent_parser, "_last_parsed_intent", None) if parsed_intent is not None: state.output_validation_retries += int(getattr(parsed_intent, "output_validation_retries", 0)) _record_progress( state, deps, phase="routing", status="done", label=str(_PROFILE_LABELS.get(retrieval_plan.profile, retrieval_plan.profile)), detail=retrieval_plan.rationale, duration_ms=state.routing_ms, meta={ "profile": retrieval_plan.profile, "execution_mode": retrieval_plan.execution_mode, "granularity": effective_granularity, "top_k": effective_top_k, "use_graph": retrieval_plan.use_graph, }, tool="route", ) return PlanRetrieval()
[docs] @dataclass class PlanRetrieval(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]): """Build the retrieval plan and complementary-query strategy."""
[docs] async def run( self, ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps], ) -> RunRetrieval: """Resolve the planning step and transition to retrieval.""" _append_path(ctx.state, self.get_node_id()) state = ctx.state deps = ctx.deps if state.collections: return RunRetrieval() # If routing already ran concurrently (via RouteQuestion) we just # need to await the planning future that was launched in parallel. if state.planning_future is not None: agentic_plan = state.planning_future.result() state.planning_future = None else: _record_progress( state, deps, phase="planning", status="active", label="Planification de la recherche", detail="Préparation de la stratégie de récupération", tool="plan", ) agentic_plan = await asyncio.to_thread(plan_retrieval, state.question, deps.lightweight_llm) state.planner_ms = agentic_plan.planning_ms state.agentic_plan = agentic_plan state.output_validation_retries += int(getattr(agentic_plan, "output_validation_retries", 0)) state.agentic_meta["planner"] = { "used": agentic_plan.planner_used, "intent_class": agentic_plan.intent_class, "primary_query": agentic_plan.primary_query, "clarification_question": agentic_plan.clarification_question, "strict_grounding_requested": agentic_plan.strict_grounding_requested, "needs_complementary": agentic_plan.needs_complementary, "needs_compression": agentic_plan.needs_compression, "complementary_count": len(agentic_plan.complementary_queries), "rationale": agentic_plan.rationale, "planning_ms": state.planner_ms, } _record_progress( state, deps, phase="planning", status="done", label="Plan de recherche établi", detail=agentic_plan.rationale, duration_ms=state.planner_ms, meta={ "planner_used": bool(agentic_plan.planner_used), "intent_class": agentic_plan.intent_class, "primary_query": agentic_plan.primary_query, "clarification_question": agentic_plan.clarification_question, "strict_grounding_requested": bool(agentic_plan.strict_grounding_requested), "needs_complementary": bool(agentic_plan.needs_complementary), "needs_compression": bool(agentic_plan.needs_compression), "complementary_queries": [cq.query for cq in agentic_plan.complementary_queries], }, tool="plan", ) return RunRetrieval()
[docs] @dataclass class RunRetrieval(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]): """Execute retrieval and context enrichment for the planned query."""
[docs] async def run( self, ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps], ) -> EvaluateEvidence | End[PlanningResult]: """Run retrieval and either finish early or evaluate sufficiency.""" _append_path(ctx.state, self.get_node_id()) state = ctx.state deps = ctx.deps config = get_config() search_cfg = config.search if state.agentic_plan is not None and state.agentic_plan.skip_retrieval: _record_progress( state, deps, phase="retrieval", status="done", label="Recherche évitée", detail=state.agentic_plan.rationale, count=0, duration_ms=0.0, tool="retrieve", ) _finalize_agentic_meta(state) return End( PlanningEarlyExit( kind="skip_retrieval", routing_ms=state.routing_ms, planner_ms=state.planner_ms, retrieval_ms=0.0, intent_class=state.agentic_plan.intent_class, clarification_question=state.agentic_plan.clarification_question, strict_grounding_requested=bool(state.agentic_plan.strict_grounding_requested), agentic_rationale=state.agentic_plan.rationale, agentic_meta=state.agentic_meta, ) ) _record_progress( state, deps, phase="retrieval", status="active", label="Recherche des preuves", detail=( "Exploration ciblée des collections imposées" if state.collections else "Exploration des passages les plus pertinents" ), tool="retrieve", ) started_at = time.perf_counter() decomp_result = cast(Optional[DecomposedQuery], state.decomposition_result) if state.collections: retrieval_results = deps.retrieval_service.retrieve( query=state.question, top_k=state.effective_top_k, score_threshold=state.score_threshold, filters=state.filters, collections=state.collections, granularity=state.effective_granularity, embedding_preset=state.embedding_preset, ) elif decomp_result is not None and decomp_result.decomposed and decomp_result.sub_questions: def _sub_retrieve(query: str) -> list[Any]: try: return deps.retrieval_service.retrieve( query=query, top_k=state.effective_top_k, score_threshold=state.score_threshold, filters=state.filters, collections=None, granularity=state.effective_granularity, embedding_preset=state.embedding_preset, ) except Exception as exc: logger.warning("Sub-question retrieval failed (non-fatal): %s", exc) return [] seen_ids: set[Any] = set() retrieval_results = [] max_w = min(len(decomp_result.sub_questions), get_config().search.max_parallel_workers) with ThreadPoolExecutor(max_workers=max_w) as executor: futures = [executor.submit(_sub_retrieve, sub_question) for sub_question in decomp_result.sub_questions] for future in futures: for result in future.result(): if result.subdivision_id not in seen_ids: retrieval_results.append(result) seen_ids.add(result.subdivision_id) retrieval_results.sort(key=lambda item: item.score, reverse=True) else: retrieval_results = deps.retrieval_service.retrieve( query=state.question, top_k=state.effective_top_k, score_threshold=state.score_threshold, filters=state.filters, collections=state.collections, granularity=state.effective_granularity, embedding_preset=state.embedding_preset, ) state.retrieval_ms = (time.perf_counter() - started_at) * 1000.0 state.retrieval_results = retrieval_results state.retrieval_stats = deps.retrieval_service.last_retrieval_stats retrieval_overview = build_retrieval_overview( retrieval_results, effective_granularity=state.effective_granularity, candidate_counts={ "after_fusion": int(state.retrieval_stats.candidates_after_fusion), "after_threshold": int(state.retrieval_stats.candidates_after_threshold), "after_rerank": int(state.retrieval_stats.candidates_after_rerank), "in_context": len(retrieval_results), }, ) _record_progress( state, deps, phase="retrieval", status="done", label=( f"{len(retrieval_results)} résultat" f"{'s' if len(retrieval_results) != 1 else ''} retrouvé" f"{'s' if len(retrieval_results) != 1 else ''}" ), detail=( f"{_hierarchy_detail(cast(Dict[str, Any], retrieval_overview['selected_counts']))} " "retenus après fusion, filtrage et reranking" + ( "\n" f"Temps retrieval: embedding {state.retrieval_stats.embedding_ms:.0f} ms · " f"sémantique {state.retrieval_stats.semantic_search_ms:.0f} ms · " f"lexical {state.retrieval_stats.lexical_search_ms:.0f} ms · " f"rerank {state.retrieval_stats.rerank_ms:.0f} ms" if ( state.retrieval_stats.embedding_ms > 0.0 or state.retrieval_stats.semantic_search_ms > 0.0 or state.retrieval_stats.lexical_search_ms > 0.0 or state.retrieval_stats.rerank_ms > 0.0 ) else "" ) ), count=len(retrieval_results), duration_ms=state.retrieval_ms, meta={ **retrieval_overview, "cache_hit": state.retrieval_stats.cache_hit, "query_variants_count": state.retrieval_stats.query_variants_count, "embedding_ms": state.retrieval_stats.embedding_ms, "semantic_search_ms": state.retrieval_stats.semantic_search_ms, "lexical_search_ms": state.retrieval_stats.lexical_search_ms, "parallel_search_ms": state.retrieval_stats.parallel_search_ms, "fusion_ms": state.retrieval_stats.fusion_ms, "rerank_ms": state.retrieval_stats.rerank_ms, }, tool="retrieve", ) gate_threshold = search_cfg.relevance_gate_threshold if gate_threshold is not None and retrieval_results and retrieval_results[0].score < gate_threshold: _finalize_agentic_meta(state) return End( PlanningEarlyExit( kind="relevance_gate", routing_ms=state.routing_ms, planner_ms=state.planner_ms, retrieval_ms=state.retrieval_ms, intent_class=(state.agentic_plan.intent_class if state.agentic_plan is not None else "documentary"), clarification_question=( state.agentic_plan.clarification_question if state.agentic_plan is not None else None ), strict_grounding_requested=bool( state.agentic_plan.strict_grounding_requested if state.agentic_plan is not None else False ), agentic_meta=state.agentic_meta, best_score=retrieval_results[0].score, gate_threshold=gate_threshold, candidates_dropped=len(retrieval_results), ) ) if not retrieval_results: _finalize_agentic_meta(state) return End( PlanningEarlyExit( kind="empty", routing_ms=state.routing_ms, planner_ms=state.planner_ms, retrieval_ms=state.retrieval_ms, intent_class=(state.agentic_plan.intent_class if state.agentic_plan is not None else "documentary"), clarification_question=( state.agentic_plan.clarification_question if state.agentic_plan is not None else None ), strict_grounding_requested=bool( state.agentic_plan.strict_grounding_requested if state.agentic_plan is not None else False ), agentic_meta=state.agentic_meta, ) ) state.retrieval_query = ( state.agentic_plan.primary_query if state.agentic_plan is not None and state.agentic_plan.planner_used else (state.retrieval_plan.search_query if state.retrieval_plan is not None else state.question) ) return EvaluateEvidence()
[docs] @dataclass class EvaluateEvidence(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]): """Assess whether retrieved evidence is sufficient for answering."""
[docs] async def run( self, ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps], ) -> MaybeFetchGraphSupport: """Evaluate retrieval quality before optional graph augmentation.""" _append_path(ctx.state, self.get_node_id()) state = ctx.state deps = ctx.deps config = get_config() search_cfg = config.search if not search_cfg.crag_enabled or not state.retrieval_results: return MaybeFetchGraphSupport() # Pre-launch graph fetch in parallel with CRAG — it only needs initial act_ids. retrieval_plan = state.retrieval_plan use_graph = ( retrieval_plan is not None and getattr(retrieval_plan, "use_graph", False) and deps.graph_rag_service is not None ) if use_graph: graph_act_ids = {r.act_id for r in state.retrieval_results if r.act_id is not None} executor = ThreadPoolExecutor(max_workers=1) state.graph_prefetch_future = executor.submit( fetch_graph_context, act_ids=graph_act_ids, graph_rag_service=cast(Any, deps.graph_rag_service), community_enricher=cast(Any, deps.community_enricher), max_depth=state.graph_depth, ) scored = [r.score for r in state.retrieval_results if r.score is not None] if scored and max(scored) >= search_cfg.crag_skip_score_threshold: state.agentic_meta["crag"] = { "skipped": True, "reason": "top_score_above_threshold", "top_score": round(max(scored), 4), "threshold": search_cfg.crag_skip_score_threshold, } _record_progress( state, deps, phase="crag", status="done", label="CRAG ignoré (score élevé)", detail=f"Score max {max(scored):.3f} ≥ seuil {search_cfg.crag_skip_score_threshold}", tool="evaluate", ) return MaybeFetchGraphSupport() _record_progress( state, deps, phase="crag", status="active", label="Auto-vérification de la récupération", detail="Contrôle de suffisance des preuves récupérées", tool="evaluate", ) started_at = time.perf_counter() crag_evals: list[dict[str, Any]] = [] crag_iter = 0 while crag_iter < search_cfg.crag_max_iterations: eval_result = await asyncio.to_thread( evaluate_retrieval, state.question, state.retrieval_results, deps.lightweight_llm, ) state.output_validation_retries += int(getattr(eval_result, "output_validation_retries", 0)) crag_evals.append( { "iteration": crag_iter, "status": eval_result.status, "gap": eval_result.gap_hint, "eval_ms": eval_result.eval_ms, "fallback": eval_result.fallback, } ) if eval_result.status == "SUFFICIENT" or not eval_result.gap_hint: break refined_plan = await asyncio.to_thread( refine_retrieval, state.question, eval_result.gap_hint, deps.lightweight_llm, ) state.output_validation_retries += int(getattr(refined_plan, "output_validation_retries", 0)) existing_ids = {result.subdivision_id for result in state.retrieval_results} try: additional = deps.retrieval_service.retrieve( query=refined_plan.primary_query, top_k=state.effective_top_k, score_threshold=state.score_threshold, filters=state.filters, collections=state.collections, granularity=state.effective_granularity, embedding_preset=state.embedding_preset, ) new_results = [result for result in additional if result.subdivision_id not in existing_ids] if new_results: state.retrieval_results = state.retrieval_results + new_results except Exception as exc: logger.warning("CRAG re-retrieval failed (non-fatal): %s", exc) break crag_iter += 1 total_ms = (time.perf_counter() - started_at) * 1000.0 state.agentic_meta["crag"] = { "iterations": crag_iter, "evaluations": crag_evals, "crag_ms": round(total_ms, 1), } _record_progress( state, deps, phase="crag", status="done", label=( f"CRAG terminé après {crag_iter} itération{'s' if crag_iter > 1 else ''}" if crag_iter > 0 else "CRAG validé sans itération corrective" ), detail=( f"Dernier statut: {crag_evals[-1]['status']}" if crag_evals else "Aucune évaluation corrective nécessaire" ), duration_ms=total_ms, meta={"evaluations": crag_evals, "iterations": crag_iter}, tool="evaluate", ) return MaybeFetchGraphSupport()
[docs] @dataclass class MaybeFetchGraphSupport(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]): """Optionally augment retrieval context with graph-derived support."""
[docs] async def run( self, ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps], ) -> CompressContext: """Fetch graph support when the current plan allows it.""" _append_path(ctx.state, self.get_node_id()) state = ctx.state deps = ctx.deps retrieval_plan = state.retrieval_plan assert retrieval_plan is not None _record_progress( state, deps, phase="enrichment", status="active", label="Assemblage des preuves", detail=( "Enrichissement des passages et du graphe" if retrieval_plan.use_graph and deps.graph_rag_service is not None else "Enrichissement des passages retenus" ), tool="enrich", ) started_at = time.perf_counter() graph_fetch = None graph_enrichment_ms = 0.0 community_meta: Dict[str, Any] = {} use_graph = retrieval_plan.use_graph and deps.graph_rag_service is not None if use_graph and deps.graph_rag_service is not None: # Use prefetched graph result from EvaluateEvidence if available. prefetch = state.graph_prefetch_future if prefetch is not None: with ThreadPoolExecutor(max_workers=1) as executor: ctx_future = executor.submit( deps.context_service.enrich_results, state.retrieval_results, state.effective_include_relations, state.include_subjects, ) context_slices = ctx_future.result() graph_fetch = prefetch.result() state.graph_prefetch_future = None else: graph_act_ids = {result.act_id for result in state.retrieval_results if result.act_id is not None} with ThreadPoolExecutor(max_workers=min(2, get_config().search.max_parallel_workers)) as executor: ctx_future = executor.submit( deps.context_service.enrich_results, state.retrieval_results, state.effective_include_relations, state.include_subjects, ) graph_future = executor.submit( fetch_graph_context, act_ids=graph_act_ids, graph_rag_service=cast(Any, deps.graph_rag_service), community_enricher=cast(Any, deps.community_enricher), max_depth=state.graph_depth, ) context_slices = ctx_future.result() graph_fetch = graph_future.result() if graph_fetch is not None: graph_enrichment_ms = graph_fetch.duration_ms community_meta = graph_fetch.community_meta else: context_slices = deps.context_service.enrich_results( state.retrieval_results, include_relations=state.effective_include_relations, include_subjects=state.include_subjects, ) state.context_enrichment_ms = (time.perf_counter() - started_at) * 1000.0 state.graph_enrichment_ms = graph_enrichment_ms state.context_slices = context_slices state.graph_fetch = graph_fetch state.community_meta = community_meta enrichment_overview = build_retrieval_overview( context_slices, effective_granularity=state.effective_granularity, ) _record_progress( state, deps, phase="enrichment", status="done", label=( f"{len(context_slices)} passage" f"{'s' if len(context_slices) != 1 else ''} retenu" f"{'s' if len(context_slices) != 1 else ''}" ), detail=( f"Contexte final: {_hierarchy_detail(cast(Dict[str, Any], enrichment_overview['selected_counts']))}" + (" + enrichissement graphe" if graph_fetch is not None else "") ), count=len(context_slices), duration_ms=state.context_enrichment_ms, meta={ **enrichment_overview, "graph_used": graph_fetch is not None, "graph_nodes": len(graph_fetch.nodes) if graph_fetch is not None else 0, "graph_relationships": len(graph_fetch.relationships) if graph_fetch is not None else 0, }, tool="enrich", ) if ( state.retrieval_depth == "deep" and state.agentic_plan is not None and state.agentic_plan.needs_complementary and state.agentic_plan.complementary_queries ): _record_progress( state, deps, phase="complementary", status="active", label="Recherche complémentaire", detail="Exploration de requêtes additionnelles pour élargir la preuve", count=len(state.agentic_plan.complementary_queries), tool="retrieve_complementary", ) comp_started_at = time.perf_counter() search_cfg = get_config().search existing_act_ids = {slice_.act.act_id for slice_ in state.context_slices} complementary_count = 0 def _run_comp_query( query: AgenticComplementaryQuery, ) -> tuple[AgenticComplementaryQuery, list[Any]]: try: return query, deps.retrieval_service.retrieve( query=query.query, top_k=search_cfg.complementary_top_k, score_threshold=state.score_threshold, filters=state.filters, granularity=state.effective_granularity, embedding_preset=state.embedding_preset, ) except Exception as exc: logger.warning("Complementary retrieval failed (non-fatal): %s", exc) return query, [] queries_to_run = state.agentic_plan.complementary_queries[: search_cfg.complementary_max_queries] max_w = max(min(len(queries_to_run), get_config().search.max_parallel_workers), 1) with ThreadPoolExecutor(max_workers=max_w) as executor: futures = [executor.submit(_run_comp_query, query) for query in queries_to_run] for future in futures: query, comp_results = future.result() new_results = [result for result in comp_results if result.act_id not in existing_act_ids] if new_results: new_slices = deps.context_service.enrich_results( new_results, include_relations=state.effective_include_relations, include_subjects=state.include_subjects, ) state.context_slices.extend(new_slices) existing_act_ids.update(slice_.act.act_id for slice_ in new_slices) complementary_count += len(new_slices) logger.debug( "Complementary retrieval '%s' added %d new slices", query.query, len(new_slices), ) state.complementary_ms = (time.perf_counter() - comp_started_at) * 1000.0 state.agentic_meta["complementary"] = { "queries_executed": len(queries_to_run), "new_slices_added": complementary_count, "complementary_ms": round(state.complementary_ms, 1), } _record_progress( state, deps, phase="complementary", status="done", label="Recherche complémentaire terminée", detail=( f"{complementary_count} passage{'s' if complementary_count != 1 else ''} ajouté" if complementary_count > 0 else "Aucun passage supplémentaire retenu" ), count=complementary_count, duration_ms=state.complementary_ms, meta=cast(Dict[str, Any], state.agentic_meta.get("complementary") or {}), tool="retrieve_complementary", ) return CompressContext()
[docs] @dataclass class CompressContext(BaseNode[PlanningGraphState, AgenticPlanningDeps, PlanningResult]): """Finalize and optionally compress context before generation."""
[docs] async def run( self, ctx: GraphRunContext[PlanningGraphState, AgenticPlanningDeps], ) -> End[PlanningResult]: """Produce the terminal planning artifact for downstream generation.""" _append_path(ctx.state, self.get_node_id()) state = ctx.state deps = ctx.deps config = get_config() search_cfg = config.search pre_compression_count = len(state.context_slices) max_context_chars = config.generation.max_context_chars if (state.agentic_plan is not None and state.agentic_plan.needs_compression) or sum( len(slice_.content or "") for slice_ in state.context_slices ) > max_context_chars * search_cfg.compression_threshold_ratio: _record_progress( state, deps, phase="compression", status="active", label="Compression du contexte", detail="Réduction du contexte pour respecter le budget de génération", count=pre_compression_count, tool="compress", ) comp_started_at = time.perf_counter() state.context_slices = compress_context( state.context_slices, deps.lightweight_llm, budget_chars=max_context_chars, ) state.compression_ms = (time.perf_counter() - comp_started_at) * 1000.0 state.agentic_meta["compression"] = { "triggered": True, "pre_compression_slices": pre_compression_count, "post_compression_slices": len(state.context_slices), "compression_ms": round(state.compression_ms, 1), } _record_progress( state, deps, phase="compression", status="done", label="Compression terminée", detail=( f"{pre_compression_count}{len(state.context_slices)} passage" f"{'s' if len(state.context_slices) != 1 else ''}" ), count=len(state.context_slices), duration_ms=state.compression_ms, meta=cast(Dict[str, Any], state.agentic_meta.get("compression") or {}), tool="compress", ) _finalize_agentic_meta(state) retrieval_plan = state.retrieval_plan assert retrieval_plan is not None return End( PlanningContext( context_slices=state.context_slices, graph_fetch=state.graph_fetch, retrieval_plan=retrieval_plan, agentic_plan=state.agentic_plan, agentic_meta=state.agentic_meta, retrieval_query=state.retrieval_query or state.question, effective_top_k=state.effective_top_k, effective_granularity=state.effective_granularity, effective_include_relations=state.effective_include_relations, community_meta=state.community_meta, routing_ms=state.routing_ms, planner_ms=state.planner_ms, retrieval_ms=state.retrieval_ms, context_enrichment_ms=state.context_enrichment_ms, graph_enrichment_ms=state.graph_enrichment_ms, complementary_ms=state.complementary_ms, compression_ms=state.compression_ms, retrieval_stats=state.retrieval_stats, ) )
_PLANNING_GRAPH = Graph( nodes=( LoadConversationContext, DecomposeQuestion, RouteQuestion, PlanRetrieval, RunRetrieval, EvaluateEvidence, MaybeFetchGraphSupport, CompressContext, ), name="rag_planning_graph", auto_instrument=False, )
[docs] def run_planning_graph( *, deps: AgenticPlanningDeps, ) -> PlanningResult: """Execute the planning graph synchronously and return the terminal artifact.""" state = PlanningGraphState( question=deps.question, top_k=deps.top_k, score_threshold=deps.score_threshold, filters=deps.filters, include_relations=deps.include_relations, include_subjects=deps.include_subjects, collections=deps.collections, granularity=deps.granularity, graph_depth=deps.graph_depth, use_graph=deps.use_graph, embedding_preset=deps.embedding_preset, retrieval_depth=deps.retrieval_depth, planner_run_id=uuid4().hex, ) result = _PLANNING_GRAPH.run_sync(LoadConversationContext(), state=state, deps=deps) return result.output