"""POST /query endpoint — synchronous RAG query handler."""
import asyncio
import logging
import time
from typing import Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException
from lalandre_core.config import get_config
from lalandre_core.utils import VALID_QUERY_MODES, as_dict, as_document_list
from lalandre_rag.response import merge_sources_payload
from rag_service.bootstrap import RagComponents
from rag_service.graph.cypher_branch import collect_cypher_support
from rag_service.metrics import (
infer_query_outcome,
observe_provider_error,
observe_provider_fallbacks,
observe_query_request,
)
from rag_service.mode_handlers import (
handle_compare_mode,
handle_llm_only_mode,
handle_rag_mode,
handle_summarize_mode,
)
from rag_service.models import (
QueryMetadata,
QueryRequest,
QueryResponse,
SourcesResponse,
)
from rag_service.routers._deps import extract_user_id, get_components
logger = logging.getLogger(__name__)
router = APIRouter()
_VALID_QUERY_MODES = VALID_QUERY_MODES
_ALLOWED_QUERY_GRANULARITIES: set[Optional[str]] = {None, "chunks", "all"}
_CONVERSATION_SKIP_MODES = {"summarize", "compare"}
[docs]
def apply_config_defaults_query(query_request: QueryRequest) -> None:
"""Fill None fields with config defaults (mutates in place)."""
cfg = get_config().search
if query_request.mode is None:
query_request.mode = cfg.default_mode
if query_request.top_k is None:
query_request.top_k = cfg.default_limit
if query_request.granularity is None:
query_request.granularity = cfg.default_granularity
[docs]
def validate_query_mode_and_granularity(query_request: QueryRequest) -> None:
"""Validate the requested query mode and granularity combination."""
if query_request.mode not in _VALID_QUERY_MODES:
raise HTTPException(
status_code=400,
detail=(
"Invalid mode. Must be one of: rag, llm_only, "
"summarize, compare (legacy: semantic, hybrid, lexical, graph, search)."
),
)
if query_request.granularity not in _ALLOWED_QUERY_GRANULARITIES:
raise HTTPException(
status_code=400,
detail="Invalid granularity. Must be one of: chunks, all.",
)
[docs]
def build_query_response(
*,
query_request: QueryRequest,
answer: str,
sources: Optional[Dict[str, Any]],
search_metadata: Optional[Dict[str, Any]],
) -> QueryResponse:
"""Assemble the normalized `/query` response payload."""
base_metadata: QueryMetadata = {
"top_k": query_request.top_k or 0,
"num_sources": sources.get("total", 0) if sources else 0,
}
if query_request.min_score is not None:
base_metadata["min_score"] = query_request.min_score
metadata: Dict[str, Any] = {**base_metadata, **(search_metadata or {})}
parsed_sources: Optional[SourcesResponse] = None
if isinstance(sources, dict):
acts_raw = sources.get("acts")
parsed_sources = SourcesResponse(
total=int(sources.get("total", 0)),
documents=as_document_list(sources.get("documents")),
acts=as_dict(acts_raw) if isinstance(acts_raw, dict) else None,
graph_nodes=as_document_list(sources.get("graph_nodes")),
graph_edges=as_document_list(sources.get("graph_edges")),
cypher_rows=as_document_list(sources.get("cypher_rows")),
graph_query=as_dict(sources.get("graph_query")) if isinstance(sources.get("graph_query"), dict) else None,
)
return QueryResponse(
query_id=query_request.query_id,
question=query_request.question,
answer=answer,
mode=query_request.mode or "rag",
sources=parsed_sources,
metadata=metadata,
)
[docs]
@router.post("/query", response_model=QueryResponse)
async def process_query(
query_request: QueryRequest,
components: RagComponents = Depends(get_components),
user_id: Optional[str] = Depends(extract_user_id),
) -> QueryResponse:
"""
Process a RAG query
Modes:
- llm_only: LLM generation without retrieval
- rag: Standard RAG (retrieval + LLM)
- summarize: Summarization of retrieved documents
- compare: Compare two acts
"""
started_at = time.perf_counter()
rag_service = components.rag_service
conversation_manager = components.conversation_manager
try:
apply_config_defaults_query(query_request)
validate_query_mode_and_granularity(query_request)
conv_context = None
chat_history = None
if conversation_manager is not None and query_request.mode not in _CONVERSATION_SKIP_MODES:
conv_context = conversation_manager.load_history(
conversation_id=query_request.conversation_id,
question=query_request.question,
user_id=user_id,
)
chat_history = conv_context.history_messages or None
if query_request.mode == "llm_only":
answer, sources, search_metadata = await asyncio.to_thread(
lambda: handle_llm_only_mode(
query_request=query_request,
rag_service=rag_service,
)
)
elif query_request.mode == "summarize":
answer, sources, search_metadata = await asyncio.to_thread(
lambda: handle_summarize_mode(
query_request=query_request,
rag_service=rag_service,
)
)
elif query_request.mode == "compare":
answer, sources, search_metadata = await asyncio.to_thread(
lambda: handle_compare_mode(
query_request=query_request,
rag_service=rag_service,
)
)
elif query_request.mode == "rag":
answer, sources, search_metadata = await asyncio.to_thread(
lambda: handle_rag_mode(
query_request=query_request,
rag_service=rag_service,
chat_history=chat_history,
)
)
response_policy = as_dict(as_dict(search_metadata).get("response_policy"))
should_run_cypher = (
response_policy.get("state") in {"grounded", "weakly_grounded"}
if response_policy
else as_dict(search_metadata).get("auto_mode_fallback") != "llm_only"
)
if query_request.graph_use_cypher and should_run_cypher:
config = get_config()
cypher_sources, cypher_metadata = await collect_cypher_support(
query_request=query_request,
graph_rag_service=components.graph_rag_service,
rag_service=rag_service,
config=config,
graph_depth=query_request.graph_depth or config.graph.depth,
strict_graph_mode=bool(config.graph.strict_mode),
)
sources = merge_sources_payload(sources, cypher_sources)
search_metadata = {**search_metadata, **cypher_metadata}
else:
raise HTTPException(
status_code=400,
detail="Invalid mode. Must be one of: rag, llm_only, summarize, compare.",
)
response = build_query_response(
query_request=query_request,
answer=answer,
sources=sources,
search_metadata=search_metadata,
)
if conv_context is not None and conversation_manager is not None:
message_id = conversation_manager.save_turn(
conversation_id=conv_context.conversation_id,
question=query_request.question,
answer=answer,
query_id=query_request.query_id,
mode=query_request.mode,
sources=sources,
)
response.conversation_id = conv_context.conversation_id
response.message_id = message_id
metadata = dict(response.metadata)
observe_query_request(
mode=query_request.mode,
granularity=query_request.granularity,
top_k=query_request.top_k,
duration_seconds=time.perf_counter() - started_at,
outcome=infer_query_outcome(metadata),
metadata=metadata,
)
observe_provider_fallbacks(mode=query_request.mode, metadata=metadata)
return response
except HTTPException as exc:
observe_query_request(
mode=query_request.mode,
granularity=query_request.granularity,
top_k=query_request.top_k,
duration_seconds=time.perf_counter() - started_at,
outcome="client_error" if exc.status_code < 500 else "server_error",
)
if exc.status_code >= 500:
observe_provider_error(
mode=query_request.mode,
stage="request",
exc_or_reason=exc.detail,
)
raise
except Exception as e:
observe_query_request(
mode=query_request.mode,
granularity=query_request.granularity,
top_k=query_request.top_k,
duration_seconds=time.perf_counter() - started_at,
outcome="server_error",
)
observe_provider_error(
mode=query_request.mode,
stage="request",
exc_or_reason=e,
)
logger.exception("Query processing failed")
raise HTTPException(status_code=500, detail="An error occurred while processing your request")