"""RAG proxy router — /query, /query/stream, /search."""
import json
import logging
import time
import uuid
from typing import Any, Optional
import httpx
from api_gateway.deps import extract_user_id, get_runtime_config
from api_gateway.rate_limit import LIMIT_QUERY, LIMIT_SEARCH, LIMIT_STREAM, limiter
from api_gateway.service_metrics import (
observe_proxy_error,
observe_query_request,
observe_search_request,
)
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from lalandre_core.utils import MODE_ALIASES, VALID_QUERY_MODES, normalize_celex
from lalandre_rag.models import QueryResponse, SearchRequest, SearchResponse
from pydantic import BaseModel, Field, model_validator
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1", tags=["rag"])
[docs]
class QueryRequest(BaseModel):
"""RAG query request.
All retrieval parameters are optional — the rag-service resolves
omitted values from ``app_config.yaml`` (single source of truth).
"""
question: str = Field(..., description="User's question")
conversation_id: Optional[str] = Field(
default=None,
description="Conversation ID for multi-turn. Omit for stateless single-turn.",
)
mode: Optional[str] = Field(
default=None,
description=(
"RAG mode: rag, llm_only, summarize, compare "
"(legacy: semantic/hybrid/lexical/graph/search). Default from config."
),
)
top_k: Optional[int] = Field(
default=None,
ge=1,
le=50,
description="Number of results to retrieve. Default from config.",
)
min_score: Optional[float] = Field(
default=None,
ge=0.0,
description="Minimum normalized retrieval score threshold (0-1). Default from config.",
)
sources: bool = Field(default=False, description="Return sources in response")
graph_depth: Optional[int] = Field(
default=None,
description="Graph depth override for graph mode (default from config)",
)
use_graph: Optional[bool] = Field(
default=None,
description="Override graph enrichment. None = use config default.",
)
graph_use_cypher: bool = Field(
default=False,
description=(
"If true (graph mode), generate a read-only Cypher query from natural "
"language and answer from Neo4j results."
),
)
graph_cypher_timeout_seconds: Optional[float] = Field(
default=None,
gt=0.0,
description="Timeout (seconds) for NL->Cypher generation. Default from config.",
)
graph_cypher_max_rows: Optional[int] = Field(
default=None,
ge=1,
le=500,
description="Maximum Cypher rows returned for graph_use_cypher mode. Default from config.",
)
graph_generation_timeout_seconds: Optional[float] = Field(
default=None,
gt=0.0,
description=("Optional server-side timeout (seconds) for graph mode answer generation (LLM phase only)."),
)
graph_retrieval_timeout_seconds: Optional[float] = Field(
default=None,
gt=0.0,
description=("Optional server-side timeout (seconds) for graph retrieval phase (semantic + Neo4j traversal)."),
)
granularity: Optional[str] = Field(
default=None,
description="Retrieval granularity: chunks or all. Default from config.",
)
embedding_preset: Optional[str] = Field(
default=None,
description="Embedding preset for semantic search and collection routing.",
)
include_full_content: bool = Field(
default=False,
description="Include full content in search results (default: false)",
)
celex: Optional[str] = Field(default=None, description="Primary CELEX (summarize/compare)")
compare_celex: Optional[str] = Field(default=None, description="Secondary CELEX (compare)")
[docs]
@model_validator(mode="after")
def validate_question(self):
"""Validate required CELEX fields for summarize and compare modes."""
mode = (self.mode or "").lower()
if mode == "summarize" and not (self.celex or "").strip():
raise ValueError("celex is required for summarize mode")
if mode == "compare":
if not (self.celex or "").strip() or not (self.compare_celex or "").strip():
raise ValueError("celex and compare_celex are required for compare mode")
if self.celex:
self.celex = normalize_celex(self.celex)
if self.compare_celex:
self.compare_celex = normalize_celex(self.compare_celex)
return self
def _infer_stream_outcome_from_metadata(metadata: dict[str, Any]) -> str:
response_policy = metadata.get("response_policy")
if isinstance(response_policy, dict):
state = response_policy.get("state")
if isinstance(state, str) and state:
return state
auto_mode_fallback = metadata.get("auto_mode_fallback")
if isinstance(auto_mode_fallback, str) and auto_mode_fallback:
return auto_mode_fallback
if metadata.get("blocked_reason"):
return "fallback"
return "success"
def _stream_error_outcome(status_code: int) -> str:
if status_code in {408, 504}:
return "timeout"
if status_code == 503:
return "unavailable"
return "client_error" if 400 <= status_code < 500 else "server_error"
[docs]
@router.post("/query", response_model=QueryResponse)
@limiter.limit(LIMIT_QUERY)
async def query(request: Request, query_request: QueryRequest):
"""Proxies user queries to the RAG Service."""
rag_service_url, rag_proxy_timeout_seconds = get_runtime_config(request)
query_id = str(uuid.uuid4())
started_at = time.perf_counter()
logger.info(
"Processing Query %s: %s... (Mode: %s)",
query_id,
query_request.question[:50],
query_request.mode,
)
effective_mode = str(query_request.mode or "unknown")
try:
async with httpx.AsyncClient(timeout=rag_proxy_timeout_seconds) as client:
payload = query_request.model_dump(exclude_none=True)
mode = (payload.get("mode") or "").lower()
if mode:
if mode not in VALID_QUERY_MODES and mode not in MODE_ALIASES:
raise HTTPException(
status_code=400,
detail=(
"Invalid mode. Must be one of: rag, llm_only, "
"summarize, compare (legacy: semantic, lexical, hybrid, graph, search)."
),
)
if mode in MODE_ALIASES:
payload["mode"] = MODE_ALIASES[mode]
effective_mode = str(payload.get("mode") or "rag")
user_id = extract_user_id(request)
headers = {"x-user-id": user_id} if user_id else {}
response = await client.post(
f"{rag_service_url}/query",
json={"query_id": query_id, **payload},
headers=headers,
)
if response.status_code != 200:
logger.error("RAG service query error (status=%d): %s", response.status_code, response.text)
if response.status_code >= 500:
raise HTTPException(status_code=502, detail="An error occurred while processing your request")
try:
detail = response.json().get("detail", "Request failed")
except Exception:
detail = "Request failed"
raise HTTPException(status_code=response.status_code, detail=detail)
observe_query_request(
mode=effective_mode,
granularity=query_request.granularity,
top_k=query_request.top_k or 0,
duration_seconds=time.perf_counter() - started_at,
outcome="success",
)
return response.json()
except httpx.TimeoutException:
logger.error("Query %s timed out.", query_id)
observe_proxy_error(endpoint="/api/v1/query", target="rag_service", exc_or_reason="timeout")
observe_query_request(
mode=effective_mode,
granularity=query_request.granularity,
top_k=query_request.top_k or 0,
duration_seconds=time.perf_counter() - started_at,
outcome="timeout",
)
raise HTTPException(status_code=504, detail="RAG service timeout - query took too long")
except httpx.RequestError as e:
logger.error("Query %s connection failed: %s", query_id, e)
observe_proxy_error(endpoint="/api/v1/query", target="rag_service", exc_or_reason=e)
observe_query_request(
mode=effective_mode,
granularity=query_request.granularity,
top_k=query_request.top_k or 0,
duration_seconds=time.perf_counter() - started_at,
outcome="unavailable",
)
raise HTTPException(status_code=503, detail="Service temporarily unavailable")
except HTTPException as exc:
if exc.status_code >= 500:
observe_proxy_error(endpoint="/api/v1/query", target="rag_service", exc_or_reason=exc.detail)
observe_query_request(
mode=effective_mode,
granularity=query_request.granularity,
top_k=query_request.top_k or 0,
duration_seconds=time.perf_counter() - started_at,
outcome="client_error" if exc.status_code < 500 else "server_error",
)
raise
[docs]
@router.post("/query/stream")
@limiter.limit(LIMIT_STREAM)
async def query_stream(request: Request, query_request: QueryRequest):
"""Proxies streaming (SSE) queries to the RAG Service."""
rag_service_url, rag_proxy_timeout_seconds = get_runtime_config(request)
query_id = str(uuid.uuid4())
started_at = time.perf_counter()
payload = query_request.model_dump(exclude_none=True)
effective_mode = str(payload.get("mode") or "rag")
mode = (payload.get("mode") or "").lower()
if mode:
if mode not in VALID_QUERY_MODES and mode not in MODE_ALIASES:
observe_query_request(
mode=effective_mode,
granularity=query_request.granularity,
top_k=query_request.top_k or 0,
duration_seconds=time.perf_counter() - started_at,
outcome="client_error",
)
raise HTTPException(
status_code=400,
detail="Invalid mode. Must be one of: rag, llm_only, summarize, compare.",
)
if mode in MODE_ALIASES:
payload["mode"] = MODE_ALIASES[mode]
effective_mode = str(payload.get("mode") or "rag")
user_id = extract_user_id(request)
async def proxy_sse():
timeout = httpx.Timeout(rag_proxy_timeout_seconds, connect=10.0)
headers = {"x-user-id": user_id} if user_id else {}
stream_outcome = "success"
request_recorded = False
current_event = ""
def _record_stream_request(outcome: str) -> None:
nonlocal request_recorded
if request_recorded:
return
observe_query_request(
mode=effective_mode,
granularity=query_request.granularity,
top_k=query_request.top_k or 0,
duration_seconds=time.perf_counter() - started_at,
outcome=outcome,
)
request_recorded = True
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
"POST",
f"{rag_service_url}/query/stream",
json={"query_id": query_id, **payload},
headers=headers,
) as response:
if response.status_code != 200:
error_body = await response.aread()
logger.error(
"RAG service stream error (status=%d): %s",
response.status_code,
error_body.decode(),
)
detail = (
"An error occurred while processing your request"
if response.status_code >= 500
else "Request failed"
)
stream_outcome = _stream_error_outcome(response.status_code)
if response.status_code >= 500:
observe_proxy_error(
endpoint="/api/v1/query/stream",
target="rag_service",
exc_or_reason=detail,
)
_record_stream_request(stream_outcome)
yield f"event: error\ndata: {json.dumps({'detail': detail, 'code': response.status_code})}\n\n"
return
async for line in response.aiter_lines():
if line.startswith("event: "):
current_event = line[7:].strip()
elif line.startswith("data: ") and current_event:
try:
data = json.loads(line[6:])
except json.JSONDecodeError:
data = None
if current_event == "metadata" and isinstance(data, dict):
stream_outcome = _infer_stream_outcome_from_metadata(data)
elif current_event == "error" and isinstance(data, dict):
raw_code = data.get("code")
if isinstance(raw_code, int):
stream_outcome = _stream_error_outcome(raw_code)
current_event = ""
yield line + "\n"
_record_stream_request(stream_outcome)
except httpx.TimeoutException:
observe_proxy_error(endpoint="/api/v1/query/stream", target="rag_service", exc_or_reason="timeout")
_record_stream_request("timeout")
yield f"event: error\ndata: {json.dumps({'detail': 'RAG service timeout', 'code': 504})}\n\n"
except httpx.RequestError as e:
logger.error("RAG service stream connection failed: %s", e)
observe_proxy_error(endpoint="/api/v1/query/stream", target="rag_service", exc_or_reason=e)
_record_stream_request("unavailable")
yield f"event: error\ndata: {json.dumps({'detail': 'Service temporarily unavailable', 'code': 503})}\n\n"
except Exception as exc:
logger.exception("RAG service stream proxy failed")
observe_proxy_error(endpoint="/api/v1/query/stream", target="rag_service", exc_or_reason=exc)
_record_stream_request("server_error")
error_payload = {
"detail": "An error occurred while processing your request",
"code": 500,
}
yield f"event: error\ndata: {json.dumps(error_payload)}\n\n"
return StreamingResponse(proxy_sse(), media_type="text/event-stream")
[docs]
@router.post("/search", response_model=SearchResponse)
@limiter.limit(LIMIT_SEARCH)
async def search(request: Request, search_request: SearchRequest):
"""Proxies search requests to the RAG Service."""
rag_service_url, rag_proxy_timeout_seconds = get_runtime_config(request)
started_at = time.perf_counter()
try:
async with httpx.AsyncClient(timeout=rag_proxy_timeout_seconds) as client:
response = await client.post(
f"{rag_service_url}/search",
json=search_request.model_dump(exclude_none=True),
)
if response.status_code != 200:
logger.error("RAG service search error (status=%d): %s", response.status_code, response.text)
if response.status_code >= 500:
raise HTTPException(status_code=502, detail="An error occurred while processing your search")
try:
detail = response.json().get("detail", "Request failed")
except Exception:
detail = "Request failed"
raise HTTPException(status_code=response.status_code, detail=detail)
observe_search_request(
mode=search_request.mode,
granularity=search_request.granularity,
top_k=search_request.top_k or 0,
duration_seconds=time.perf_counter() - started_at,
outcome="success",
)
return response.json()
except httpx.TimeoutException:
logger.error("Search request timed out.")
observe_proxy_error(endpoint="/api/v1/search", target="rag_service", exc_or_reason="timeout")
observe_search_request(
mode=search_request.mode,
granularity=search_request.granularity,
top_k=search_request.top_k or 0,
duration_seconds=time.perf_counter() - started_at,
outcome="timeout",
)
raise HTTPException(status_code=504, detail="RAG service timeout - search took too long")
except httpx.RequestError as e:
logger.error("Search request connection failed: %s", e)
observe_proxy_error(endpoint="/api/v1/search", target="rag_service", exc_or_reason=e)
observe_search_request(
mode=search_request.mode,
granularity=search_request.granularity,
top_k=search_request.top_k or 0,
duration_seconds=time.perf_counter() - started_at,
outcome="unavailable",
)
raise HTTPException(status_code=503, detail="Service temporarily unavailable")
except HTTPException as exc:
if exc.status_code >= 500:
observe_proxy_error(endpoint="/api/v1/search", target="rag_service", exc_or_reason=exc.detail)
observe_search_request(
mode=search_request.mode,
granularity=search_request.granularity,
top_k=search_request.top_k or 0,
duration_seconds=time.perf_counter() - started_at,
outcome="client_error" if exc.status_code < 500 else "server_error",
)
raise