Source code for rag_service.mode_handlers

"""
Mode-specific query handlers for the RAG service.
"""

from typing import Any, Dict, Optional

from fastapi import HTTPException
from lalandre_core.utils import (
    as_dict,
    as_optional_dict,
    as_str,
    sanitize_error_text,
)
from lalandre_rag.response import (
    create_blocked_sourced_response,
    normalize_sources_payload,
)
from lalandre_rag.summaries import is_default_summary_question
from rag_service.models import QueryRequest


[docs] def handle_llm_only_mode( *, query_request: QueryRequest, rag_service: Any, ) -> tuple[str, Optional[Dict[str, Any]], Dict[str, Any]]: """Execute the direct LLM-only answer path.""" try: result = as_dict(rag_service.query_llm_only(question=query_request.question)) answer = as_str(result.get("answer")).strip() if not answer: raise ValueError("Empty LLM response") metadata = as_dict(result.get("metadata")) return answer, None, metadata except Exception as exc: reason = sanitize_error_text(exc) answer = f"Mode llm_only en reponse degradee: le runtime LLM est indisponible ({reason})." return ( answer, None, { "llm_fallback": True, "llm_fallback_reason": reason, }, )
[docs] def handle_summarize_mode( *, query_request: QueryRequest, rag_service: Any, ) -> tuple[str, Optional[Dict[str, Any]], Dict[str, Any]]: """Execute summarize mode and return answer, sources, and metadata.""" if not query_request.celex: raise HTTPException(status_code=400, detail="celex required for summarize mode") summary_topic = (query_request.question or "").strip() or f"Résumé {query_request.celex}" if is_default_summary_question(summary_topic, query_request.celex): result = rag_service.summarize_canonical( celex=query_request.celex, question=summary_topic, ) if result is not None: answer = as_str(result.get("answer")) sources = normalize_sources_payload(as_optional_dict(result.get("sources"))) metadata = as_dict(result.get("metadata")) return answer, sources, metadata result = as_dict( rag_service.summarize( topic=summary_topic, top_k=query_request.top_k, score_threshold=query_request.min_score, filters={"celex": query_request.celex}, include_full_content=query_request.include_full_content, ) ) answer = as_str(result.get("answer")) sources = normalize_sources_payload(as_optional_dict(result.get("sources"))) metadata = as_dict(result.get("metadata")) return answer, sources, metadata
[docs] def handle_compare_mode( *, query_request: QueryRequest, rag_service: Any, ) -> tuple[str, Optional[Dict[str, Any]], Dict[str, Any]]: """Execute compare mode and return answer, sources, and metadata.""" if not query_request.celex or not query_request.compare_celex: raise HTTPException(status_code=400, detail="celex and compare_celex required for compare mode") result = as_dict( rag_service.compare( comparison_question=query_request.question, celex_list=[query_request.celex, query_request.compare_celex], top_k=query_request.top_k, score_threshold=query_request.min_score, include_full_content=query_request.include_full_content, ) ) answer = as_str(result.get("answer")) sources = normalize_sources_payload(as_optional_dict(result.get("sources"))) metadata = as_dict(result.get("metadata")) return answer, sources, metadata
[docs] def handle_rag_mode( *, query_request: QueryRequest, rag_service: Any, chat_history: Any = None, ) -> tuple[str, Optional[Dict[str, Any]], Dict[str, Any]]: """Execute the default retrieval-augmented generation path.""" try: result = as_dict( rag_service.query( question=query_request.question, top_k=query_request.top_k, score_threshold=query_request.min_score, return_sources=True, include_full_content=query_request.include_full_content, granularity=query_request.granularity, chat_history=chat_history, graph_depth=query_request.graph_depth, use_graph=query_request.use_graph, embedding_preset=query_request.embedding_preset, ) ) answer = as_str(result.get("answer")) sources = normalize_sources_payload(as_optional_dict(result.get("sources"))) metadata = as_dict(result.get("metadata")) return answer, sources, metadata except Exception as exc: reason = sanitize_error_text(exc) response = create_blocked_sourced_response( mode="rag", query=query_request.question, reason=reason, ) response_metadata = as_dict(response.get("metadata")) response_metadata["llm_fallback"] = True response_metadata["llm_fallback_reason"] = reason return as_str(response.get("answer")), None, response_metadata