Source code for api_gateway.routers.rag_proxy

"""RAG proxy router — /query, /query/stream, /search."""

import json
import logging
import time
import uuid
from typing import Any, Optional

import httpx
from api_gateway.deps import extract_user_id, get_runtime_config
from api_gateway.rate_limit import LIMIT_QUERY, LIMIT_SEARCH, LIMIT_STREAM, limiter
from api_gateway.service_metrics import (
    observe_proxy_error,
    observe_query_request,
    observe_search_request,
)
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from lalandre_core.utils import MODE_ALIASES, VALID_QUERY_MODES, normalize_celex
from lalandre_rag.models import QueryResponse, SearchRequest, SearchResponse
from pydantic import BaseModel, Field, model_validator

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/api/v1", tags=["rag"])


[docs] class QueryRequest(BaseModel): """RAG query request. All retrieval parameters are optional — the rag-service resolves omitted values from ``app_config.yaml`` (single source of truth). """ question: str = Field(..., description="User's question") conversation_id: Optional[str] = Field( default=None, description="Conversation ID for multi-turn. Omit for stateless single-turn.", ) mode: Optional[str] = Field( default=None, description=( "RAG mode: rag, llm_only, summarize, compare " "(legacy: semantic/hybrid/lexical/graph/search). Default from config." ), ) top_k: Optional[int] = Field( default=None, ge=1, le=50, description="Number of results to retrieve. Default from config.", ) min_score: Optional[float] = Field( default=None, ge=0.0, description="Minimum normalized retrieval score threshold (0-1). Default from config.", ) sources: bool = Field(default=False, description="Return sources in response") graph_depth: Optional[int] = Field( default=None, description="Graph depth override for graph mode (default from config)", ) use_graph: Optional[bool] = Field( default=None, description="Override graph enrichment. None = use config default.", ) graph_use_cypher: bool = Field( default=False, description=( "If true (graph mode), generate a read-only Cypher query from natural " "language and answer from Neo4j results." ), ) graph_cypher_timeout_seconds: Optional[float] = Field( default=None, gt=0.0, description="Timeout (seconds) for NL->Cypher generation. Default from config.", ) graph_cypher_max_rows: Optional[int] = Field( default=None, ge=1, le=500, description="Maximum Cypher rows returned for graph_use_cypher mode. Default from config.", ) graph_generation_timeout_seconds: Optional[float] = Field( default=None, gt=0.0, description=("Optional server-side timeout (seconds) for graph mode answer generation (LLM phase only)."), ) graph_retrieval_timeout_seconds: Optional[float] = Field( default=None, gt=0.0, description=("Optional server-side timeout (seconds) for graph retrieval phase (semantic + Neo4j traversal)."), ) granularity: Optional[str] = Field( default=None, description="Retrieval granularity: chunks or all. Default from config.", ) embedding_preset: Optional[str] = Field( default=None, description="Embedding preset for semantic search and collection routing.", ) include_full_content: bool = Field( default=False, description="Include full content in search results (default: false)", ) celex: Optional[str] = Field(default=None, description="Primary CELEX (summarize/compare)") compare_celex: Optional[str] = Field(default=None, description="Secondary CELEX (compare)")
[docs] @model_validator(mode="after") def validate_question(self): """Validate required CELEX fields for summarize and compare modes.""" mode = (self.mode or "").lower() if mode == "summarize" and not (self.celex or "").strip(): raise ValueError("celex is required for summarize mode") if mode == "compare": if not (self.celex or "").strip() or not (self.compare_celex or "").strip(): raise ValueError("celex and compare_celex are required for compare mode") if self.celex: self.celex = normalize_celex(self.celex) if self.compare_celex: self.compare_celex = normalize_celex(self.compare_celex) return self
def _infer_stream_outcome_from_metadata(metadata: dict[str, Any]) -> str: response_policy = metadata.get("response_policy") if isinstance(response_policy, dict): state = response_policy.get("state") if isinstance(state, str) and state: return state auto_mode_fallback = metadata.get("auto_mode_fallback") if isinstance(auto_mode_fallback, str) and auto_mode_fallback: return auto_mode_fallback if metadata.get("blocked_reason"): return "fallback" return "success" def _stream_error_outcome(status_code: int) -> str: if status_code in {408, 504}: return "timeout" if status_code == 503: return "unavailable" return "client_error" if 400 <= status_code < 500 else "server_error"
[docs] @router.post("/query", response_model=QueryResponse) @limiter.limit(LIMIT_QUERY) async def query(request: Request, query_request: QueryRequest): """Proxies user queries to the RAG Service.""" rag_service_url, rag_proxy_timeout_seconds = get_runtime_config(request) query_id = str(uuid.uuid4()) started_at = time.perf_counter() logger.info( "Processing Query %s: %s... (Mode: %s)", query_id, query_request.question[:50], query_request.mode, ) effective_mode = str(query_request.mode or "unknown") try: async with httpx.AsyncClient(timeout=rag_proxy_timeout_seconds) as client: payload = query_request.model_dump(exclude_none=True) mode = (payload.get("mode") or "").lower() if mode: if mode not in VALID_QUERY_MODES and mode not in MODE_ALIASES: raise HTTPException( status_code=400, detail=( "Invalid mode. Must be one of: rag, llm_only, " "summarize, compare (legacy: semantic, lexical, hybrid, graph, search)." ), ) if mode in MODE_ALIASES: payload["mode"] = MODE_ALIASES[mode] effective_mode = str(payload.get("mode") or "rag") user_id = extract_user_id(request) headers = {"x-user-id": user_id} if user_id else {} response = await client.post( f"{rag_service_url}/query", json={"query_id": query_id, **payload}, headers=headers, ) if response.status_code != 200: logger.error("RAG service query error (status=%d): %s", response.status_code, response.text) if response.status_code >= 500: raise HTTPException(status_code=502, detail="An error occurred while processing your request") try: detail = response.json().get("detail", "Request failed") except Exception: detail = "Request failed" raise HTTPException(status_code=response.status_code, detail=detail) observe_query_request( mode=effective_mode, granularity=query_request.granularity, top_k=query_request.top_k or 0, duration_seconds=time.perf_counter() - started_at, outcome="success", ) return response.json() except httpx.TimeoutException: logger.error("Query %s timed out.", query_id) observe_proxy_error(endpoint="/api/v1/query", target="rag_service", exc_or_reason="timeout") observe_query_request( mode=effective_mode, granularity=query_request.granularity, top_k=query_request.top_k or 0, duration_seconds=time.perf_counter() - started_at, outcome="timeout", ) raise HTTPException(status_code=504, detail="RAG service timeout - query took too long") except httpx.RequestError as e: logger.error("Query %s connection failed: %s", query_id, e) observe_proxy_error(endpoint="/api/v1/query", target="rag_service", exc_or_reason=e) observe_query_request( mode=effective_mode, granularity=query_request.granularity, top_k=query_request.top_k or 0, duration_seconds=time.perf_counter() - started_at, outcome="unavailable", ) raise HTTPException(status_code=503, detail="Service temporarily unavailable") except HTTPException as exc: if exc.status_code >= 500: observe_proxy_error(endpoint="/api/v1/query", target="rag_service", exc_or_reason=exc.detail) observe_query_request( mode=effective_mode, granularity=query_request.granularity, top_k=query_request.top_k or 0, duration_seconds=time.perf_counter() - started_at, outcome="client_error" if exc.status_code < 500 else "server_error", ) raise
[docs] @router.post("/query/stream") @limiter.limit(LIMIT_STREAM) async def query_stream(request: Request, query_request: QueryRequest): """Proxies streaming (SSE) queries to the RAG Service.""" rag_service_url, rag_proxy_timeout_seconds = get_runtime_config(request) query_id = str(uuid.uuid4()) started_at = time.perf_counter() payload = query_request.model_dump(exclude_none=True) effective_mode = str(payload.get("mode") or "rag") mode = (payload.get("mode") or "").lower() if mode: if mode not in VALID_QUERY_MODES and mode not in MODE_ALIASES: observe_query_request( mode=effective_mode, granularity=query_request.granularity, top_k=query_request.top_k or 0, duration_seconds=time.perf_counter() - started_at, outcome="client_error", ) raise HTTPException( status_code=400, detail="Invalid mode. Must be one of: rag, llm_only, summarize, compare.", ) if mode in MODE_ALIASES: payload["mode"] = MODE_ALIASES[mode] effective_mode = str(payload.get("mode") or "rag") user_id = extract_user_id(request) async def proxy_sse(): timeout = httpx.Timeout(rag_proxy_timeout_seconds, connect=10.0) headers = {"x-user-id": user_id} if user_id else {} stream_outcome = "success" request_recorded = False current_event = "" def _record_stream_request(outcome: str) -> None: nonlocal request_recorded if request_recorded: return observe_query_request( mode=effective_mode, granularity=query_request.granularity, top_k=query_request.top_k or 0, duration_seconds=time.perf_counter() - started_at, outcome=outcome, ) request_recorded = True try: async with httpx.AsyncClient(timeout=timeout) as client: async with client.stream( "POST", f"{rag_service_url}/query/stream", json={"query_id": query_id, **payload}, headers=headers, ) as response: if response.status_code != 200: error_body = await response.aread() logger.error( "RAG service stream error (status=%d): %s", response.status_code, error_body.decode(), ) detail = ( "An error occurred while processing your request" if response.status_code >= 500 else "Request failed" ) stream_outcome = _stream_error_outcome(response.status_code) if response.status_code >= 500: observe_proxy_error( endpoint="/api/v1/query/stream", target="rag_service", exc_or_reason=detail, ) _record_stream_request(stream_outcome) yield f"event: error\ndata: {json.dumps({'detail': detail, 'code': response.status_code})}\n\n" return async for line in response.aiter_lines(): if line.startswith("event: "): current_event = line[7:].strip() elif line.startswith("data: ") and current_event: try: data = json.loads(line[6:]) except json.JSONDecodeError: data = None if current_event == "metadata" and isinstance(data, dict): stream_outcome = _infer_stream_outcome_from_metadata(data) elif current_event == "error" and isinstance(data, dict): raw_code = data.get("code") if isinstance(raw_code, int): stream_outcome = _stream_error_outcome(raw_code) current_event = "" yield line + "\n" _record_stream_request(stream_outcome) except httpx.TimeoutException: observe_proxy_error(endpoint="/api/v1/query/stream", target="rag_service", exc_or_reason="timeout") _record_stream_request("timeout") yield f"event: error\ndata: {json.dumps({'detail': 'RAG service timeout', 'code': 504})}\n\n" except httpx.RequestError as e: logger.error("RAG service stream connection failed: %s", e) observe_proxy_error(endpoint="/api/v1/query/stream", target="rag_service", exc_or_reason=e) _record_stream_request("unavailable") yield f"event: error\ndata: {json.dumps({'detail': 'Service temporarily unavailable', 'code': 503})}\n\n" except Exception as exc: logger.exception("RAG service stream proxy failed") observe_proxy_error(endpoint="/api/v1/query/stream", target="rag_service", exc_or_reason=exc) _record_stream_request("server_error") error_payload = { "detail": "An error occurred while processing your request", "code": 500, } yield f"event: error\ndata: {json.dumps(error_payload)}\n\n" return StreamingResponse(proxy_sse(), media_type="text/event-stream")