Source code for lalandre_rag.retrieval.query_parser

"""
LLM-assisted query parsing for legal retrieval routing.

Note d'architecture
-------------------
Ici, le parser d'intention utilise le LLM de génération principal
(config.generation.*) — même provider, même modèle, clé depuis Vault.
"""

import logging
from typing import Any, Dict, Optional

from lalandre_core.config import get_config
from lalandre_core.http.llm_client import JSONHTTPLLMClient, SharedKeyPoolJSONHTTPLLMClient
from lalandre_core.llm import normalize_base_url, normalize_provider
from lalandre_core.utils.api_key_pool import APIKeyPool
from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator

from lalandre_rag.agentic.models import RoutingIntentOutput
from lalandre_rag.agentic.tools import run_intent_parser_agent

logger = logging.getLogger(__name__)

_ALLOWED_PROFILES = frozenset(
    {
        "contextual_default",
        "citation_precision",
        "relationship_focus",
        "global_overview",
    }
)
_ALLOWED_GRANULARITIES = frozenset({"subdivisions", "chunks", "all", "auto"})
_ALLOWED_EXECUTION_MODES = frozenset({"hybrid", "global"})

_PROFILE_ALIASES: dict[str, str] = {
    "default": "contextual_default",
    "citation": "citation_precision",
    "relation": "relationship_focus",
    "relationships": "relationship_focus",
    "global": "global_overview",
    "overview": "global_overview",
}


[docs] class ParsedQueryIntent(BaseModel, frozen=True): """Normalized interpretation returned by the intent parser.""" model_config = {"extra": "ignore"} profile: str granularity: Optional[str] = None top_k: int = Field(default=10, ge=1) include_relations_hint: bool = False execution_mode: str = "hybrid" rationale: str = "LLM parser selected retrieval profile." use_graph: bool = False normalized_query: Optional[str] = None intent_label: Optional[str] = None confidence: Optional[float] = Field(default=None, ge=0.0, le=1.0) output_validation_retries: int = Field(default=0, ge=0)
[docs] @field_validator("profile", mode="before") @classmethod def normalize_profile(cls, v: Any) -> str: """Resolve profile aliases and reject unsupported routing profiles.""" raw = str(v).strip().lower() if v is not None else "" resolved = _PROFILE_ALIASES.get(raw, raw) if resolved not in _ALLOWED_PROFILES: raise ValueError(f"Invalid profile: {resolved!r}") return resolved
[docs] @field_validator("granularity", mode="before") @classmethod def normalize_granularity(cls, v: Any) -> Optional[str]: """Normalize requested granularity and collapse ``auto`` to ``None``.""" if v is None: return None g = str(v).strip().lower() if isinstance(v, str) else "auto" if g not in _ALLOWED_GRANULARITIES: g = "auto" return None if g == "auto" else g
[docs] @field_validator("top_k", mode="before") @classmethod def coerce_top_k(cls, v: Any) -> int: """Clamp ``top_k`` to the configured parser-safe range.""" cap = get_config().search.query_parser_max_top_k if isinstance(v, int): return max(1, min(v, cap)) if isinstance(v, str) and v.strip().isdigit(): return max(1, min(int(v.strip()), cap)) return 10
[docs] @field_validator("include_relations_hint", mode="before") @classmethod def coerce_bool(cls, v: Any) -> bool: """Coerce common truthy string values to booleans.""" if isinstance(v, bool): return v if isinstance(v, str): return v.strip().lower() in {"true", "1", "yes"} return False
[docs] @field_validator("confidence", mode="before") @classmethod def coerce_confidence(cls, v: Any) -> Optional[float]: """Normalize optional confidence scores to the ``[0, 1]`` range.""" if v is None: return None if isinstance(v, (int, float)): return max(0.0, min(float(v), 1.0)) if isinstance(v, str): stripped = v.strip() if not stripped: return None try: return max(0.0, min(float(stripped), 1.0)) except ValueError: return None return None
[docs] @field_validator("execution_mode", mode="before") @classmethod def normalize_execution_mode(cls, v: Any) -> str: """Normalize the execution mode and default to ``hybrid``.""" if isinstance(v, str): mode = v.strip().lower() if mode in _ALLOWED_EXECUTION_MODES: return mode return "hybrid"
[docs] @field_validator("normalized_query", mode="before") @classmethod def clean_normalized_query(cls, v: Any) -> Optional[str]: """Trim the optional normalized query field.""" if not isinstance(v, str): return None stripped = v.strip() return stripped if stripped else None
[docs] @field_validator("rationale", mode="before") @classmethod def clean_rationale(cls, v: Any) -> str: """Normalize routing rationales and provide a fallback sentence.""" if isinstance(v, str) and v.strip(): return v.strip() return "LLM parser selected retrieval profile."
[docs] @field_validator("intent_label", mode="before") @classmethod def clean_intent_label(cls, v: Any) -> Optional[str]: """Normalize the optional intent label emitted by the LLM.""" if isinstance(v, str) and v.strip(): return v.strip() return None
[docs] @model_validator(mode="after") def apply_cross_field_defaults(self) -> "ParsedQueryIntent": """Apply derived defaults after model validation succeeds.""" if self.execution_mode not in _ALLOWED_EXECUTION_MODES: new_mode = "global" if self.profile == "global_overview" else "hybrid" object.__setattr__(self, "execution_mode", new_mode) if not self.intent_label: object.__setattr__(self, "intent_label", self.profile) return self
[docs] @classmethod def from_routing_output( cls, output: RoutingIntentOutput, *, requested_top_k: int, requested_granularity: Optional[str], output_validation_retries: int, ) -> Optional["ParsedQueryIntent"]: """Convert validated agent output into a normalized immutable intent.""" payload: Dict[str, Any] = { "profile": output.profile, "granularity": output.granularity, "top_k": output.top_k or max(int(requested_top_k), 1), "include_relations_hint": output.include_relations_hint, "execution_mode": output.execution_mode, "rationale": output.rationale, "use_graph": output.use_graph, "normalized_query": output.normalized_query, "intent_label": output.intent_label, "confidence": output.confidence, "output_validation_retries": output_validation_retries, } try: intent = cls.model_validate(payload) except ValidationError: return None if requested_granularity in {"subdivisions", "chunks"}: return intent.model_copy(update={"granularity": requested_granularity}) return intent
[docs] class LLMQueryParserClient: """ Query parser d'intention utilisant le LLM de génération principal. Dégrade gracieusement : si le parsing échoue, le QueryRouter bascule sur les heuristiques déterministes. Architecture : NE PAS configurer sur le LLM d'extraction. Utilise config.generation.* — Mistral, OpenAI-compatible. """ def __init__( self, *, provider: str, model: str, base_url: str, timeout_seconds: float, api_key: Optional[str] = None, max_output_tokens: int = 180, temperature: float = 0.0, key_pool: Optional[APIKeyPool] = None, ) -> None: normalized_provider = normalize_provider(provider) self.provider = normalized_provider self.model = model.strip() normalized_base_url = base_url.strip().rstrip("/") self.base_url = normalized_base_url self.timeout_seconds = max(0.3, float(timeout_seconds)) self.api_key = api_key.strip() if api_key else None min_tokens = get_config().search.intent_parser_min_output_tokens self.max_output_tokens = max(min_tokens, int(max_output_tokens)) self.temperature = max(0.0, float(temperature)) transport_provider = "openai_compatible" if self.provider == "mistral" else self.provider if self.provider == "mistral" and key_pool is not None and len(key_pool) > 1: self._http_client: Any = SharedKeyPoolJSONHTTPLLMClient.from_key_pool( key_pool=key_pool, provider=transport_provider, model=self.model, base_url=self.base_url, timeout_seconds=self.timeout_seconds, max_output_tokens=self.max_output_tokens, temperature=self.temperature, system_prompt="Return valid JSON only.", error_preview_chars=240, ) else: self._http_client = JSONHTTPLLMClient( provider=transport_provider, model=self.model, base_url=self.base_url, timeout_seconds=self.timeout_seconds, api_key=self.api_key, max_output_tokens=self.max_output_tokens, temperature=self.temperature, system_prompt="Return valid JSON only.", error_preview_chars=240, )
[docs] @classmethod def from_runtime( cls, *, config: Any, settings: Any, key_pool: Optional[APIKeyPool] = None, ) -> Optional["LLMQueryParserClient"]: """Factory depuis la config runtime. Utilise config.generation.* comme source principale. search.intent_parser_* peut surcharger provider/model/base_url si besoin (ex. : utiliser un modèle plus petit dédié au routing). NE fait plus de fallback sur extraction.llm_* (réservé au LLM d'extraction). """ search_cfg = config.search if not search_cfg.intent_parser_enabled: return None gen_cfg = config.generation # Provider : override intent_parser > generation provider_raw = search_cfg.intent_parser_provider or gen_cfg.provider # Model : override intent_parser > generation.lightweight_model_name > generation.model_name model_raw = search_cfg.intent_parser_model or gen_cfg.lightweight_model_name or gen_cfg.model_name # Base URL : override intent_parser > generation > Mistral public API base_url_raw: Optional[str] = search_cfg.intent_parser_base_url or gen_cfg.base_url if not base_url_raw and (provider_raw or "").strip().lower() == "mistral": base_url_raw = gen_cfg.mistral_base_url # API key : override intent_parser > LLM_API_KEY > MISTRAL_API_KEY api_key_raw = ( search_cfg.intent_parser_api_key or settings.SEARCH_INTENT_PARSER_API_KEY or gen_cfg.api_key or settings.LLM_API_KEY or settings.MISTRAL_API_KEY ) provider = normalize_provider(provider_raw) if isinstance(provider_raw, str) else "" model = model_raw.strip() if isinstance(model_raw, str) else "" base_url = normalize_base_url( provider=provider, base_url=base_url_raw.strip() if isinstance(base_url_raw, str) else "", ) timeout_seconds = search_cfg.intent_parser_timeout_seconds max_output_tokens = search_cfg.intent_parser_max_output_tokens temperature = search_cfg.intent_parser_temperature if provider == "mistral": if not base_url: base_url = gen_cfg.mistral_base_url if not provider or not model or not base_url: logger.warning( "LLM query parser: configuration incomplète (provider/model/base_url manquants). Heuristiques activées." ) return None if provider not in {"mistral", "openai_compatible"}: logger.warning( "Provider LLM query parser non supporté %r; heuristiques activées.", provider, ) return None logger.info( "LLM query parser initialisé : provider=%r model=%r", provider, model, ) return cls( provider=provider, model=model, base_url=base_url, timeout_seconds=timeout_seconds, api_key=api_key_raw, max_output_tokens=max_output_tokens, temperature=temperature, key_pool=key_pool, )
[docs] def parse( self, *, question: str, top_k: int, requested_granularity: Optional[str], ) -> Optional[ParsedQueryIntent]: """Parse one user question into a normalized routing intent.""" try: output, retries = run_intent_parser_agent( question=question, top_k=top_k, requested_granularity=requested_granularity, generate_text=self._generate, model_name=f"{self.provider}:{self.model}", ) except Exception as exc: logger.warning("LLM query parser call failed: %s", exc) return None intent = ParsedQueryIntent.from_routing_output( output, requested_top_k=top_k, requested_granularity=requested_granularity, output_validation_retries=retries, ) if intent is None: logger.debug("LLM query parser returned non-coercible structured output") return None object.__setattr__(self, "_last_parsed_intent", intent) return intent
def _generate(self, prompt: str) -> str: return self._http_client.generate(prompt)