Source code for lalandre_rag.service

"""
RAG (Retrieval-Augmented Generation) Service
High-level orchestration of retrieval, context enrichment, and LLM generation
"""

from typing import Any, Callable, Dict, Iterator, List, Optional, Union

from lalandre_core.config import get_config, get_env_settings
from lalandre_core.linking import LegalEntityLinker
from lalandre_core.utils.api_key_pool import APIKeyPool
from langchain_core.messages import BaseMessage

from lalandre_rag.graph import GraphRAGService
from lalandre_rag.retrieval import RetrievalService
from lalandre_rag.retrieval.context.service import ContextService

from .adapters import LlamaIndexAdapter
from .llm import build_rag_llm_clients
from .modes.hybrid_mode import HybridMode
from .modes.llm_mode import LLMMode
from .modes.summarize_mode import CompareMode, SummarizeMode
from .prompts import get_langchain_prompt
from .summaries import ActSummaryService, QuestionSummaryService


[docs] class RAGService: """ Complete RAG pipeline for legal document querying Architecture: - Delegates LLM-only to LLMMode - Delegates hybrid RAG to HybridMode - Delegates summarization to SummarizeMode - Delegates comparison to CompareMode This is a thin orchestration layer that initializes modes and delegates requests. """ def __init__( self, retrieval_service: RetrievalService, context_service: ContextService, llm_model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, api_key: Optional[str] = None, graph_rag_service: Optional[GraphRAGService] = None, key_pool: Optional[APIKeyPool] = None, act_summary_service: Optional[ActSummaryService] = None, entity_linker: Optional[LegalEntityLinker] = None, external_detector: Optional[Callable[[str], Any]] = None, ): """ Initialize RAG service Args: retrieval_service: Service for document retrieval context_service: Service for context enrichment llm_model: LLM model name (default from config) temperature: LLM temperature (default from config) max_tokens: Maximum tokens for generation (default from config) api_key: Optional provider API key override """ self.retrieval_service = retrieval_service self.context_service = context_service self.act_summary_service = act_summary_service self.entity_linker = entity_linker self.external_detector = external_detector # Load configuration config = get_config() gen_config = config.generation # Resolve LLM runtime values settings = get_env_settings() resolved_provider = (gen_config.provider or "mistral").strip().lower() resolved_model = llm_model if llm_model is not None else gen_config.model_name if not resolved_model: raise ValueError("generation.model_name must be configured") resolved_temperature = temperature if temperature is not None else gen_config.temperature if resolved_temperature is None: raise ValueError("generation.temperature must be configured") resolved_max_tokens = max_tokens if max_tokens is not None else gen_config.max_tokens resolved_timeout = float(gen_config.timeout_seconds) resolved_base_url = gen_config.base_url resolved_api_key = api_key if api_key is not None else (gen_config.api_key or settings.LLM_API_KEY) # Initialize provider-specific chat and llamaindex clients llm_clients = build_rag_llm_clients( provider=resolved_provider, model_name=resolved_model, temperature=float(resolved_temperature), max_tokens=int(resolved_max_tokens), timeout_seconds=resolved_timeout, base_url=resolved_base_url, mistral_base_url=gen_config.mistral_base_url, context_window=gen_config.context_window, api_key=resolved_api_key, mistral_api_key=settings.MISTRAL_API_KEY, key_pool=key_pool, ) self._llm_provider: str = llm_clients.provider self._llm_model: str = resolved_model self._llm_temperature: float = float(resolved_temperature) self._llm_max_tokens: int = int(resolved_max_tokens) self.llm = llm_clients.chat_llm # Build a lighter/faster LLM for agentic sub-tasks (CRAG eval/refine). lightweight_model = gen_config.lightweight_model_name if lightweight_model and lightweight_model != resolved_model: lw_clients = build_rag_llm_clients( provider=resolved_provider, model_name=lightweight_model, temperature=0.0, max_tokens=512, timeout_seconds=resolved_timeout, base_url=resolved_base_url, mistral_base_url=gen_config.mistral_base_url, context_window=gen_config.context_window, api_key=resolved_api_key, mistral_api_key=settings.MISTRAL_API_KEY, key_pool=key_pool, ) self.lightweight_llm = lw_clients.chat_llm else: self.lightweight_llm = self.llm # LlamaIndex adapter is optional (provider-dependent). llamaindex_adapter = ( LlamaIndexAdapter(llm_clients.llamaindex_llm) if llm_clients.llamaindex_llm is not None else None ) # Load prompts (with_history enables MessagesPlaceholder for multi-turn) rag_prompt = get_langchain_prompt("rag", with_history=True) question_summary_service = QuestionSummaryService(act_summary_service) # Initialize specialized modes self.llm_mode = LLMMode(self.llm) self.hybrid_mode = HybridMode( retrieval_service, context_service, self.llm, rag_prompt, graph_rag_service=graph_rag_service, lightweight_llm=self.lightweight_llm, key_pool=key_pool, entity_linker=entity_linker, external_detector=external_detector, ) self.summarize_mode = SummarizeMode( retrieval_service, context_service, llamaindex_adapter, citation_llm=self.llm, act_summary_service=act_summary_service, question_summary_service=question_summary_service, ) self.compare_mode = CompareMode( retrieval_service, context_service, llamaindex_adapter, citation_llm=self.llm, question_summary_service=question_summary_service, ) # =================================================================== # MODE 2: LLM ONLY # ===================================================================
[docs] def query_llm_only(self, question: str, include_warning: bool = True) -> Dict[str, Any]: """ MODE 2: Pure LLM (100% Generation) Delegates to LLMMode. Args: question: User question include_warning: Include warning about no document grounding Returns: Dictionary with LLM answer (no sources) """ return self.llm_mode.query(question=question, include_warning=include_warning)
[docs] def stream_query_llm_only(self, question: str) -> Iterator[str]: """Stream LLM-only answer token by token.""" return self.llm_mode.stream_query(question=question)
# =================================================================== # MODE 3: HYBRID (RAG) # ===================================================================
[docs] def query( self, question: str, top_k: int = 10, score_threshold: Optional[float] = None, filters: Optional[Dict[str, Any]] = None, include_relations: bool = False, include_subjects: bool = False, return_sources: bool = True, include_full_content: bool = False, collections: Optional[List[str]] = None, granularity: Optional[str] = None, chat_history: Optional[List[BaseMessage]] = None, graph_depth: Optional[int] = None, use_graph: Optional[bool] = None, embedding_preset: Optional[str] = None, cypher_documents: Optional[List[Dict[str, Any]]] = None, cypher_query_meta: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ MODE 3: Hybrid RAG (default) Delegates to HybridMode. Args: question: User question top_k: Number of documents to retrieve filters: Metadata filters include_relations: Include act relations in context include_subjects: Include subject classifications return_sources: Return source documents in response collections: Specific collections to search granularity: Quick selector chat_history: Optional conversation history as LangChain messages Returns: Dictionary with answer, sources, and metadata """ return self.hybrid_mode.query( question=question, top_k=top_k, score_threshold=score_threshold, filters=filters, include_relations=include_relations, include_subjects=include_subjects, include_full_content=include_full_content, return_sources=return_sources, collections=collections, granularity=granularity, chat_history=chat_history, graph_depth=graph_depth, use_graph=use_graph, embedding_preset=embedding_preset, cypher_documents=cypher_documents, cypher_query_meta=cypher_query_meta, )
[docs] def stream_query( self, question: str, top_k: int = 10, score_threshold: Optional[float] = None, filters: Optional[Dict[str, Any]] = None, include_relations: bool = False, include_subjects: bool = False, return_sources: bool = True, include_full_content: bool = False, collections: Optional[List[str]] = None, granularity: Optional[str] = None, chat_history: Optional[List[BaseMessage]] = None, graph_depth: Optional[int] = None, use_graph: Optional[bool] = None, embedding_preset: Optional[str] = None, retrieval_depth: Optional[str] = None, cypher_documents: Optional[List[Dict[str, Any]]] = None, cypher_query_meta: Optional[Dict[str, Any]] = None, ) -> Iterator[Union[Dict[str, Any], str]]: """Stream hybrid RAG answer: yields preamble dict then string tokens.""" return self.hybrid_mode.stream_query( question=question, top_k=top_k, score_threshold=score_threshold, filters=filters, include_relations=include_relations, include_subjects=include_subjects, include_full_content=include_full_content, return_sources=return_sources, collections=collections, granularity=granularity, chat_history=chat_history, graph_depth=graph_depth, use_graph=use_graph, embedding_preset=embedding_preset, retrieval_depth=retrieval_depth, cypher_documents=cypher_documents, cypher_query_meta=cypher_query_meta, )
# =================================================================== # MODE 4: SUMMARIZE # ===================================================================
[docs] def summarize( self, topic: str, top_k: int = 10, score_threshold: Optional[float] = None, filters: Optional[Dict[str, Any]] = None, include_relations: bool = True, include_full_content: bool = False, ) -> Dict[str, Any]: """ MODE 4: Summarization Delegates to SummarizeMode. Args: topic: Topic or question to summarize top_k: Number of documents to retrieve filters: Metadata filters include_relations: Include relations in context Returns: Dictionary with summary and sources """ return self.summarize_mode.summarize( topic=topic, top_k=top_k, score_threshold=score_threshold, filters=filters, include_relations=include_relations, include_full_content=include_full_content, )
[docs] def summarize_canonical( self, *, celex: str, question: str, ) -> Optional[Dict[str, Any]]: """Return a canonical-summary response when one exists for *celex*.""" return self.summarize_mode.summarize_canonical(celex=celex, question=question)
# =================================================================== # MODE 5: COMPARE # ===================================================================
[docs] def compare( self, comparison_question: str, celex_list: Optional[List[str]] = None, top_k: int = 10, score_threshold: Optional[float] = None, include_full_content: bool = False, ) -> Dict[str, Any]: """ MODE 5: Comparison Delegates to CompareMode. Args: comparison_question: What to compare celex_list: Optional list of specific CELEX to compare top_k: Number of documents if CELEX not specified Returns: Dictionary with comparison and sources """ return self.compare_mode.compare( comparison_question=comparison_question, celex_list=celex_list, top_k=top_k, score_threshold=score_threshold, include_full_content=include_full_content, )
# =================================================================== # UTILITY METHODS # ===================================================================
[docs] def get_statistics(self) -> Dict[str, Any]: """ Get RAG pipeline statistics Returns: Dictionary with pipeline statistics """ retrieval_stats = self.retrieval_service.get_statistics() return { "rag_pipeline": { "llm_provider": self._llm_provider, "llm_model": self._llm_model, "llm_temperature": self._llm_temperature, "max_tokens": self._llm_max_tokens, }, "retrieval": retrieval_stats, }