"""
Hybrid mode for QA with deterministic routing and context budgeting.
Pipeline orchestrator — delegates generation to hybrid_generation and
graph enrichment to hybrid_graph.
"""
import logging
import time
from queue import Queue
from threading import Thread
from typing import Any, Callable, Dict, Iterator, List, Optional, Union, cast
from lalandre_core.config import get_config, get_env_settings
from lalandre_core.linking import LegalEntityLinker
from lalandre_core.utils.api_key_pool import APIKeyPool
from langchain_core.messages import BaseMessage
from langchain_core.prompts import ChatPromptTemplate
from lalandre_rag.agentic.deps import AgenticPlanningDeps
from lalandre_rag.agentic.graph import run_planning_graph
from lalandre_rag.agentic.models import PlanningContext
from lalandre_rag.agentic.models import PlanningEarlyExit as _EarlyExit
from lalandre_rag.agentic.models import PlanningResult as _PrepareResult
from lalandre_rag.graph import CommunityContextEnricher, GraphRAGService
from lalandre_rag.modes.hybrid_generation import query_global_mode, query_standard_mode
from lalandre_rag.modes.hybrid_helpers import ProgressCallback, build_plan_metadata
from lalandre_rag.modes.llm_mode import LLMMode
from lalandre_rag.response import (
build_clarification_answer,
build_no_source_blocked_answer,
create_blocked_sourced_response,
normalize_sources_payload,
)
from lalandre_rag.response.policy import (
decide_post_generation,
decide_pre_generation,
infer_citation_status,
infer_evidence_grade,
infer_intent_class,
)
from lalandre_rag.retrieval import RetrievalService
from lalandre_rag.retrieval.context.service import ContextService
from lalandre_rag.retrieval.overview import build_retrieval_overview
from lalandre_rag.retrieval.query_parser import LLMQueryParserClient
from lalandre_rag.retrieval.query_router import QueryRouter
logger = logging.getLogger(__name__)
def _yield_text_chunks(text: str, *, chunk_size: int = 160) -> Iterator[str]:
"""Yield a validated answer in small chunks for SSE clients."""
if not text:
return
for start in range(0, len(text), chunk_size):
yield text[start : start + chunk_size]
def _attach_planner_runtime_metadata(metadata: Dict[str, Any], agentic_meta: Dict[str, Any]) -> None:
planner_runtime = agentic_meta.get("planner_runtime")
planner_run_id = agentic_meta.get("planner_run_id")
planner_path = agentic_meta.get("planner_path")
tool_trace = agentic_meta.get("tool_trace")
output_validation_retries = agentic_meta.get("output_validation_retries")
if planner_runtime is not None:
metadata["planner_runtime"] = planner_runtime
if planner_run_id is not None:
metadata["planner_run_id"] = planner_run_id
if planner_path is not None:
metadata["planner_path"] = planner_path
if tool_trace is not None:
metadata["tool_trace"] = tool_trace
if output_validation_retries is not None:
metadata["output_validation_retries"] = output_validation_retries
def _build_candidate_counts(stats: Any, *, candidates_in_context: int) -> Dict[str, int]:
return {
"after_fusion": int(stats.candidates_after_fusion),
"after_threshold": int(stats.candidates_after_threshold),
"after_rerank": int(stats.candidates_after_rerank),
"in_context": int(candidates_in_context),
}
def _build_hybrid_retrieval_metadata(
*,
ctx: PlanningContext,
config: Any,
retrieval_service: RetrievalService,
question: str,
requested_top_k: int,
requested_granularity: Optional[str],
requested_include_relations: bool,
) -> Dict[str, Any]:
stats = ctx.retrieval_stats
candidate_counts = _build_candidate_counts(
stats,
candidates_in_context=len(ctx.context_slices),
)
metadata: Dict[str, Any] = {
"retrieval_plan": build_plan_metadata(
retrieval_plan=ctx.retrieval_plan,
requested_top_k=requested_top_k,
effective_top_k=ctx.effective_top_k,
requested_granularity=requested_granularity,
effective_granularity=ctx.effective_granularity,
requested_include_relations=requested_include_relations,
effective_include_relations=ctx.effective_include_relations,
retrieval_query=ctx.retrieval_query,
original_question=question,
),
"retrieval_trace": {
"embedding_model": retrieval_service.embedding_service.model_name,
"rerank_model": config.search.rerank_model if config.search.rerank_enabled else None,
"fusion_method": config.search.fusion_method,
"fusion_weights": {
"lexical": stats.fusion_lexical_weight,
"semantic": stats.fusion_semantic_weight,
},
"cache_hit": stats.cache_hit,
"query_variants_count": stats.query_variants_count,
"candidates_after_fusion": stats.candidates_after_fusion,
"candidates_after_threshold": stats.candidates_after_threshold,
"candidates_after_rerank": stats.candidates_after_rerank,
"candidates_in_context": len(ctx.context_slices),
"adaptive_cutoff_applied": stats.adaptive_cutoff_applied,
"score_threshold_applied": stats.effective_score_threshold,
"query_expansion_enabled": config.search.query_expansion_enabled,
"step_timings_ms": {
"embedding_ms": stats.embedding_ms,
"semantic_search_ms": stats.semantic_search_ms,
"lexical_search_ms": stats.lexical_search_ms,
"parallel_search_ms": stats.parallel_search_ms,
"fusion_ms": stats.fusion_ms,
"rerank_ms": stats.rerank_ms,
"total_retrieve_ms": stats.total_retrieve_ms,
},
},
"retrieval_overview": build_retrieval_overview(
ctx.context_slices,
effective_granularity=ctx.effective_granularity,
candidate_counts=candidate_counts,
),
"chunks_used": [
{
"subdivision_id": cs.doc.subdivision_id,
"chunk_id": cs.doc.chunk_id,
"celex": cs.act.celex,
"score": round(cs.score, 4) if cs.score is not None else None,
"source_kind": cs.doc.source_kind,
}
for cs in ctx.context_slices
],
}
if ctx.community_meta:
metadata["community_context"] = ctx.community_meta
if ctx.agentic_meta:
metadata["agentic"] = ctx.agentic_meta
_attach_planner_runtime_metadata(metadata, ctx.agentic_meta)
return metadata
# ── Main class ────────────────────────────────────────────────────────────────
[docs]
class HybridMode:
"""
MODE 3: retrieval + generation.
Includes a global community-aware path for broad queries.
"""
def __init__(
self,
retrieval_service: RetrievalService,
context_service: ContextService,
llm: Any,
rag_prompt: ChatPromptTemplate,
graph_rag_service: Optional[GraphRAGService] = None,
lightweight_llm: Any = None,
key_pool: Optional[APIKeyPool] = None,
entity_linker: Optional[LegalEntityLinker] = None,
external_detector: Optional[Callable[[str], Any]] = None,
) -> None:
self.retrieval_service = retrieval_service
self.context_service = context_service
self.llm = llm
self.lightweight_llm = lightweight_llm or llm
self.rag_prompt = rag_prompt
self.graph_rag_service = graph_rag_service
self.entity_linker = entity_linker
self.external_detector = external_detector
self.llm_mode = LLMMode(llm)
self.community_enricher = (
CommunityContextEnricher(getattr(graph_rag_service, "neo4j", None))
if graph_rag_service is not None
else None
)
config = get_config()
settings = get_env_settings()
intent_parser = LLMQueryParserClient.from_runtime(
config=config,
settings=settings,
key_pool=key_pool,
)
self.query_router = QueryRouter(intent_parser=intent_parser)
# ── Shared phases 1–4 ────────────────────────────────────────────────────
def _prepare_retrieval_context(
self,
question: str,
top_k: int,
score_threshold: Optional[float],
filters: Optional[Dict[str, Any]],
include_relations: bool,
include_subjects: bool,
collections: Optional[List[str]],
granularity: Optional[str],
graph_depth: Optional[int],
use_graph: Optional[bool] = None,
embedding_preset: Optional[str] = None,
retrieval_depth: Optional[str] = None,
progress_callback: ProgressCallback = None,
) -> _PrepareResult:
"""Run phases 1–4 through the PydanticAI planning graph."""
deps = AgenticPlanningDeps(
retrieval_service=cast(Any, self.retrieval_service),
context_service=cast(Any, self.context_service),
llm=self.llm,
lightweight_llm=self.lightweight_llm,
rag_prompt=self.rag_prompt,
query_router=cast(Any, self.query_router),
graph_rag_service=cast(Any, self.graph_rag_service),
community_enricher=cast(Any, self.community_enricher),
question=question,
top_k=top_k,
score_threshold=score_threshold,
filters=filters,
include_relations=include_relations,
include_subjects=include_subjects,
include_full_content=False,
return_sources=True,
collections=collections,
granularity=granularity,
graph_depth=graph_depth,
use_graph=use_graph,
embedding_preset=embedding_preset,
retrieval_depth=retrieval_depth,
chat_history=None,
progress_callback=progress_callback,
)
return run_planning_graph(deps=deps)
@staticmethod
def _mark_auto_llm_fallback(
*,
metadata: Dict[str, Any],
prepare: _EarlyExit,
total_started_at: float,
generation_ms: Optional[float] = None,
) -> None:
phase_timings = cast(Dict[str, Any], metadata.setdefault("phase_timings_ms", {}))
phase_timings["routing_ms"] = round(prepare.routing_ms, 1)
phase_timings["planner_ms"] = round(prepare.planner_ms, 1)
if prepare.retrieval_ms:
phase_timings["retrieval_ms"] = round(prepare.retrieval_ms, 1)
if generation_ms is not None:
phase_timings["generation_ms"] = round(generation_ms, 1)
phase_timings["total_ms"] = round((time.perf_counter() - total_started_at) * 1000.0, 1)
metadata["auto_mode_fallback"] = "llm_only"
metadata["auto_mode_fallback_reason"] = prepare.kind
if prepare.agentic_rationale:
metadata["auto_mode_fallback_rationale"] = prepare.agentic_rationale
metadata["skip_retrieval"] = {
"triggered": True,
"rationale": prepare.agentic_rationale,
}
if prepare.agentic_meta:
metadata["agentic"] = prepare.agentic_meta
_attach_planner_runtime_metadata(metadata, prepare.agentic_meta)
@staticmethod
def _routing_profile_from_agentic_meta(agentic_meta: Optional[Dict[str, Any]]) -> Optional[str]:
if not isinstance(agentic_meta, dict):
return None
routing = agentic_meta.get("routing")
if isinstance(routing, dict):
profile = routing.get("profile")
if isinstance(profile, str) and profile:
return profile
return None
@classmethod
def _prepare_policy_decision(
cls,
*,
question: str,
prepare: _EarlyExit,
) -> Any:
intent_class = infer_intent_class(
intent_class=prepare.intent_class,
skip_retrieval=(prepare.kind == "skip_retrieval"),
)
return decide_pre_generation(
intent_class=intent_class,
evidence_grade="none",
question=question,
retrieval_profile=cls._routing_profile_from_agentic_meta(prepare.agentic_meta),
clarification_question=prepare.clarification_question,
strict_grounding_requested=prepare.strict_grounding_requested,
)
@classmethod
def _context_policy_decision(
cls,
*,
question: str,
ctx: PlanningContext,
response_sources: Optional[Dict[str, Any]] = None,
citation_validation: Optional[Dict[str, Any]] = None,
repaired: bool = False,
) -> Any:
agentic_plan = ctx.agentic_plan
intent_class = infer_intent_class(
intent_class=(agentic_plan.intent_class if agentic_plan is not None else None),
skip_retrieval=False,
)
has_sources = bool(response_sources and response_sources.get("total", 0) > 0)
evidence_grade = infer_evidence_grade(
has_sources=has_sources,
crag_meta=cast(Optional[Dict[str, Any]], ctx.agentic_meta.get("crag")),
)
citation_status = infer_citation_status(
validation=citation_validation,
repaired=repaired,
)
return decide_post_generation(
intent_class=intent_class,
evidence_grade=evidence_grade,
citation_status=citation_status,
question=question,
has_sources=has_sources,
retrieval_profile=getattr(ctx.retrieval_plan, "profile", None),
clarification_question=(agentic_plan.clarification_question if agentic_plan is not None else None),
strict_grounding_requested=bool(
agentic_plan.strict_grounding_requested if agentic_plan is not None else False
),
)
@staticmethod
def _apply_policy_metadata(
*,
metadata: Dict[str, Any],
decision: Any,
) -> None:
metadata.update(decision.metadata())
metadata.update(decision.legacy_metadata())
@staticmethod
def _finalize_policy_response(
*,
response: Dict[str, Any],
question: str,
decision: Any,
) -> None:
metadata = cast(Dict[str, Any], response.setdefault("metadata", {}))
HybridMode._apply_policy_metadata(metadata=metadata, decision=decision)
if decision.state == "clarify":
response["answer"] = build_clarification_answer(
clarification_question=decision.clarification_question,
)
response["sources"] = None
elif decision.state == "hard_block":
response["answer"] = build_no_source_blocked_answer("rag")
response["sources"] = None
elif decision.state == "weakly_grounded":
metadata.setdefault("evidence_grade", decision.evidence_grade)
# ── Phase 5: synchronous query ────────────────────────────────────────────
[docs]
def query(
self,
question: str,
top_k: int = 10,
score_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
include_relations: bool = False,
include_subjects: bool = True,
include_full_content: bool = False,
return_sources: bool = True,
collections: Optional[List[str]] = None,
granularity: Optional[str] = None,
chat_history: Optional[List[BaseMessage]] = None,
graph_depth: Optional[int] = None,
use_graph: Optional[bool] = None,
embedding_preset: Optional[str] = None,
retrieval_depth: Optional[str] = None,
cypher_documents: Optional[List[Dict[str, Any]]] = None,
cypher_query_meta: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Run the full hybrid pipeline and return a policy-compliant response."""
total_started_at = time.perf_counter()
config = get_config()
prepare = self._prepare_retrieval_context(
question=question,
top_k=top_k,
score_threshold=score_threshold,
filters=filters,
include_relations=include_relations,
include_subjects=include_subjects,
collections=collections,
granularity=granularity,
graph_depth=graph_depth,
use_graph=use_graph,
embedding_preset=embedding_preset,
retrieval_depth=retrieval_depth,
)
# ── Early-exit handling ───────────────────────────────────────────────
if isinstance(prepare, _EarlyExit):
decision = self._prepare_policy_decision(question=question, prepare=prepare)
if decision.state == "llm_only":
llm_response = self.llm_mode.query(question=question, include_warning=True)
metadata = cast(Dict[str, Any], llm_response.get("metadata") or {})
self._mark_auto_llm_fallback(
metadata=metadata,
prepare=prepare,
total_started_at=total_started_at,
generation_ms=cast(
Optional[float],
cast(Dict[str, Any], metadata.get("phase_timings_ms") or {}).get("generation_ms"),
),
)
self._apply_policy_metadata(metadata=metadata, decision=decision)
llm_response["metadata"] = metadata
return llm_response
response = create_blocked_sourced_response(
mode="rag",
query=question,
reason=decision.reason,
answer=(
build_clarification_answer(clarification_question=decision.clarification_question)
if decision.state == "clarify"
else None
),
)
timings: Dict[str, Any] = {
"routing_ms": round(prepare.routing_ms, 1),
"planner_ms": round(prepare.planner_ms, 1),
"generation_ms": 0.0,
"total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1),
}
if prepare.retrieval_ms:
timings["retrieval_ms"] = round(prepare.retrieval_ms, 1)
response["metadata"]["phase_timings_ms"] = timings
if prepare.kind == "relevance_gate":
response["metadata"]["relevance_gate"] = {
"triggered": True,
"best_score": round(prepare.best_score, 4),
"threshold": prepare.gate_threshold,
"candidates_dropped": prepare.candidates_dropped,
}
elif prepare.kind == "empty":
response["metadata"]["empty_retrieval"] = {"triggered": True}
logger.warning("RAG query returned no retrieval results for question: %.80s", question)
if prepare.agentic_meta:
response["metadata"]["agentic"] = prepare.agentic_meta
_attach_planner_runtime_metadata(response["metadata"], prepare.agentic_meta)
self._apply_policy_metadata(metadata=cast(Dict[str, Any], response["metadata"]), decision=decision)
return response
if prepare.kind if isinstance(prepare, _EarlyExit) else False:
pass # unreachable — satisfies type checker
ctx = cast(PlanningContext, prepare)
# ── PHASE 5: Generation ───────────────────────────────────────────────
if ctx.retrieval_plan.execution_mode == "global" and not collections:
response = query_global_mode(
question=question,
context_slices=ctx.context_slices,
llm=self.llm,
rag_prompt=self.rag_prompt,
include_full_content=include_full_content,
include_subjects=include_subjects,
return_sources=return_sources,
graph_fetch=ctx.graph_fetch,
chat_history=chat_history,
entity_linker=self.entity_linker,
external_detector=self.external_detector,
lightweight_llm=self.lightweight_llm,
cypher_documents=cypher_documents,
cypher_query_meta=cypher_query_meta,
)
else:
response = query_standard_mode(
question=question,
context_slices=ctx.context_slices,
llm=self.llm,
rag_prompt=self.rag_prompt,
include_relations=ctx.effective_include_relations,
include_subjects=include_subjects,
include_full_content=include_full_content,
return_sources=return_sources,
graph_fetch=ctx.graph_fetch,
chat_history=chat_history,
entity_linker=self.entity_linker,
external_detector=self.external_detector,
lightweight_llm=self.lightweight_llm,
cypher_documents=cypher_documents,
cypher_query_meta=cypher_query_meta,
)
response["sources"] = normalize_sources_payload(cast(Optional[Dict[str, Any]], response.get("sources")))
# ── Metadata assembly ─────────────────────────────────────────────────
response_metadata = cast(Dict[str, Any], response.setdefault("metadata", {}))
phase_timings = cast(Dict[str, Any], response_metadata.setdefault("phase_timings_ms", {}))
phase_timings.update(
{
"routing_ms": round(ctx.routing_ms, 1),
"planner_ms": round(ctx.planner_ms, 1),
"retrieval_ms": round(ctx.retrieval_ms, 1),
"complementary_ms": round(ctx.complementary_ms, 1),
"compression_ms": round(ctx.compression_ms, 1),
"context_enrichment_ms": round(ctx.context_enrichment_ms, 1),
"graph_enrichment_ms": round(ctx.graph_enrichment_ms, 1),
"total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1),
}
)
response_metadata.update(
_build_hybrid_retrieval_metadata(
ctx=ctx,
config=config,
retrieval_service=self.retrieval_service,
question=question,
requested_top_k=top_k,
requested_granularity=granularity,
requested_include_relations=include_relations,
)
)
decision = self._context_policy_decision(
question=question,
ctx=ctx,
response_sources=cast(Optional[Dict[str, Any]], response.get("sources")),
citation_validation=cast(Optional[Dict[str, Any]], response_metadata.get("citation_validation")),
repaired=bool(response_metadata.get("citation_repaired")),
)
self._finalize_policy_response(
response=response,
question=question,
decision=decision,
)
return response
# ── Phase 5: streaming query ──────────────────────────────────────────────
[docs]
def stream_query(
self,
question: str,
top_k: int = 10,
score_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
include_relations: bool = False,
include_subjects: bool = True,
include_full_content: bool = False,
return_sources: bool = True,
collections: Optional[List[str]] = None,
granularity: Optional[str] = None,
chat_history: Optional[List[BaseMessage]] = None,
graph_depth: Optional[int] = None,
use_graph: Optional[bool] = None,
embedding_preset: Optional[str] = None,
retrieval_depth: Optional[str] = None,
cypher_documents: Optional[List[Dict[str, Any]]] = None,
cypher_query_meta: Optional[Dict[str, Any]] = None,
) -> Iterator[Union[Dict[str, Any], str]]:
"""Stream query with live progress events emitted from a worker thread."""
config = get_config()
queue: Queue[object] = Queue()
sentinel = object()
def _push(item: object) -> None:
queue.put(item)
def _status_callback(event: Dict[str, Any]) -> None:
_push({"_status": event})
def _worker() -> None:
try:
preamble_emitted = False
streamed_tokens = False
total_started_at = time.perf_counter()
prepare = self._prepare_retrieval_context(
question=question,
top_k=top_k,
score_threshold=score_threshold,
filters=filters,
include_relations=include_relations,
include_subjects=include_subjects,
collections=collections,
granularity=granularity,
graph_depth=graph_depth,
use_graph=use_graph,
embedding_preset=embedding_preset,
retrieval_depth=retrieval_depth,
progress_callback=_status_callback,
)
if isinstance(prepare, _EarlyExit):
decision = self._prepare_policy_decision(question=question, prepare=prepare)
if decision.state == "llm_only":
preamble_meta: Dict[str, Any] = {
"warning": (
"WARNING: Cette réponse est générée uniquement par le LLM "
"et n'est pas basée sur votre base documentaire. "
"Elle peut contenir des inexactitudes."
),
}
self._mark_auto_llm_fallback(
metadata=preamble_meta,
prepare=prepare,
total_started_at=total_started_at,
)
self._apply_policy_metadata(metadata=preamble_meta, decision=decision)
_push({"_preamble": True, "sources": None, "metadata": preamble_meta})
_push(
{
"_status": {
"phase": "generation",
"status": "active",
"label": "Réponse rédigée",
"detail": "Bascule automatique vers llm_only pour une requête non documentaire",
}
}
)
generation_started_at = time.perf_counter()
for chunk in self.llm_mode.stream_query(question=question):
_push(chunk)
_push(
{
"_status": {
"phase": "generation",
"status": "done",
"label": "Réponse générée",
"duration_ms": round((time.perf_counter() - generation_started_at) * 1000.0, 1),
}
}
)
return
response = create_blocked_sourced_response(
mode="rag",
query=question,
reason=decision.reason,
answer=(
build_clarification_answer(
clarification_question=decision.clarification_question,
)
if decision.state == "clarify"
else None
),
)
preamble_meta = cast(Dict[str, Any], response["metadata"])
preamble_meta["phase_timings_ms"] = {
"routing_ms": round(prepare.routing_ms, 1),
"planner_ms": round(prepare.planner_ms, 1),
"generation_ms": 0.0,
}
if prepare.retrieval_ms:
preamble_meta["phase_timings_ms"]["retrieval_ms"] = round(prepare.retrieval_ms, 1)
if prepare.kind == "skip_retrieval":
preamble_meta["skip_retrieval"] = True
elif prepare.kind == "relevance_gate":
preamble_meta["relevance_gate"] = True
elif prepare.kind == "empty":
preamble_meta["empty_retrieval"] = True
if prepare.agentic_meta:
preamble_meta["agentic"] = prepare.agentic_meta
_attach_planner_runtime_metadata(preamble_meta, prepare.agentic_meta)
self._apply_policy_metadata(metadata=preamble_meta, decision=decision)
_push({"_preamble": True, "sources": None, "metadata": preamble_meta})
for chunk in _yield_text_chunks(str(response["answer"] or "")):
_push(chunk)
return
ctx = prepare
pre_policy = self._context_policy_decision(
question=question,
ctx=ctx,
response_sources={"total": len(ctx.context_slices)},
)
def _emit_preamble(
sources_payload: Optional[Dict[str, Any]],
query_metadata: Dict[str, Any],
) -> None:
nonlocal preamble_emitted
merged_metadata = dict(query_metadata)
preamble_phase_timings = cast(
Dict[str, Any],
merged_metadata.setdefault("phase_timings_ms", {}),
)
preamble_phase_timings.update(
{
"routing_ms": round(ctx.routing_ms, 1),
"planner_ms": round(ctx.planner_ms, 1),
"retrieval_ms": round(ctx.retrieval_ms, 1),
"context_enrichment_ms": round(ctx.context_enrichment_ms, 1),
"graph_enrichment_ms": round(ctx.graph_enrichment_ms, 1),
"complementary_ms": round(ctx.complementary_ms, 1),
"compression_ms": round(ctx.compression_ms, 1),
}
)
stats = ctx.retrieval_stats
merged_metadata["retrieval_step_timings_ms"] = {
"parallel_search_ms": stats.parallel_search_ms,
"fusion_ms": stats.fusion_ms,
"rerank_ms": stats.rerank_ms,
"total_retrieve_ms": stats.total_retrieve_ms,
}
merged_metadata.update(
_build_hybrid_retrieval_metadata(
ctx=ctx,
config=config,
retrieval_service=self.retrieval_service,
question=question,
requested_top_k=top_k,
requested_granularity=granularity,
requested_include_relations=include_relations,
)
)
self._apply_policy_metadata(metadata=merged_metadata, decision=pre_policy)
if pre_policy.state == "weakly_grounded":
merged_metadata.setdefault("evidence_grade", pre_policy.evidence_grade)
preamble_emitted = True
_push(
{
"_preamble": True,
"sources": sources_payload,
"metadata": merged_metadata,
}
)
def _emit_token(chunk: str) -> None:
nonlocal streamed_tokens
if not chunk:
return
streamed_tokens = True
_push(chunk)
def _emit_final_answer(answer_text: str) -> None:
if streamed_tokens:
_push({"_final_answer": answer_text})
if ctx.retrieval_plan.execution_mode == "global" and not collections:
response = query_global_mode(
question=question,
context_slices=ctx.context_slices,
llm=self.llm,
rag_prompt=self.rag_prompt,
include_full_content=include_full_content,
include_subjects=include_subjects,
return_sources=return_sources,
graph_fetch=ctx.graph_fetch,
chat_history=chat_history,
progress_callback=_status_callback,
preamble_callback=_emit_preamble,
token_callback=_emit_token,
final_answer_callback=_emit_final_answer,
entity_linker=self.entity_linker,
external_detector=self.external_detector,
lightweight_llm=self.lightweight_llm,
cypher_documents=cypher_documents,
cypher_query_meta=cypher_query_meta,
)
else:
response = query_standard_mode(
question=question,
context_slices=ctx.context_slices,
llm=self.llm,
rag_prompt=self.rag_prompt,
include_relations=ctx.effective_include_relations,
include_subjects=include_subjects,
include_full_content=include_full_content,
return_sources=return_sources,
graph_fetch=ctx.graph_fetch,
chat_history=chat_history,
progress_callback=_status_callback,
preamble_callback=_emit_preamble,
token_callback=_emit_token,
final_answer_callback=_emit_final_answer,
entity_linker=self.entity_linker,
external_detector=self.external_detector,
lightweight_llm=self.lightweight_llm,
cypher_documents=cypher_documents,
cypher_query_meta=cypher_query_meta,
)
response["sources"] = normalize_sources_payload(cast(Optional[Dict[str, Any]], response.get("sources")))
response_metadata = cast(Dict[str, Any], response.setdefault("metadata", {}))
decision = self._context_policy_decision(
question=question,
ctx=ctx,
response_sources=cast(Optional[Dict[str, Any]], response.get("sources")),
citation_validation=cast(Optional[Dict[str, Any]], response_metadata.get("citation_validation")),
repaired=bool(response_metadata.get("citation_repaired")),
)
self._finalize_policy_response(
response=response,
question=question,
decision=decision,
)
if streamed_tokens and decision.state == "weakly_grounded":
_emit_final_answer(str(response.get("answer") or ""))
if not preamble_emitted:
_emit_preamble(
normalize_sources_payload(cast(Optional[Dict[str, Any]], response.get("sources"))),
cast(Dict[str, Any], response.get("metadata") or {}),
)
if not streamed_tokens:
for chunk in _yield_text_chunks(str(response.get("answer") or "")):
_push(chunk)
except Exception as exc: # pragma: no cover - surfaced through stream.py
logger.exception("Hybrid RAG streaming worker failed")
_push({"_error": exc})
finally:
_push(sentinel)
Thread(target=_worker, daemon=True).start()
while True:
item = queue.get()
if item is sentinel:
break
if isinstance(item, dict) and "_error" in item:
error = item["_error"]
if isinstance(error, Exception):
raise error
raise RuntimeError(str(error))
yield cast(Union[Dict[str, Any], str], item)