Source code for rag_service.routers.stream

"""POST /query/stream endpoint — SSE streaming query handler."""

import asyncio
import json
import logging
import time
from typing import Any, Dict, List, Optional, cast

from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from lalandre_core.config import get_config
from lalandre_core.utils import sanitize_error_text
from lalandre_rag.response import merge_sources_payload
from rag_service.bootstrap import RagComponents
from rag_service.graph.cypher_branch import collect_cypher_support
from rag_service.metrics import (
    infer_query_outcome,
    observe_provider_error,
    observe_provider_fallbacks,
    observe_query_request,
)
from rag_service.mode_handlers import handle_compare_mode, handle_summarize_mode
from rag_service.models import QueryRequest
from rag_service.routers._deps import extract_user_id, get_components
from rag_service.routers.query import (
    _CONVERSATION_SKIP_MODES,
    apply_config_defaults_query,
    validate_query_mode_and_granularity,
)

logger = logging.getLogger(__name__)
router = APIRouter()

_HEARTBEAT_INTERVAL = 15  # seconds between SSE heartbeats


def _sse_event(event: str, data: Any) -> str:
    return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"


def _sse_comment(comment: str = "keepalive") -> str:
    return f": {comment}\n\n"


def _first_non_blank_text(*values: Any) -> str:
    for value in values:
        if value is None:
            continue
        text = str(value).strip()
        if text:
            return text
    return ""


[docs] @router.post("/query/stream") async def process_query_stream( query_request: QueryRequest, components: RagComponents = Depends(get_components), user_id: Optional[str] = Depends(extract_user_id), ): """SSE streaming variant of /query. Streams LLM tokens as they are generated.""" request_started_at = time.perf_counter() rag_service = components.rag_service conversation_manager = components.conversation_manager try: apply_config_defaults_query(query_request) validate_query_mode_and_granularity(query_request) except HTTPException as exc: observe_query_request( mode=query_request.mode, granularity=query_request.granularity, top_k=query_request.top_k, duration_seconds=time.perf_counter() - request_started_at, outcome="client_error" if exc.status_code < 500 else "server_error", ) if exc.status_code >= 500: observe_provider_error( mode=query_request.mode, stage="stream_request", exc_or_reason=exc.detail, ) raise except Exception as exc: observe_query_request( mode=query_request.mode, granularity=query_request.granularity, top_k=query_request.top_k, duration_seconds=time.perf_counter() - request_started_at, outcome="server_error", ) observe_provider_error( mode=query_request.mode, stage="stream_request", exc_or_reason=exc, ) raise assert query_request.top_k is not None # set by apply_config_defaults_query conv_context = None chat_history = None if conversation_manager is not None and query_request.mode not in _CONVERSATION_SKIP_MODES: conv_context = conversation_manager.load_history( conversation_id=query_request.conversation_id, question=query_request.question, user_id=user_id, ) chat_history = conv_context.history_messages or None _sentinel = object() async def _anext(it): """Run one next() call on a sync iterator in a thread.""" return await asyncio.to_thread(next, it, _sentinel) async def event_stream(): answer_chunks: list[str] = [] final_answer_override: Optional[str] = None phase_timings: Dict[str, Any] = {} preamble_at: Optional[float] = None sources_data: Any = None response_metadata: Dict[str, Any] = {} collected_steps: list[Dict[str, Any]] = [] def _emit_status(data: Dict[str, Any]) -> str: """Emit a status SSE event and collect the step for persistence.""" phase = data.get("phase", "") # Update existing step or append new one for existing in collected_steps: if existing.get("phase") == phase: existing.update(data) return _sse_event("status", data) collected_steps.append(dict(data)) return _sse_event("status", data) def _extract_preamble_timings(item: dict) -> None: nonlocal phase_timings, preamble_at, response_metadata preamble_at = time.perf_counter() meta = item.get("metadata") or {} response_metadata = dict(meta) if isinstance(meta, dict) else {} phase_timings = dict(response_metadata.get("phase_timings_ms") or {}) try: if query_request.mode == "llm_only": yield _sse_event( "sources", { "sources": None, "query_id": query_request.query_id, "mode": "llm_only", }, ) preamble_at = time.perf_counter() yield _emit_status( { "phase": "generation", "status": "active", "label": "Rédaction de la réponse", "detail": "Génération directe sans recherche documentaire", } ) gen = iter( rag_service.stream_query_llm_only( question=query_request.question, ) ) while True: token = await _anext(gen) if token is _sentinel: break answer_chunks.append(str(token)) yield _sse_event("token", {"t": token}) yield _emit_status( { "phase": "generation", "status": "done", "label": "Réponse générée", "duration_ms": round((time.perf_counter() - preamble_at) * 1000.0, 1), } ) elif query_request.mode == "rag": config = get_config() cypher_documents: Optional[List[Dict[str, Any]]] = None cypher_query_meta: Optional[Dict[str, Any]] = None cypher_sources_payload: Optional[Dict[str, Any]] = None cypher_meta_payload: Optional[Dict[str, Any]] = None # ── Phase 0bis : Cypher synchrone AVANT la génération ──────── # On bloque la génération pour que les [Cx] soient dans le # contexte LLM et citables comme les autres tags natifs. # Coût latence : +500-2000ms sur les questions graph. if query_request.graph_use_cypher: yield _emit_status( { "phase": "cypher", "status": "active", "label": "Appui Cypher", "detail": "Interrogation Neo4j pour ancrer la réponse", } ) try: cypher_sources_payload, cypher_meta_payload = await collect_cypher_support( query_request=query_request, graph_rag_service=components.graph_rag_service, rag_service=rag_service, config=config, graph_depth=query_request.graph_depth or config.graph.depth, strict_graph_mode=bool(config.graph.strict_mode), ) except Exception as exc: # pragma: no cover - defensive logger.warning("cypher pre-fetch failed (non-fatal): %s", exc) cypher_sources_payload, cypher_meta_payload = ( None, {"graph_cypher_support": {"enabled": True, "status": "error"}}, ) if cypher_sources_payload is not None: cypher_documents = cypher_sources_payload.get("cypher_rows") or None cypher_query_meta = cypher_sources_payload.get("graph_query") or None cypher_support = ( cypher_meta_payload.get("graph_cypher_support") if isinstance(cypher_meta_payload, dict) else None ) or {} cypher_status = str(cypher_support.get("status") or "unknown") if cypher_status == "ok": row_count = cypher_support.get("row_count") strategy = cypher_support.get("strategy") or "" detail_parts: list[str] = [] if isinstance(strategy, str) and strategy: detail_parts.append(strategy) if isinstance(row_count, int): detail_parts.append(f"{row_count} ligne{'s' if row_count != 1 else ''} Neo4j") yield _emit_status( { "phase": "cypher", "status": "done", "label": "Appui Cypher prêt", "detail": " · ".join(detail_parts) or "Résultats Cypher disponibles", "count": row_count if isinstance(row_count, int) else None, "meta": cypher_support, } ) cypher_phase_timings = cypher_support.get("phase_timings_ms") if isinstance(cypher_phase_timings, dict): phase_timings["cypher_support_ms"] = float( cypher_phase_timings.get("total_graph_pipeline_ms") or 0.0 ) elif cypher_status == "skipped": yield _emit_status( { "phase": "cypher", "status": "done", "label": "Appui Cypher non utilisé", "detail": _first_non_blank_text( cypher_support.get("reason"), cypher_support.get("detail"), ) or "Appui Cypher non utilisé pour cette question", "meta": cypher_support, } ) else: yield _emit_status( { "phase": "cypher", "status": "done", "label": "Appui Cypher indisponible", "detail": _first_non_blank_text( cypher_support.get("error"), cypher_support.get("reason"), ) or "La branche Cypher n'a pas pu être exploitée", "meta": cypher_support, } ) # Legacy parallel-cypher state retained as no-ops for compatibility # with the rest of this function (now Cypher is fully resolved). cypher_task: asyncio.Task[tuple[Optional[Dict[str, Any]], Dict[str, Any]]] | None = None cypher_emitted = True def _can_run_cypher(preamble_metadata: Optional[Dict[str, Any]] = None) -> bool: return False def _ensure_cypher_task(preamble_metadata: Optional[Dict[str, Any]] = None) -> None: nonlocal cypher_task if cypher_task is not None or not _can_run_cypher(preamble_metadata): return yield_status = { "phase": "cypher", "status": "active", "label": "Préparation de l'appui Cypher", "detail": "Interrogation complémentaire de Neo4j en parallèle", } pending_events.append(_emit_status(yield_status)) cypher_task = asyncio.create_task( collect_cypher_support( query_request=query_request, graph_rag_service=components.graph_rag_service, rag_service=rag_service, config=config, graph_depth=query_request.graph_depth or config.graph.depth, strict_graph_mode=bool(config.graph.strict_mode), ) ) async def _flush_cypher_support() -> list[str]: nonlocal sources_data, cypher_emitted if cypher_task is None or cypher_emitted: return [] cypher_sources, cypher_meta = await cypher_task cypher_emitted = True support = cast(Dict[str, Any], cypher_meta.get("graph_cypher_support") or {}) events: list[str] = [] status_val = str(support.get("status") or "unknown") if status_val == "ok": row_count_raw = support.get("row_count") row_count = int(row_count_raw) if isinstance(row_count_raw, int) else None detail_parts = [] strategy = support.get("strategy") if isinstance(strategy, str) and strategy: detail_parts.append(strategy) if row_count is not None: detail_parts.append(f"{row_count} ligne{'s' if row_count != 1 else ''} Neo4j") events.append( _emit_status( { "phase": "cypher", "status": "done", "label": "Appui Cypher terminé", "detail": " · ".join(detail_parts) or "Résultats Cypher disponibles", "count": row_count, "meta": support, } ) ) merged_sources = merge_sources_payload( sources_data if isinstance(sources_data, dict) else None, cypher_sources, ) if merged_sources is not None: sources_data = merged_sources events.append( _sse_event( "sources", { "sources": sources_data, "query_id": query_request.query_id, "mode": "rag", }, ) ) cypher_phase = cast( Optional[Dict[str, Any]], support.get("phase_timings_ms"), ) if isinstance(cypher_phase, dict): phase_timings["cypher_support_ms"] = float( cypher_phase.get("total_graph_pipeline_ms") or 0.0 ) elif status_val == "skipped": detail = ( _first_non_blank_text( support.get("reason"), support.get("detail"), ) or "Appui Cypher non utilisé pour cette question" ) events.append( _emit_status( { "phase": "cypher", "status": "done", "label": "Appui Cypher non utilisé", "detail": detail, "meta": support, } ) ) else: detail = ( _first_non_blank_text( support.get("error"), support.get("reason"), ) or "La branche Cypher n'a pas pu être exploitée" ) events.append( _emit_status( { "phase": "cypher", "status": "done", "label": "Appui Cypher indisponible", "detail": detail, "meta": support, } ) ) return events gen = iter( rag_service.stream_query( question=query_request.question, top_k=cast(int, query_request.top_k), score_threshold=query_request.min_score, return_sources=True, include_full_content=bool(query_request.include_full_content), granularity=query_request.granularity, chat_history=chat_history, graph_depth=query_request.graph_depth, use_graph=query_request.use_graph, embedding_preset=query_request.embedding_preset, retrieval_depth=query_request.retrieval_depth, cypher_documents=cypher_documents, cypher_query_meta=cypher_query_meta, ) ) pending_events: list[str] = [] loop = asyncio.get_event_loop() first_future = loop.run_in_executor(None, next, gen, _sentinel) first_item: Any = None waiting_for_first_item = True while waiting_for_first_item: try: first_item = await asyncio.wait_for( asyncio.shield(first_future), timeout=_HEARTBEAT_INTERVAL, ) waiting_for_first_item = False except asyncio.TimeoutError: if cypher_task is not None and cypher_task.done() and not cypher_emitted: for event in await _flush_cypher_support(): yield event yield _sse_comment() item = first_item while item is not _sentinel: if isinstance(item, dict) and "_status" in item: yield _emit_status(item["_status"]) if cypher_task is not None and cypher_task.done() and not cypher_emitted: for event in await _flush_cypher_support(): yield event item = await _anext(gen) continue if isinstance(item, dict) and "_final_answer" in item: final_answer_override = str(item["_final_answer"] or "") answer_chunks = [final_answer_override] yield _sse_event("final_answer", {"answer": final_answer_override}) item = await _anext(gen) continue if isinstance(item, dict) and item.get("_preamble"): _extract_preamble_timings(item) _ensure_cypher_task(cast(Optional[Dict[str, Any]], item.get("metadata"))) while pending_events: yield pending_events.pop(0) sources_data = item.get("sources") yield _sse_event( "sources", { "sources": sources_data, "query_id": query_request.query_id, "mode": "rag", }, ) break if not isinstance(item, dict): answer_chunks.append(str(item)) yield _sse_event("token", {"t": item}) item = await _anext(gen) while True: if cypher_task is not None and cypher_task.done() and not cypher_emitted: for event in await _flush_cypher_support(): yield event item = await _anext(gen) if item is _sentinel: break if isinstance(item, dict): if "_status" in item: yield _emit_status(item["_status"]) continue if "_final_answer" in item: final_answer_override = str(item["_final_answer"] or "") answer_chunks = [final_answer_override] yield _sse_event("final_answer", {"answer": final_answer_override}) continue answer_chunks.append(str(item)) yield _sse_event("token", {"t": item}) if cypher_task is not None and not cypher_emitted: for event in await _flush_cypher_support(): yield event elif query_request.mode in ("summarize", "compare"): # Fallback: run non-streaming, emit full answer as one token if query_request.mode == "summarize": answer, sources, meta = await asyncio.to_thread( lambda: handle_summarize_mode( query_request=query_request, rag_service=rag_service, ) ) else: answer, sources, meta = await asyncio.to_thread( lambda: handle_compare_mode( query_request=query_request, rag_service=rag_service, ) ) sources_data = sources preamble_at = time.perf_counter() yield _sse_event( "sources", { "sources": sources_data, "query_id": query_request.query_id, "mode": query_request.mode, }, ) answer_chunks.append(answer) yield _sse_event("token", {"t": answer}) else: observe_query_request( mode=query_request.mode, granularity=query_request.granularity, top_k=query_request.top_k, duration_seconds=time.perf_counter() - request_started_at, outcome="client_error", ) yield _sse_event( "error", { "detail": f"Streaming not supported for mode '{query_request.mode}'", "code": 400, }, ) return except Exception as exc: observe_query_request( mode=query_request.mode, granularity=query_request.granularity, top_k=query_request.top_k, duration_seconds=time.perf_counter() - request_started_at, outcome="server_error", ) observe_provider_error( mode=query_request.mode, stage="stream", exc_or_reason=exc, ) logger.exception("Streaming query failed") yield _sse_event("error", {"detail": sanitize_error_text(exc), "code": 500}) return total_ms = round((time.perf_counter() - request_started_at) * 1000.0, 1) generation_ms = ( round((time.perf_counter() - preamble_at) * 1000.0, 1) if preamble_at else float(phase_timings.get("generation_ms") or 0.0) ) if "generation_ms" in phase_timings and generation_ms <= 0.0: generation_ms = float(phase_timings.get("generation_ms") or 0.0) final_timings = { **phase_timings, "generation_ms": generation_ms, "total_ms": total_ms, } yield _sse_event( "timings", final_timings, ) full_answer = final_answer_override if final_answer_override is not None else "".join(answer_chunks) metadata: Dict[str, Any] = { key: response_metadata[key] for key in ( "response_policy", "auto_mode_fallback", "blocked_reason", "blocked_no_sources", ) if key in response_metadata } if conv_context is not None and conversation_manager is not None: message_id = conversation_manager.save_turn( conversation_id=conv_context.conversation_id, question=query_request.question, answer=full_answer, query_id=query_request.query_id, mode=query_request.mode, sources=sources_data, timings=final_timings, steps=collected_steps, ) metadata["conversation_id"] = conv_context.conversation_id metadata["message_id"] = message_id metric_metadata = {**response_metadata, "phase_timings_ms": final_timings} observe_query_request( mode=query_request.mode, granularity=query_request.granularity, top_k=query_request.top_k, duration_seconds=time.perf_counter() - request_started_at, outcome=infer_query_outcome(metric_metadata), metadata=metric_metadata, ) observe_provider_fallbacks(mode=query_request.mode, metadata=metric_metadata) yield _sse_event("metadata", metadata) yield _sse_event("done", {}) return StreamingResponse(event_stream(), media_type="text/event-stream")