Source code for rag_service.routers.query

"""POST /query endpoint — synchronous RAG query handler."""

import asyncio
import logging
import time
from typing import Any, Dict, Optional

from fastapi import APIRouter, Depends, HTTPException
from lalandre_core.config import get_config
from lalandre_core.utils import VALID_QUERY_MODES, as_dict, as_document_list
from lalandre_rag.response import merge_sources_payload
from rag_service.bootstrap import RagComponents
from rag_service.graph.cypher_branch import collect_cypher_support
from rag_service.metrics import (
    infer_query_outcome,
    observe_provider_error,
    observe_provider_fallbacks,
    observe_query_request,
)
from rag_service.mode_handlers import (
    handle_compare_mode,
    handle_llm_only_mode,
    handle_rag_mode,
    handle_summarize_mode,
)
from rag_service.models import (
    QueryMetadata,
    QueryRequest,
    QueryResponse,
    SourcesResponse,
)
from rag_service.routers._deps import extract_user_id, get_components

logger = logging.getLogger(__name__)
router = APIRouter()

_VALID_QUERY_MODES = VALID_QUERY_MODES
_ALLOWED_QUERY_GRANULARITIES: set[Optional[str]] = {None, "chunks", "all"}
_CONVERSATION_SKIP_MODES = {"summarize", "compare"}


[docs] def apply_config_defaults_query(query_request: QueryRequest) -> None: """Fill None fields with config defaults (mutates in place).""" cfg = get_config().search if query_request.mode is None: query_request.mode = cfg.default_mode if query_request.top_k is None: query_request.top_k = cfg.default_limit if query_request.granularity is None: query_request.granularity = cfg.default_granularity
[docs] def validate_query_mode_and_granularity(query_request: QueryRequest) -> None: """Validate the requested query mode and granularity combination.""" if query_request.mode not in _VALID_QUERY_MODES: raise HTTPException( status_code=400, detail=( "Invalid mode. Must be one of: rag, llm_only, " "summarize, compare (legacy: semantic, hybrid, lexical, graph, search)." ), ) if query_request.granularity not in _ALLOWED_QUERY_GRANULARITIES: raise HTTPException( status_code=400, detail="Invalid granularity. Must be one of: chunks, all.", )
[docs] def build_query_response( *, query_request: QueryRequest, answer: str, sources: Optional[Dict[str, Any]], search_metadata: Optional[Dict[str, Any]], ) -> QueryResponse: """Assemble the normalized `/query` response payload.""" base_metadata: QueryMetadata = { "top_k": query_request.top_k or 0, "num_sources": sources.get("total", 0) if sources else 0, } if query_request.min_score is not None: base_metadata["min_score"] = query_request.min_score metadata: Dict[str, Any] = {**base_metadata, **(search_metadata or {})} parsed_sources: Optional[SourcesResponse] = None if isinstance(sources, dict): acts_raw = sources.get("acts") parsed_sources = SourcesResponse( total=int(sources.get("total", 0)), documents=as_document_list(sources.get("documents")), acts=as_dict(acts_raw) if isinstance(acts_raw, dict) else None, graph_nodes=as_document_list(sources.get("graph_nodes")), graph_edges=as_document_list(sources.get("graph_edges")), cypher_rows=as_document_list(sources.get("cypher_rows")), graph_query=as_dict(sources.get("graph_query")) if isinstance(sources.get("graph_query"), dict) else None, ) return QueryResponse( query_id=query_request.query_id, question=query_request.question, answer=answer, mode=query_request.mode or "rag", sources=parsed_sources, metadata=metadata, )
[docs] @router.post("/query", response_model=QueryResponse) async def process_query( query_request: QueryRequest, components: RagComponents = Depends(get_components), user_id: Optional[str] = Depends(extract_user_id), ) -> QueryResponse: """ Process a RAG query Modes: - llm_only: LLM generation without retrieval - rag: Standard RAG (retrieval + LLM) - summarize: Summarization of retrieved documents - compare: Compare two acts """ started_at = time.perf_counter() rag_service = components.rag_service conversation_manager = components.conversation_manager try: apply_config_defaults_query(query_request) validate_query_mode_and_granularity(query_request) conv_context = None chat_history = None if conversation_manager is not None and query_request.mode not in _CONVERSATION_SKIP_MODES: conv_context = conversation_manager.load_history( conversation_id=query_request.conversation_id, question=query_request.question, user_id=user_id, ) chat_history = conv_context.history_messages or None if query_request.mode == "llm_only": answer, sources, search_metadata = await asyncio.to_thread( lambda: handle_llm_only_mode( query_request=query_request, rag_service=rag_service, ) ) elif query_request.mode == "summarize": answer, sources, search_metadata = await asyncio.to_thread( lambda: handle_summarize_mode( query_request=query_request, rag_service=rag_service, ) ) elif query_request.mode == "compare": answer, sources, search_metadata = await asyncio.to_thread( lambda: handle_compare_mode( query_request=query_request, rag_service=rag_service, ) ) elif query_request.mode == "rag": answer, sources, search_metadata = await asyncio.to_thread( lambda: handle_rag_mode( query_request=query_request, rag_service=rag_service, chat_history=chat_history, ) ) response_policy = as_dict(as_dict(search_metadata).get("response_policy")) should_run_cypher = ( response_policy.get("state") in {"grounded", "weakly_grounded"} if response_policy else as_dict(search_metadata).get("auto_mode_fallback") != "llm_only" ) if query_request.graph_use_cypher and should_run_cypher: config = get_config() cypher_sources, cypher_metadata = await collect_cypher_support( query_request=query_request, graph_rag_service=components.graph_rag_service, rag_service=rag_service, config=config, graph_depth=query_request.graph_depth or config.graph.depth, strict_graph_mode=bool(config.graph.strict_mode), ) sources = merge_sources_payload(sources, cypher_sources) search_metadata = {**search_metadata, **cypher_metadata} else: raise HTTPException( status_code=400, detail="Invalid mode. Must be one of: rag, llm_only, summarize, compare.", ) response = build_query_response( query_request=query_request, answer=answer, sources=sources, search_metadata=search_metadata, ) if conv_context is not None and conversation_manager is not None: message_id = conversation_manager.save_turn( conversation_id=conv_context.conversation_id, question=query_request.question, answer=answer, query_id=query_request.query_id, mode=query_request.mode, sources=sources, ) response.conversation_id = conv_context.conversation_id response.message_id = message_id metadata = dict(response.metadata) observe_query_request( mode=query_request.mode, granularity=query_request.granularity, top_k=query_request.top_k, duration_seconds=time.perf_counter() - started_at, outcome=infer_query_outcome(metadata), metadata=metadata, ) observe_provider_fallbacks(mode=query_request.mode, metadata=metadata) return response except HTTPException as exc: observe_query_request( mode=query_request.mode, granularity=query_request.granularity, top_k=query_request.top_k, duration_seconds=time.perf_counter() - started_at, outcome="client_error" if exc.status_code < 500 else "server_error", ) if exc.status_code >= 500: observe_provider_error( mode=query_request.mode, stage="request", exc_or_reason=exc.detail, ) raise except Exception as e: observe_query_request( mode=query_request.mode, granularity=query_request.granularity, top_k=query_request.top_k, duration_seconds=time.perf_counter() - started_at, outcome="server_error", ) observe_provider_error( mode=query_request.mode, stage="request", exc_or_reason=e, ) logger.exception("Query processing failed") raise HTTPException(status_code=500, detail="An error occurred while processing your request")