"""POST /query/stream endpoint — SSE streaming query handler."""
import asyncio
import json
import logging
import time
from typing import Any, Dict, List, Optional, cast
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from lalandre_core.config import get_config
from lalandre_core.utils import sanitize_error_text
from lalandre_rag.response import merge_sources_payload
from rag_service.bootstrap import RagComponents
from rag_service.graph.cypher_branch import collect_cypher_support
from rag_service.metrics import (
infer_query_outcome,
observe_provider_error,
observe_provider_fallbacks,
observe_query_request,
)
from rag_service.mode_handlers import handle_compare_mode, handle_summarize_mode
from rag_service.models import QueryRequest
from rag_service.routers._deps import extract_user_id, get_components
from rag_service.routers.query import (
_CONVERSATION_SKIP_MODES,
apply_config_defaults_query,
validate_query_mode_and_granularity,
)
logger = logging.getLogger(__name__)
router = APIRouter()
_HEARTBEAT_INTERVAL = 15 # seconds between SSE heartbeats
def _sse_event(event: str, data: Any) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def _sse_comment(comment: str = "keepalive") -> str:
return f": {comment}\n\n"
def _first_non_blank_text(*values: Any) -> str:
for value in values:
if value is None:
continue
text = str(value).strip()
if text:
return text
return ""
[docs]
@router.post("/query/stream")
async def process_query_stream(
query_request: QueryRequest,
components: RagComponents = Depends(get_components),
user_id: Optional[str] = Depends(extract_user_id),
):
"""SSE streaming variant of /query. Streams LLM tokens as they are generated."""
request_started_at = time.perf_counter()
rag_service = components.rag_service
conversation_manager = components.conversation_manager
try:
apply_config_defaults_query(query_request)
validate_query_mode_and_granularity(query_request)
except HTTPException as exc:
observe_query_request(
mode=query_request.mode,
granularity=query_request.granularity,
top_k=query_request.top_k,
duration_seconds=time.perf_counter() - request_started_at,
outcome="client_error" if exc.status_code < 500 else "server_error",
)
if exc.status_code >= 500:
observe_provider_error(
mode=query_request.mode,
stage="stream_request",
exc_or_reason=exc.detail,
)
raise
except Exception as exc:
observe_query_request(
mode=query_request.mode,
granularity=query_request.granularity,
top_k=query_request.top_k,
duration_seconds=time.perf_counter() - request_started_at,
outcome="server_error",
)
observe_provider_error(
mode=query_request.mode,
stage="stream_request",
exc_or_reason=exc,
)
raise
assert query_request.top_k is not None # set by apply_config_defaults_query
conv_context = None
chat_history = None
if conversation_manager is not None and query_request.mode not in _CONVERSATION_SKIP_MODES:
conv_context = conversation_manager.load_history(
conversation_id=query_request.conversation_id,
question=query_request.question,
user_id=user_id,
)
chat_history = conv_context.history_messages or None
_sentinel = object()
async def _anext(it):
"""Run one next() call on a sync iterator in a thread."""
return await asyncio.to_thread(next, it, _sentinel)
async def event_stream():
answer_chunks: list[str] = []
final_answer_override: Optional[str] = None
phase_timings: Dict[str, Any] = {}
preamble_at: Optional[float] = None
sources_data: Any = None
response_metadata: Dict[str, Any] = {}
collected_steps: list[Dict[str, Any]] = []
def _emit_status(data: Dict[str, Any]) -> str:
"""Emit a status SSE event and collect the step for persistence."""
phase = data.get("phase", "")
# Update existing step or append new one
for existing in collected_steps:
if existing.get("phase") == phase:
existing.update(data)
return _sse_event("status", data)
collected_steps.append(dict(data))
return _sse_event("status", data)
def _extract_preamble_timings(item: dict) -> None:
nonlocal phase_timings, preamble_at, response_metadata
preamble_at = time.perf_counter()
meta = item.get("metadata") or {}
response_metadata = dict(meta) if isinstance(meta, dict) else {}
phase_timings = dict(response_metadata.get("phase_timings_ms") or {})
try:
if query_request.mode == "llm_only":
yield _sse_event(
"sources",
{
"sources": None,
"query_id": query_request.query_id,
"mode": "llm_only",
},
)
preamble_at = time.perf_counter()
yield _emit_status(
{
"phase": "generation",
"status": "active",
"label": "Rédaction de la réponse",
"detail": "Génération directe sans recherche documentaire",
}
)
gen = iter(
rag_service.stream_query_llm_only(
question=query_request.question,
)
)
while True:
token = await _anext(gen)
if token is _sentinel:
break
answer_chunks.append(str(token))
yield _sse_event("token", {"t": token})
yield _emit_status(
{
"phase": "generation",
"status": "done",
"label": "Réponse générée",
"duration_ms": round((time.perf_counter() - preamble_at) * 1000.0, 1),
}
)
elif query_request.mode == "rag":
config = get_config()
cypher_documents: Optional[List[Dict[str, Any]]] = None
cypher_query_meta: Optional[Dict[str, Any]] = None
cypher_sources_payload: Optional[Dict[str, Any]] = None
cypher_meta_payload: Optional[Dict[str, Any]] = None
# ── Phase 0bis : Cypher synchrone AVANT la génération ────────
# On bloque la génération pour que les [Cx] soient dans le
# contexte LLM et citables comme les autres tags natifs.
# Coût latence : +500-2000ms sur les questions graph.
if query_request.graph_use_cypher:
yield _emit_status(
{
"phase": "cypher",
"status": "active",
"label": "Appui Cypher",
"detail": "Interrogation Neo4j pour ancrer la réponse",
}
)
try:
cypher_sources_payload, cypher_meta_payload = await collect_cypher_support(
query_request=query_request,
graph_rag_service=components.graph_rag_service,
rag_service=rag_service,
config=config,
graph_depth=query_request.graph_depth or config.graph.depth,
strict_graph_mode=bool(config.graph.strict_mode),
)
except Exception as exc: # pragma: no cover - defensive
logger.warning("cypher pre-fetch failed (non-fatal): %s", exc)
cypher_sources_payload, cypher_meta_payload = (
None,
{"graph_cypher_support": {"enabled": True, "status": "error"}},
)
if cypher_sources_payload is not None:
cypher_documents = cypher_sources_payload.get("cypher_rows") or None
cypher_query_meta = cypher_sources_payload.get("graph_query") or None
cypher_support = (
cypher_meta_payload.get("graph_cypher_support")
if isinstance(cypher_meta_payload, dict)
else None
) or {}
cypher_status = str(cypher_support.get("status") or "unknown")
if cypher_status == "ok":
row_count = cypher_support.get("row_count")
strategy = cypher_support.get("strategy") or ""
detail_parts: list[str] = []
if isinstance(strategy, str) and strategy:
detail_parts.append(strategy)
if isinstance(row_count, int):
detail_parts.append(f"{row_count} ligne{'s' if row_count != 1 else ''} Neo4j")
yield _emit_status(
{
"phase": "cypher",
"status": "done",
"label": "Appui Cypher prêt",
"detail": " · ".join(detail_parts) or "Résultats Cypher disponibles",
"count": row_count if isinstance(row_count, int) else None,
"meta": cypher_support,
}
)
cypher_phase_timings = cypher_support.get("phase_timings_ms")
if isinstance(cypher_phase_timings, dict):
phase_timings["cypher_support_ms"] = float(
cypher_phase_timings.get("total_graph_pipeline_ms") or 0.0
)
elif cypher_status == "skipped":
yield _emit_status(
{
"phase": "cypher",
"status": "done",
"label": "Appui Cypher non utilisé",
"detail": _first_non_blank_text(
cypher_support.get("reason"),
cypher_support.get("detail"),
)
or "Appui Cypher non utilisé pour cette question",
"meta": cypher_support,
}
)
else:
yield _emit_status(
{
"phase": "cypher",
"status": "done",
"label": "Appui Cypher indisponible",
"detail": _first_non_blank_text(
cypher_support.get("error"),
cypher_support.get("reason"),
)
or "La branche Cypher n'a pas pu être exploitée",
"meta": cypher_support,
}
)
# Legacy parallel-cypher state retained as no-ops for compatibility
# with the rest of this function (now Cypher is fully resolved).
cypher_task: asyncio.Task[tuple[Optional[Dict[str, Any]], Dict[str, Any]]] | None = None
cypher_emitted = True
def _can_run_cypher(preamble_metadata: Optional[Dict[str, Any]] = None) -> bool:
return False
def _ensure_cypher_task(preamble_metadata: Optional[Dict[str, Any]] = None) -> None:
nonlocal cypher_task
if cypher_task is not None or not _can_run_cypher(preamble_metadata):
return
yield_status = {
"phase": "cypher",
"status": "active",
"label": "Préparation de l'appui Cypher",
"detail": "Interrogation complémentaire de Neo4j en parallèle",
}
pending_events.append(_emit_status(yield_status))
cypher_task = asyncio.create_task(
collect_cypher_support(
query_request=query_request,
graph_rag_service=components.graph_rag_service,
rag_service=rag_service,
config=config,
graph_depth=query_request.graph_depth or config.graph.depth,
strict_graph_mode=bool(config.graph.strict_mode),
)
)
async def _flush_cypher_support() -> list[str]:
nonlocal sources_data, cypher_emitted
if cypher_task is None or cypher_emitted:
return []
cypher_sources, cypher_meta = await cypher_task
cypher_emitted = True
support = cast(Dict[str, Any], cypher_meta.get("graph_cypher_support") or {})
events: list[str] = []
status_val = str(support.get("status") or "unknown")
if status_val == "ok":
row_count_raw = support.get("row_count")
row_count = int(row_count_raw) if isinstance(row_count_raw, int) else None
detail_parts = []
strategy = support.get("strategy")
if isinstance(strategy, str) and strategy:
detail_parts.append(strategy)
if row_count is not None:
detail_parts.append(f"{row_count} ligne{'s' if row_count != 1 else ''} Neo4j")
events.append(
_emit_status(
{
"phase": "cypher",
"status": "done",
"label": "Appui Cypher terminé",
"detail": " · ".join(detail_parts) or "Résultats Cypher disponibles",
"count": row_count,
"meta": support,
}
)
)
merged_sources = merge_sources_payload(
sources_data if isinstance(sources_data, dict) else None,
cypher_sources,
)
if merged_sources is not None:
sources_data = merged_sources
events.append(
_sse_event(
"sources",
{
"sources": sources_data,
"query_id": query_request.query_id,
"mode": "rag",
},
)
)
cypher_phase = cast(
Optional[Dict[str, Any]],
support.get("phase_timings_ms"),
)
if isinstance(cypher_phase, dict):
phase_timings["cypher_support_ms"] = float(
cypher_phase.get("total_graph_pipeline_ms") or 0.0
)
elif status_val == "skipped":
detail = (
_first_non_blank_text(
support.get("reason"),
support.get("detail"),
)
or "Appui Cypher non utilisé pour cette question"
)
events.append(
_emit_status(
{
"phase": "cypher",
"status": "done",
"label": "Appui Cypher non utilisé",
"detail": detail,
"meta": support,
}
)
)
else:
detail = (
_first_non_blank_text(
support.get("error"),
support.get("reason"),
)
or "La branche Cypher n'a pas pu être exploitée"
)
events.append(
_emit_status(
{
"phase": "cypher",
"status": "done",
"label": "Appui Cypher indisponible",
"detail": detail,
"meta": support,
}
)
)
return events
gen = iter(
rag_service.stream_query(
question=query_request.question,
top_k=cast(int, query_request.top_k),
score_threshold=query_request.min_score,
return_sources=True,
include_full_content=bool(query_request.include_full_content),
granularity=query_request.granularity,
chat_history=chat_history,
graph_depth=query_request.graph_depth,
use_graph=query_request.use_graph,
embedding_preset=query_request.embedding_preset,
retrieval_depth=query_request.retrieval_depth,
cypher_documents=cypher_documents,
cypher_query_meta=cypher_query_meta,
)
)
pending_events: list[str] = []
loop = asyncio.get_event_loop()
first_future = loop.run_in_executor(None, next, gen, _sentinel)
first_item: Any = None
waiting_for_first_item = True
while waiting_for_first_item:
try:
first_item = await asyncio.wait_for(
asyncio.shield(first_future),
timeout=_HEARTBEAT_INTERVAL,
)
waiting_for_first_item = False
except asyncio.TimeoutError:
if cypher_task is not None and cypher_task.done() and not cypher_emitted:
for event in await _flush_cypher_support():
yield event
yield _sse_comment()
item = first_item
while item is not _sentinel:
if isinstance(item, dict) and "_status" in item:
yield _emit_status(item["_status"])
if cypher_task is not None and cypher_task.done() and not cypher_emitted:
for event in await _flush_cypher_support():
yield event
item = await _anext(gen)
continue
if isinstance(item, dict) and "_final_answer" in item:
final_answer_override = str(item["_final_answer"] or "")
answer_chunks = [final_answer_override]
yield _sse_event("final_answer", {"answer": final_answer_override})
item = await _anext(gen)
continue
if isinstance(item, dict) and item.get("_preamble"):
_extract_preamble_timings(item)
_ensure_cypher_task(cast(Optional[Dict[str, Any]], item.get("metadata")))
while pending_events:
yield pending_events.pop(0)
sources_data = item.get("sources")
yield _sse_event(
"sources",
{
"sources": sources_data,
"query_id": query_request.query_id,
"mode": "rag",
},
)
break
if not isinstance(item, dict):
answer_chunks.append(str(item))
yield _sse_event("token", {"t": item})
item = await _anext(gen)
while True:
if cypher_task is not None and cypher_task.done() and not cypher_emitted:
for event in await _flush_cypher_support():
yield event
item = await _anext(gen)
if item is _sentinel:
break
if isinstance(item, dict):
if "_status" in item:
yield _emit_status(item["_status"])
continue
if "_final_answer" in item:
final_answer_override = str(item["_final_answer"] or "")
answer_chunks = [final_answer_override]
yield _sse_event("final_answer", {"answer": final_answer_override})
continue
answer_chunks.append(str(item))
yield _sse_event("token", {"t": item})
if cypher_task is not None and not cypher_emitted:
for event in await _flush_cypher_support():
yield event
elif query_request.mode in ("summarize", "compare"):
# Fallback: run non-streaming, emit full answer as one token
if query_request.mode == "summarize":
answer, sources, meta = await asyncio.to_thread(
lambda: handle_summarize_mode(
query_request=query_request,
rag_service=rag_service,
)
)
else:
answer, sources, meta = await asyncio.to_thread(
lambda: handle_compare_mode(
query_request=query_request,
rag_service=rag_service,
)
)
sources_data = sources
preamble_at = time.perf_counter()
yield _sse_event(
"sources",
{
"sources": sources_data,
"query_id": query_request.query_id,
"mode": query_request.mode,
},
)
answer_chunks.append(answer)
yield _sse_event("token", {"t": answer})
else:
observe_query_request(
mode=query_request.mode,
granularity=query_request.granularity,
top_k=query_request.top_k,
duration_seconds=time.perf_counter() - request_started_at,
outcome="client_error",
)
yield _sse_event(
"error",
{
"detail": f"Streaming not supported for mode '{query_request.mode}'",
"code": 400,
},
)
return
except Exception as exc:
observe_query_request(
mode=query_request.mode,
granularity=query_request.granularity,
top_k=query_request.top_k,
duration_seconds=time.perf_counter() - request_started_at,
outcome="server_error",
)
observe_provider_error(
mode=query_request.mode,
stage="stream",
exc_or_reason=exc,
)
logger.exception("Streaming query failed")
yield _sse_event("error", {"detail": sanitize_error_text(exc), "code": 500})
return
total_ms = round((time.perf_counter() - request_started_at) * 1000.0, 1)
generation_ms = (
round((time.perf_counter() - preamble_at) * 1000.0, 1)
if preamble_at
else float(phase_timings.get("generation_ms") or 0.0)
)
if "generation_ms" in phase_timings and generation_ms <= 0.0:
generation_ms = float(phase_timings.get("generation_ms") or 0.0)
final_timings = {
**phase_timings,
"generation_ms": generation_ms,
"total_ms": total_ms,
}
yield _sse_event(
"timings",
final_timings,
)
full_answer = final_answer_override if final_answer_override is not None else "".join(answer_chunks)
metadata: Dict[str, Any] = {
key: response_metadata[key]
for key in (
"response_policy",
"auto_mode_fallback",
"blocked_reason",
"blocked_no_sources",
)
if key in response_metadata
}
if conv_context is not None and conversation_manager is not None:
message_id = conversation_manager.save_turn(
conversation_id=conv_context.conversation_id,
question=query_request.question,
answer=full_answer,
query_id=query_request.query_id,
mode=query_request.mode,
sources=sources_data,
timings=final_timings,
steps=collected_steps,
)
metadata["conversation_id"] = conv_context.conversation_id
metadata["message_id"] = message_id
metric_metadata = {**response_metadata, "phase_timings_ms": final_timings}
observe_query_request(
mode=query_request.mode,
granularity=query_request.granularity,
top_k=query_request.top_k,
duration_seconds=time.perf_counter() - request_started_at,
outcome=infer_query_outcome(metric_metadata),
metadata=metric_metadata,
)
observe_provider_fallbacks(mode=query_request.mode, metadata=metric_metadata)
yield _sse_event("metadata", metadata)
yield _sse_event("done", {})
return StreamingResponse(event_stream(), media_type="text/event-stream")