Source code for lalandre_extraction.llm.client

"""
LLM client for GraphRAG-style relation extraction.

Extracts legal relations directly from text chunks using LLM APIs
with API key pool round-robin for rate-limit distribution.
"""

import importlib.resources
import itertools
import logging
from dataclasses import dataclass
from pathlib import Path
from threading import Lock
from time import perf_counter
from typing import Any, Iterator, List, Optional

from lalandre_core.http.llm_client import JSONHTTPLLMClient
from lalandre_core.utils import APIKeyPool

from ..metrics import (
    observe_extraction_llm_call,
    observe_extraction_llm_error,
    observe_extraction_llm_relations,
)
from .agent import run_extraction_agent
from .models import ALLOWED_RELATION_TYPES

logger = logging.getLogger(__name__)

_DEFAULT_PROMPT_TEMPLATE: str = (
    importlib.resources.files("lalandre_extraction").joinpath("prompts/extraction.txt").read_text(encoding="utf-8")
)


[docs] def load_prompt_template(path: Optional[str] = None) -> str: """Load a prompt template from *path*, or fall back to the built-in default.""" if path: return Path(path).read_text(encoding="utf-8") return _DEFAULT_PROMPT_TEMPLATE
# ═══════════════════════════════════════════════════════════════════════════ # Data types # ═══════════════════════════════════════════════════════════════════════════
[docs] @dataclass(frozen=True) class RawExtractedRelation: """Relation as returned by the LLM (before entity resolution).""" target_reference: str relation_type: str text_evidence: str relation_rationale: str = ""
# ═══════════════════════════════════════════════════════════════════════════ # Extraction LLM client # ═══════════════════════════════════════════════════════════════════════════
[docs] class ExtractionLLMClient: """ GraphRAG-style LLM extraction with API key pool round-robin. The model reads raw text chunks and identifies all legal relations directly """ def __init__( self, *, provider: str, model: str, base_url: str, key_pool: APIKeyPool, timeout_seconds: float, max_output_tokens: int, temperature: float = 0.0, min_evidence_chars: int = 8, min_rationale_chars: int = 24, system_prompt: str = "You are an EU/FR legal relation extractor. Return valid JSON only.", min_output_tokens: int = 80, ) -> None: self._provider = provider.strip().lower() self._model = model.strip() self._base_url = base_url.strip() self._min_evidence_chars = max(1, int(min_evidence_chars)) self._min_rationale_chars = max(0, int(min_rationale_chars)) self._clients: List[JSONHTTPLLMClient] = [ JSONHTTPLLMClient( provider=self._provider, model=self._model, base_url=self._base_url, timeout_seconds=timeout_seconds, api_key=key, max_output_tokens=max(min_output_tokens, max_output_tokens), temperature=temperature, system_prompt=system_prompt, error_preview_chars=400, ) for key in key_pool ] self._cycle: Iterator[JSONHTTPLLMClient] = itertools.cycle(self._clients) self._lock = Lock() self._prompt_template = load_prompt_template() logger.info( "ExtractionLLMClient initialized: provider=%s model=%s keys=%d", self._provider, self._model, len(self._clients), ) @property def provider(self) -> str: """Return the normalized provider name used by the client.""" return self._provider @property def model(self) -> str: """Return the model identifier used by the client.""" return self._model
[docs] @classmethod def from_runtime( cls, *, config: Any, key_pool: Optional[APIKeyPool] = None, ) -> Optional["ExtractionLLMClient"]: """Build an extraction client from the application runtime config.""" extraction_cfg = config.extraction provider = extraction_cfg.llm_provider.strip().lower() model = extraction_cfg.llm_model.strip() base_url = extraction_cfg.llm_base_url.strip() if not provider or not model or not base_url: logger.warning( "Extraction LLM not configured (provider=%r, model=%r, base_url=%r); " "relation extraction will run without LLM.", provider, model, base_url, ) return None if key_pool is None: try: key_pool = APIKeyPool.from_env("MISTRAL_API_KEY", start_index=6) except ValueError: pass if key_pool is None: try: key_pool = APIKeyPool.from_env() except ValueError: pass if key_pool is None: logger.warning("No MISTRAL_API_KEY found; LLM extraction disabled.") return None return cls( provider=provider, model=model, base_url=base_url, key_pool=key_pool, timeout_seconds=extraction_cfg.llm_timeout_seconds, max_output_tokens=extraction_cfg.llm_max_output_tokens, temperature=extraction_cfg.llm_temperature, min_evidence_chars=extraction_cfg.llm_min_evidence_chars, min_rationale_chars=extraction_cfg.llm_min_rationale_chars, system_prompt=extraction_cfg.llm_system_prompt, min_output_tokens=extraction_cfg.llm_min_output_tokens, )
# ─────────────────────────────────────────────────────────────────── # Public API # ───────────────────────────────────────────────────────────────────
[docs] def extract_relations( self, chunk: str, source_celex: str = "", ) -> List[RawExtractedRelation]: """Extract all legal relations from a text chunk.""" if not chunk or not chunk.strip(): return [] started_at = perf_counter() prompt = self._build_prompt(chunk.strip(), source_celex) try: output, retries = run_extraction_agent( prompt=prompt, generate_text=self._generate, model_name=f"{self._provider}:{self._model}", min_evidence_chars=self._min_evidence_chars, min_rationale_chars=self._min_rationale_chars, ) except Exception as exc: duration = perf_counter() - started_at error_type = self._classify_error_type(exc) observe_extraction_llm_call( provider=self._provider, model=self._model, outcome="error", duration_seconds=duration, ) observe_extraction_llm_error( provider=self._provider, model=self._model, error_type=error_type, ) raise relations = [ RawExtractedRelation( target_reference=r.target_reference, relation_type=r.relation_type, text_evidence=r.text_evidence, relation_rationale=r.relation_rationale, ) for r in output.relations ] observe_extraction_llm_relations( provider=self._provider, model=self._model, count=len(relations), ) observe_extraction_llm_call( provider=self._provider, model=self._model, outcome="success", duration_seconds=perf_counter() - started_at, ) return relations
# ─────────────────────────────────────────────────────────────────── # Prompt construction # ─────────────────────────────────────────────────────────────────── def _build_prompt(self, chunk: str, source_celex: str = "") -> str: relation_types = ", ".join(sorted(ALLOWED_RELATION_TYPES)) source_line = ( f"Given text from SOURCE act (celex={source_celex}), " if source_celex else "Given text from a legal act, " ) return self._prompt_template.format( source_line=source_line, relation_types=relation_types, chunk=chunk, ) # ─────────────────────────────────────────────────────────────────── # Internal # ─────────────────────────────────────────────────────────────────── def _generate(self, prompt: str) -> str: client = self._next_client() return client.generate(prompt) def _next_client(self) -> JSONHTTPLLMClient: with self._lock: return next(self._cycle) @staticmethod def _classify_error_type(exc: Exception) -> str: text = str(exc).lower() if "timeout" in text: return "timeout" if "429" in text or "rate limit" in text: return "rate_limit" if "401" in text or "403" in text or "unauthorized" in text or "forbidden" in text: return "auth" if "connect" in text or "connection" in text: return "connection" return exc.__class__.__name__.lower()