"""
LLM-assisted query parsing for legal retrieval routing.
Note d'architecture
-------------------
Ici, le parser d'intention utilise le LLM de génération principal
(config.generation.*) — même provider, même modèle, clé depuis Vault.
"""
import logging
from typing import Any, Dict, Optional
from lalandre_core.config import get_config
from lalandre_core.http.llm_client import JSONHTTPLLMClient, SharedKeyPoolJSONHTTPLLMClient
from lalandre_core.llm import normalize_base_url, normalize_provider
from lalandre_core.utils.api_key_pool import APIKeyPool
from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator
from lalandre_rag.agentic.models import RoutingIntentOutput
from lalandre_rag.agentic.tools import run_intent_parser_agent
logger = logging.getLogger(__name__)
_ALLOWED_PROFILES = frozenset(
{
"contextual_default",
"citation_precision",
"relationship_focus",
"global_overview",
}
)
_ALLOWED_GRANULARITIES = frozenset({"subdivisions", "chunks", "all", "auto"})
_ALLOWED_EXECUTION_MODES = frozenset({"hybrid", "global"})
_PROFILE_ALIASES: dict[str, str] = {
"default": "contextual_default",
"citation": "citation_precision",
"relation": "relationship_focus",
"relationships": "relationship_focus",
"global": "global_overview",
"overview": "global_overview",
}
[docs]
class ParsedQueryIntent(BaseModel, frozen=True):
"""Normalized interpretation returned by the intent parser."""
model_config = {"extra": "ignore"}
profile: str
granularity: Optional[str] = None
top_k: int = Field(default=10, ge=1)
include_relations_hint: bool = False
execution_mode: str = "hybrid"
rationale: str = "LLM parser selected retrieval profile."
use_graph: bool = False
normalized_query: Optional[str] = None
intent_label: Optional[str] = None
confidence: Optional[float] = Field(default=None, ge=0.0, le=1.0)
output_validation_retries: int = Field(default=0, ge=0)
[docs]
@field_validator("profile", mode="before")
@classmethod
def normalize_profile(cls, v: Any) -> str:
"""Resolve profile aliases and reject unsupported routing profiles."""
raw = str(v).strip().lower() if v is not None else ""
resolved = _PROFILE_ALIASES.get(raw, raw)
if resolved not in _ALLOWED_PROFILES:
raise ValueError(f"Invalid profile: {resolved!r}")
return resolved
[docs]
@field_validator("granularity", mode="before")
@classmethod
def normalize_granularity(cls, v: Any) -> Optional[str]:
"""Normalize requested granularity and collapse ``auto`` to ``None``."""
if v is None:
return None
g = str(v).strip().lower() if isinstance(v, str) else "auto"
if g not in _ALLOWED_GRANULARITIES:
g = "auto"
return None if g == "auto" else g
[docs]
@field_validator("top_k", mode="before")
@classmethod
def coerce_top_k(cls, v: Any) -> int:
"""Clamp ``top_k`` to the configured parser-safe range."""
cap = get_config().search.query_parser_max_top_k
if isinstance(v, int):
return max(1, min(v, cap))
if isinstance(v, str) and v.strip().isdigit():
return max(1, min(int(v.strip()), cap))
return 10
[docs]
@field_validator("include_relations_hint", mode="before")
@classmethod
def coerce_bool(cls, v: Any) -> bool:
"""Coerce common truthy string values to booleans."""
if isinstance(v, bool):
return v
if isinstance(v, str):
return v.strip().lower() in {"true", "1", "yes"}
return False
[docs]
@field_validator("confidence", mode="before")
@classmethod
def coerce_confidence(cls, v: Any) -> Optional[float]:
"""Normalize optional confidence scores to the ``[0, 1]`` range."""
if v is None:
return None
if isinstance(v, (int, float)):
return max(0.0, min(float(v), 1.0))
if isinstance(v, str):
stripped = v.strip()
if not stripped:
return None
try:
return max(0.0, min(float(stripped), 1.0))
except ValueError:
return None
return None
[docs]
@field_validator("execution_mode", mode="before")
@classmethod
def normalize_execution_mode(cls, v: Any) -> str:
"""Normalize the execution mode and default to ``hybrid``."""
if isinstance(v, str):
mode = v.strip().lower()
if mode in _ALLOWED_EXECUTION_MODES:
return mode
return "hybrid"
[docs]
@field_validator("normalized_query", mode="before")
@classmethod
def clean_normalized_query(cls, v: Any) -> Optional[str]:
"""Trim the optional normalized query field."""
if not isinstance(v, str):
return None
stripped = v.strip()
return stripped if stripped else None
[docs]
@field_validator("rationale", mode="before")
@classmethod
def clean_rationale(cls, v: Any) -> str:
"""Normalize routing rationales and provide a fallback sentence."""
if isinstance(v, str) and v.strip():
return v.strip()
return "LLM parser selected retrieval profile."
[docs]
@field_validator("intent_label", mode="before")
@classmethod
def clean_intent_label(cls, v: Any) -> Optional[str]:
"""Normalize the optional intent label emitted by the LLM."""
if isinstance(v, str) and v.strip():
return v.strip()
return None
[docs]
@model_validator(mode="after")
def apply_cross_field_defaults(self) -> "ParsedQueryIntent":
"""Apply derived defaults after model validation succeeds."""
if self.execution_mode not in _ALLOWED_EXECUTION_MODES:
new_mode = "global" if self.profile == "global_overview" else "hybrid"
object.__setattr__(self, "execution_mode", new_mode)
if not self.intent_label:
object.__setattr__(self, "intent_label", self.profile)
return self
[docs]
@classmethod
def from_routing_output(
cls,
output: RoutingIntentOutput,
*,
requested_top_k: int,
requested_granularity: Optional[str],
output_validation_retries: int,
) -> Optional["ParsedQueryIntent"]:
"""Convert validated agent output into a normalized immutable intent."""
payload: Dict[str, Any] = {
"profile": output.profile,
"granularity": output.granularity,
"top_k": output.top_k or max(int(requested_top_k), 1),
"include_relations_hint": output.include_relations_hint,
"execution_mode": output.execution_mode,
"rationale": output.rationale,
"use_graph": output.use_graph,
"normalized_query": output.normalized_query,
"intent_label": output.intent_label,
"confidence": output.confidence,
"output_validation_retries": output_validation_retries,
}
try:
intent = cls.model_validate(payload)
except ValidationError:
return None
if requested_granularity in {"subdivisions", "chunks"}:
return intent.model_copy(update={"granularity": requested_granularity})
return intent
[docs]
class LLMQueryParserClient:
"""
Query parser d'intention utilisant le LLM de génération principal.
Dégrade gracieusement : si le parsing échoue, le QueryRouter
bascule sur les heuristiques déterministes.
Architecture : NE PAS configurer sur le LLM d'extraction. Utilise
config.generation.* — Mistral, OpenAI-compatible.
"""
def __init__(
self,
*,
provider: str,
model: str,
base_url: str,
timeout_seconds: float,
api_key: Optional[str] = None,
max_output_tokens: int = 180,
temperature: float = 0.0,
key_pool: Optional[APIKeyPool] = None,
) -> None:
normalized_provider = normalize_provider(provider)
self.provider = normalized_provider
self.model = model.strip()
normalized_base_url = base_url.strip().rstrip("/")
self.base_url = normalized_base_url
self.timeout_seconds = max(0.3, float(timeout_seconds))
self.api_key = api_key.strip() if api_key else None
min_tokens = get_config().search.intent_parser_min_output_tokens
self.max_output_tokens = max(min_tokens, int(max_output_tokens))
self.temperature = max(0.0, float(temperature))
transport_provider = "openai_compatible" if self.provider == "mistral" else self.provider
if self.provider == "mistral" and key_pool is not None and len(key_pool) > 1:
self._http_client: Any = SharedKeyPoolJSONHTTPLLMClient.from_key_pool(
key_pool=key_pool,
provider=transport_provider,
model=self.model,
base_url=self.base_url,
timeout_seconds=self.timeout_seconds,
max_output_tokens=self.max_output_tokens,
temperature=self.temperature,
system_prompt="Return valid JSON only.",
error_preview_chars=240,
)
else:
self._http_client = JSONHTTPLLMClient(
provider=transport_provider,
model=self.model,
base_url=self.base_url,
timeout_seconds=self.timeout_seconds,
api_key=self.api_key,
max_output_tokens=self.max_output_tokens,
temperature=self.temperature,
system_prompt="Return valid JSON only.",
error_preview_chars=240,
)
[docs]
@classmethod
def from_runtime(
cls,
*,
config: Any,
settings: Any,
key_pool: Optional[APIKeyPool] = None,
) -> Optional["LLMQueryParserClient"]:
"""Factory depuis la config runtime.
Utilise config.generation.* comme source principale.
search.intent_parser_* peut surcharger provider/model/base_url si besoin
(ex. : utiliser un modèle plus petit dédié au routing).
NE fait plus de fallback sur extraction.llm_* (réservé au LLM d'extraction).
"""
search_cfg = config.search
if not search_cfg.intent_parser_enabled:
return None
gen_cfg = config.generation
# Provider : override intent_parser > generation
provider_raw = search_cfg.intent_parser_provider or gen_cfg.provider
# Model : override intent_parser > generation.lightweight_model_name > generation.model_name
model_raw = search_cfg.intent_parser_model or gen_cfg.lightweight_model_name or gen_cfg.model_name
# Base URL : override intent_parser > generation > Mistral public API
base_url_raw: Optional[str] = search_cfg.intent_parser_base_url or gen_cfg.base_url
if not base_url_raw and (provider_raw or "").strip().lower() == "mistral":
base_url_raw = gen_cfg.mistral_base_url
# API key : override intent_parser > LLM_API_KEY > MISTRAL_API_KEY
api_key_raw = (
search_cfg.intent_parser_api_key
or settings.SEARCH_INTENT_PARSER_API_KEY
or gen_cfg.api_key
or settings.LLM_API_KEY
or settings.MISTRAL_API_KEY
)
provider = normalize_provider(provider_raw) if isinstance(provider_raw, str) else ""
model = model_raw.strip() if isinstance(model_raw, str) else ""
base_url = normalize_base_url(
provider=provider,
base_url=base_url_raw.strip() if isinstance(base_url_raw, str) else "",
)
timeout_seconds = search_cfg.intent_parser_timeout_seconds
max_output_tokens = search_cfg.intent_parser_max_output_tokens
temperature = search_cfg.intent_parser_temperature
if provider == "mistral":
if not base_url:
base_url = gen_cfg.mistral_base_url
if not provider or not model or not base_url:
logger.warning(
"LLM query parser: configuration incomplète (provider/model/base_url manquants). Heuristiques activées."
)
return None
if provider not in {"mistral", "openai_compatible"}:
logger.warning(
"Provider LLM query parser non supporté %r; heuristiques activées.",
provider,
)
return None
logger.info(
"LLM query parser initialisé : provider=%r model=%r",
provider,
model,
)
return cls(
provider=provider,
model=model,
base_url=base_url,
timeout_seconds=timeout_seconds,
api_key=api_key_raw,
max_output_tokens=max_output_tokens,
temperature=temperature,
key_pool=key_pool,
)
[docs]
def parse(
self,
*,
question: str,
top_k: int,
requested_granularity: Optional[str],
) -> Optional[ParsedQueryIntent]:
"""Parse one user question into a normalized routing intent."""
try:
output, retries = run_intent_parser_agent(
question=question,
top_k=top_k,
requested_granularity=requested_granularity,
generate_text=self._generate,
model_name=f"{self.provider}:{self.model}",
)
except Exception as exc:
logger.warning("LLM query parser call failed: %s", exc)
return None
intent = ParsedQueryIntent.from_routing_output(
output,
requested_top_k=top_k,
requested_granularity=requested_granularity,
output_validation_retries=retries,
)
if intent is None:
logger.debug("LLM query parser returned non-coercible structured output")
return None
object.__setattr__(self, "_last_parsed_intent", intent)
return intent
def _generate(self, prompt: str) -> str:
return self._http_client.generate(prompt)