"""
RAG (Retrieval-Augmented Generation) Service
High-level orchestration of retrieval, context enrichment, and LLM generation
"""
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
from lalandre_core.config import get_config, get_env_settings
from lalandre_core.linking import LegalEntityLinker
from lalandre_core.utils.api_key_pool import APIKeyPool
from langchain_core.messages import BaseMessage
from lalandre_rag.graph import GraphRAGService
from lalandre_rag.retrieval import RetrievalService
from lalandre_rag.retrieval.context.service import ContextService
from .adapters import LlamaIndexAdapter
from .llm import build_rag_llm_clients
from .modes.hybrid_mode import HybridMode
from .modes.llm_mode import LLMMode
from .modes.summarize_mode import CompareMode, SummarizeMode
from .prompts import get_langchain_prompt
from .summaries import ActSummaryService, QuestionSummaryService
[docs]
class RAGService:
"""
Complete RAG pipeline for legal document querying
Architecture:
- Delegates LLM-only to LLMMode
- Delegates hybrid RAG to HybridMode
- Delegates summarization to SummarizeMode
- Delegates comparison to CompareMode
This is a thin orchestration layer that initializes modes and delegates requests.
"""
def __init__(
self,
retrieval_service: RetrievalService,
context_service: ContextService,
llm_model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
api_key: Optional[str] = None,
graph_rag_service: Optional[GraphRAGService] = None,
key_pool: Optional[APIKeyPool] = None,
act_summary_service: Optional[ActSummaryService] = None,
entity_linker: Optional[LegalEntityLinker] = None,
external_detector: Optional[Callable[[str], Any]] = None,
):
"""
Initialize RAG service
Args:
retrieval_service: Service for document retrieval
context_service: Service for context enrichment
llm_model: LLM model name (default from config)
temperature: LLM temperature (default from config)
max_tokens: Maximum tokens for generation (default from config)
api_key: Optional provider API key override
"""
self.retrieval_service = retrieval_service
self.context_service = context_service
self.act_summary_service = act_summary_service
self.entity_linker = entity_linker
self.external_detector = external_detector
# Load configuration
config = get_config()
gen_config = config.generation
# Resolve LLM runtime values
settings = get_env_settings()
resolved_provider = (gen_config.provider or "mistral").strip().lower()
resolved_model = llm_model if llm_model is not None else gen_config.model_name
if not resolved_model:
raise ValueError("generation.model_name must be configured")
resolved_temperature = temperature if temperature is not None else gen_config.temperature
if resolved_temperature is None:
raise ValueError("generation.temperature must be configured")
resolved_max_tokens = max_tokens if max_tokens is not None else gen_config.max_tokens
resolved_timeout = float(gen_config.timeout_seconds)
resolved_base_url = gen_config.base_url
resolved_api_key = api_key if api_key is not None else (gen_config.api_key or settings.LLM_API_KEY)
# Initialize provider-specific chat and llamaindex clients
llm_clients = build_rag_llm_clients(
provider=resolved_provider,
model_name=resolved_model,
temperature=float(resolved_temperature),
max_tokens=int(resolved_max_tokens),
timeout_seconds=resolved_timeout,
base_url=resolved_base_url,
mistral_base_url=gen_config.mistral_base_url,
context_window=gen_config.context_window,
api_key=resolved_api_key,
mistral_api_key=settings.MISTRAL_API_KEY,
key_pool=key_pool,
)
self._llm_provider: str = llm_clients.provider
self._llm_model: str = resolved_model
self._llm_temperature: float = float(resolved_temperature)
self._llm_max_tokens: int = int(resolved_max_tokens)
self.llm = llm_clients.chat_llm
# Build a lighter/faster LLM for agentic sub-tasks (CRAG eval/refine).
lightweight_model = gen_config.lightweight_model_name
if lightweight_model and lightweight_model != resolved_model:
lw_clients = build_rag_llm_clients(
provider=resolved_provider,
model_name=lightweight_model,
temperature=0.0,
max_tokens=512,
timeout_seconds=resolved_timeout,
base_url=resolved_base_url,
mistral_base_url=gen_config.mistral_base_url,
context_window=gen_config.context_window,
api_key=resolved_api_key,
mistral_api_key=settings.MISTRAL_API_KEY,
key_pool=key_pool,
)
self.lightweight_llm = lw_clients.chat_llm
else:
self.lightweight_llm = self.llm
# LlamaIndex adapter is optional (provider-dependent).
llamaindex_adapter = (
LlamaIndexAdapter(llm_clients.llamaindex_llm) if llm_clients.llamaindex_llm is not None else None
)
# Load prompts (with_history enables MessagesPlaceholder for multi-turn)
rag_prompt = get_langchain_prompt("rag", with_history=True)
question_summary_service = QuestionSummaryService(act_summary_service)
# Initialize specialized modes
self.llm_mode = LLMMode(self.llm)
self.hybrid_mode = HybridMode(
retrieval_service,
context_service,
self.llm,
rag_prompt,
graph_rag_service=graph_rag_service,
lightweight_llm=self.lightweight_llm,
key_pool=key_pool,
entity_linker=entity_linker,
external_detector=external_detector,
)
self.summarize_mode = SummarizeMode(
retrieval_service,
context_service,
llamaindex_adapter,
citation_llm=self.llm,
act_summary_service=act_summary_service,
question_summary_service=question_summary_service,
)
self.compare_mode = CompareMode(
retrieval_service,
context_service,
llamaindex_adapter,
citation_llm=self.llm,
question_summary_service=question_summary_service,
)
# ===================================================================
# MODE 2: LLM ONLY
# ===================================================================
[docs]
def query_llm_only(self, question: str, include_warning: bool = True) -> Dict[str, Any]:
"""
MODE 2: Pure LLM (100% Generation)
Delegates to LLMMode.
Args:
question: User question
include_warning: Include warning about no document grounding
Returns:
Dictionary with LLM answer (no sources)
"""
return self.llm_mode.query(question=question, include_warning=include_warning)
[docs]
def stream_query_llm_only(self, question: str) -> Iterator[str]:
"""Stream LLM-only answer token by token."""
return self.llm_mode.stream_query(question=question)
# ===================================================================
# MODE 3: HYBRID (RAG)
# ===================================================================
[docs]
def query(
self,
question: str,
top_k: int = 10,
score_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
include_relations: bool = False,
include_subjects: bool = False,
return_sources: bool = True,
include_full_content: bool = False,
collections: Optional[List[str]] = None,
granularity: Optional[str] = None,
chat_history: Optional[List[BaseMessage]] = None,
graph_depth: Optional[int] = None,
use_graph: Optional[bool] = None,
embedding_preset: Optional[str] = None,
cypher_documents: Optional[List[Dict[str, Any]]] = None,
cypher_query_meta: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
MODE 3: Hybrid RAG (default)
Delegates to HybridMode.
Args:
question: User question
top_k: Number of documents to retrieve
filters: Metadata filters
include_relations: Include act relations in context
include_subjects: Include subject classifications
return_sources: Return source documents in response
collections: Specific collections to search
granularity: Quick selector
chat_history: Optional conversation history as LangChain messages
Returns:
Dictionary with answer, sources, and metadata
"""
return self.hybrid_mode.query(
question=question,
top_k=top_k,
score_threshold=score_threshold,
filters=filters,
include_relations=include_relations,
include_subjects=include_subjects,
include_full_content=include_full_content,
return_sources=return_sources,
collections=collections,
granularity=granularity,
chat_history=chat_history,
graph_depth=graph_depth,
use_graph=use_graph,
embedding_preset=embedding_preset,
cypher_documents=cypher_documents,
cypher_query_meta=cypher_query_meta,
)
[docs]
def stream_query(
self,
question: str,
top_k: int = 10,
score_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
include_relations: bool = False,
include_subjects: bool = False,
return_sources: bool = True,
include_full_content: bool = False,
collections: Optional[List[str]] = None,
granularity: Optional[str] = None,
chat_history: Optional[List[BaseMessage]] = None,
graph_depth: Optional[int] = None,
use_graph: Optional[bool] = None,
embedding_preset: Optional[str] = None,
retrieval_depth: Optional[str] = None,
cypher_documents: Optional[List[Dict[str, Any]]] = None,
cypher_query_meta: Optional[Dict[str, Any]] = None,
) -> Iterator[Union[Dict[str, Any], str]]:
"""Stream hybrid RAG answer: yields preamble dict then string tokens."""
return self.hybrid_mode.stream_query(
question=question,
top_k=top_k,
score_threshold=score_threshold,
filters=filters,
include_relations=include_relations,
include_subjects=include_subjects,
include_full_content=include_full_content,
return_sources=return_sources,
collections=collections,
granularity=granularity,
chat_history=chat_history,
graph_depth=graph_depth,
use_graph=use_graph,
embedding_preset=embedding_preset,
retrieval_depth=retrieval_depth,
cypher_documents=cypher_documents,
cypher_query_meta=cypher_query_meta,
)
# ===================================================================
# MODE 4: SUMMARIZE
# ===================================================================
[docs]
def summarize(
self,
topic: str,
top_k: int = 10,
score_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
include_relations: bool = True,
include_full_content: bool = False,
) -> Dict[str, Any]:
"""
MODE 4: Summarization
Delegates to SummarizeMode.
Args:
topic: Topic or question to summarize
top_k: Number of documents to retrieve
filters: Metadata filters
include_relations: Include relations in context
Returns:
Dictionary with summary and sources
"""
return self.summarize_mode.summarize(
topic=topic,
top_k=top_k,
score_threshold=score_threshold,
filters=filters,
include_relations=include_relations,
include_full_content=include_full_content,
)
[docs]
def summarize_canonical(
self,
*,
celex: str,
question: str,
) -> Optional[Dict[str, Any]]:
"""Return a canonical-summary response when one exists for *celex*."""
return self.summarize_mode.summarize_canonical(celex=celex, question=question)
# ===================================================================
# MODE 5: COMPARE
# ===================================================================
[docs]
def compare(
self,
comparison_question: str,
celex_list: Optional[List[str]] = None,
top_k: int = 10,
score_threshold: Optional[float] = None,
include_full_content: bool = False,
) -> Dict[str, Any]:
"""
MODE 5: Comparison
Delegates to CompareMode.
Args:
comparison_question: What to compare
celex_list: Optional list of specific CELEX to compare
top_k: Number of documents if CELEX not specified
Returns:
Dictionary with comparison and sources
"""
return self.compare_mode.compare(
comparison_question=comparison_question,
celex_list=celex_list,
top_k=top_k,
score_threshold=score_threshold,
include_full_content=include_full_content,
)
# ===================================================================
# UTILITY METHODS
# ===================================================================
[docs]
def get_statistics(self) -> Dict[str, Any]:
"""
Get RAG pipeline statistics
Returns:
Dictionary with pipeline statistics
"""
retrieval_stats = self.retrieval_service.get_statistics()
return {
"rag_pipeline": {
"llm_provider": self._llm_provider,
"llm_model": self._llm_model,
"llm_temperature": self._llm_temperature,
"max_tokens": self._llm_max_tokens,
},
"retrieval": retrieval_stats,
}