"""
Pydantic models for the RAG service API.
"""
from typing import Any, Optional
from lalandre_core.utils import MODE_ALIASES, normalize_celex
from lalandre_rag.models import (
QueryMetadata,
QueryResponse,
SearchRequest,
SearchResponse,
SearchResult,
SourcesResponse,
)
from pydantic import BaseModel, Field, model_validator
__all__ = [
"QueryMetadata",
"SourcesResponse",
"QueryResponse",
"SearchRequest",
"SearchResult",
"SearchResponse",
"QueryRequest",
"HealthResponse",
]
_MODE_ALIASES = MODE_ALIASES
[docs]
class QueryRequest(BaseModel):
"""Request model for RAG queries.
All retrieval parameters are optional. When omitted (``None``), the
rag-service resolves them from ``SearchConfig`` defaults at request time
so that ``app_config.yaml`` remains the single source of truth.
"""
query_id: str
question: str
conversation_id: Optional[str] = Field(
default=None,
description="Conversation ID for multi-turn. Omit for stateless single-turn.",
)
mode: Optional[str] = Field(
default=None,
description=(
"RAG mode: rag, llm_only, summarize, compare. "
"Legacy aliases such as semantic, lexical, hybrid, graph, search, and qa resolve to rag."
),
)
top_k: Optional[int] = Field(
default=None,
ge=1,
le=50,
description="Number of results to retrieve. Default from config (search.default_limit).",
)
min_score: Optional[float] = Field(
default=None,
ge=0.0,
description="Minimum normalized retrieval score threshold (0-1). Default from config.",
)
sources: bool = False
include_full_content: bool = Field(
default=False,
description="Include full source content in responses (default: false)",
)
graph_depth: Optional[int] = None
use_graph: Optional[bool] = Field(
default=None,
description="Override graph enrichment for this request. None = use config default (use_graph_in_rag).",
)
graph_generation_timeout_seconds: Optional[float] = Field(
default=None,
gt=0.0,
description=("Optional server-side timeout (seconds) for graph mode answer generation (LLM phase only)."),
)
graph_retrieval_timeout_seconds: Optional[float] = Field(
default=None,
gt=0.0,
description=("Optional server-side timeout (seconds) for graph retrieval phase (semantic + Neo4j traversal)."),
)
graph_use_cypher: bool = Field(
default=False,
description=(
"If true (graph mode), generate a read-only Cypher query from natural "
"language and answer from Neo4j results."
),
)
graph_cypher_timeout_seconds: Optional[float] = Field(
default=None,
gt=0.0,
description="Timeout (seconds) for NL->Cypher generation. Default from config.",
)
graph_cypher_max_rows: Optional[int] = Field(
default=None,
ge=1,
le=500,
description="Maximum Cypher rows returned for graph_use_cypher mode. Default from config.",
)
granularity: Optional[str] = Field(
default=None,
description="Retrieval granularity: chunks or all. Default from config.",
)
embedding_preset: Optional[str] = Field(
default=None,
description="Embedding preset for semantic search and collection routing.",
)
retrieval_depth: Optional[str] = Field(
default=None,
description="Retrieval depth: standard or deep. Affects complementary search.",
)
celex: Optional[str] = None
compare_celex: Optional[str] = None
[docs]
@model_validator(mode="before")
@classmethod
def resolve_mode_aliases(cls, data: Any) -> Any:
"""Resolve legacy mode aliases before validation."""
if isinstance(data, dict):
mode = (data.get("mode") or "").lower()
if mode and mode in _MODE_ALIASES:
data["mode"] = _MODE_ALIASES[mode]
return data
[docs]
@model_validator(mode="after")
def normalize_celex_fields(self):
"""Normalize CELEX fields after model validation succeeds."""
if self.celex:
self.celex = normalize_celex(self.celex)
if self.compare_celex:
self.compare_celex = normalize_celex(self.compare_celex)
return self
[docs]
class HealthResponse(BaseModel):
"""Health response"""
status: str
service: str
components_initialized: bool