Source code for lalandre_extraction.llm.client
"""
LLM client for GraphRAG-style relation extraction.
Extracts legal relations directly from text chunks using LLM APIs
with API key pool round-robin for rate-limit distribution.
"""
import importlib.resources
import itertools
import logging
from dataclasses import dataclass
from pathlib import Path
from threading import Lock
from time import perf_counter
from typing import Any, Iterator, List, Optional
from lalandre_core.http.llm_client import JSONHTTPLLMClient
from lalandre_core.utils import APIKeyPool
from ..metrics import (
observe_extraction_llm_call,
observe_extraction_llm_error,
observe_extraction_llm_relations,
)
from .agent import run_extraction_agent
from .models import ALLOWED_RELATION_TYPES
logger = logging.getLogger(__name__)
_DEFAULT_PROMPT_TEMPLATE: str = (
importlib.resources.files("lalandre_extraction").joinpath("prompts/extraction.txt").read_text(encoding="utf-8")
)
[docs]
def load_prompt_template(path: Optional[str] = None) -> str:
"""Load a prompt template from *path*, or fall back to the built-in default."""
if path:
return Path(path).read_text(encoding="utf-8")
return _DEFAULT_PROMPT_TEMPLATE
# ═══════════════════════════════════════════════════════════════════════════
# Data types
# ═══════════════════════════════════════════════════════════════════════════
[docs]
@dataclass(frozen=True)
class RawExtractedRelation:
"""Relation as returned by the LLM (before entity resolution)."""
target_reference: str
relation_type: str
text_evidence: str
relation_rationale: str = ""
# ═══════════════════════════════════════════════════════════════════════════
# Extraction LLM client
# ═══════════════════════════════════════════════════════════════════════════
[docs]
class ExtractionLLMClient:
"""
GraphRAG-style LLM extraction with API key pool round-robin.
The model reads raw text chunks and identifies all legal relations directly
"""
def __init__(
self,
*,
provider: str,
model: str,
base_url: str,
key_pool: APIKeyPool,
timeout_seconds: float,
max_output_tokens: int,
temperature: float = 0.0,
min_evidence_chars: int = 8,
min_rationale_chars: int = 24,
system_prompt: str = "You are an EU/FR legal relation extractor. Return valid JSON only.",
min_output_tokens: int = 80,
) -> None:
self._provider = provider.strip().lower()
self._model = model.strip()
self._base_url = base_url.strip()
self._min_evidence_chars = max(1, int(min_evidence_chars))
self._min_rationale_chars = max(0, int(min_rationale_chars))
self._clients: List[JSONHTTPLLMClient] = [
JSONHTTPLLMClient(
provider=self._provider,
model=self._model,
base_url=self._base_url,
timeout_seconds=timeout_seconds,
api_key=key,
max_output_tokens=max(min_output_tokens, max_output_tokens),
temperature=temperature,
system_prompt=system_prompt,
error_preview_chars=400,
)
for key in key_pool
]
self._cycle: Iterator[JSONHTTPLLMClient] = itertools.cycle(self._clients)
self._lock = Lock()
self._prompt_template = load_prompt_template()
logger.info(
"ExtractionLLMClient initialized: provider=%s model=%s keys=%d",
self._provider,
self._model,
len(self._clients),
)
@property
def provider(self) -> str:
"""Return the normalized provider name used by the client."""
return self._provider
@property
def model(self) -> str:
"""Return the model identifier used by the client."""
return self._model
[docs]
@classmethod
def from_runtime(
cls,
*,
config: Any,
key_pool: Optional[APIKeyPool] = None,
) -> Optional["ExtractionLLMClient"]:
"""Build an extraction client from the application runtime config."""
extraction_cfg = config.extraction
provider = extraction_cfg.llm_provider.strip().lower()
model = extraction_cfg.llm_model.strip()
base_url = extraction_cfg.llm_base_url.strip()
if not provider or not model or not base_url:
logger.warning(
"Extraction LLM not configured (provider=%r, model=%r, base_url=%r); "
"relation extraction will run without LLM.",
provider,
model,
base_url,
)
return None
if key_pool is None:
try:
key_pool = APIKeyPool.from_env("MISTRAL_API_KEY", start_index=6)
except ValueError:
pass
if key_pool is None:
try:
key_pool = APIKeyPool.from_env()
except ValueError:
pass
if key_pool is None:
logger.warning("No MISTRAL_API_KEY found; LLM extraction disabled.")
return None
return cls(
provider=provider,
model=model,
base_url=base_url,
key_pool=key_pool,
timeout_seconds=extraction_cfg.llm_timeout_seconds,
max_output_tokens=extraction_cfg.llm_max_output_tokens,
temperature=extraction_cfg.llm_temperature,
min_evidence_chars=extraction_cfg.llm_min_evidence_chars,
min_rationale_chars=extraction_cfg.llm_min_rationale_chars,
system_prompt=extraction_cfg.llm_system_prompt,
min_output_tokens=extraction_cfg.llm_min_output_tokens,
)
# ───────────────────────────────────────────────────────────────────
# Public API
# ───────────────────────────────────────────────────────────────────
[docs]
def extract_relations(
self,
chunk: str,
source_celex: str = "",
) -> List[RawExtractedRelation]:
"""Extract all legal relations from a text chunk."""
if not chunk or not chunk.strip():
return []
started_at = perf_counter()
prompt = self._build_prompt(chunk.strip(), source_celex)
try:
output, retries = run_extraction_agent(
prompt=prompt,
generate_text=self._generate,
model_name=f"{self._provider}:{self._model}",
min_evidence_chars=self._min_evidence_chars,
min_rationale_chars=self._min_rationale_chars,
)
except Exception as exc:
duration = perf_counter() - started_at
error_type = self._classify_error_type(exc)
observe_extraction_llm_call(
provider=self._provider,
model=self._model,
outcome="error",
duration_seconds=duration,
)
observe_extraction_llm_error(
provider=self._provider,
model=self._model,
error_type=error_type,
)
raise
relations = [
RawExtractedRelation(
target_reference=r.target_reference,
relation_type=r.relation_type,
text_evidence=r.text_evidence,
relation_rationale=r.relation_rationale,
)
for r in output.relations
]
observe_extraction_llm_relations(
provider=self._provider,
model=self._model,
count=len(relations),
)
observe_extraction_llm_call(
provider=self._provider,
model=self._model,
outcome="success",
duration_seconds=perf_counter() - started_at,
)
return relations
# ───────────────────────────────────────────────────────────────────
# Prompt construction
# ───────────────────────────────────────────────────────────────────
def _build_prompt(self, chunk: str, source_celex: str = "") -> str:
relation_types = ", ".join(sorted(ALLOWED_RELATION_TYPES))
source_line = (
f"Given text from SOURCE act (celex={source_celex}), " if source_celex else "Given text from a legal act, "
)
return self._prompt_template.format(
source_line=source_line,
relation_types=relation_types,
chunk=chunk,
)
# ───────────────────────────────────────────────────────────────────
# Internal
# ───────────────────────────────────────────────────────────────────
def _generate(self, prompt: str) -> str:
client = self._next_client()
return client.generate(prompt)
def _next_client(self) -> JSONHTTPLLMClient:
with self._lock:
return next(self._cycle)
@staticmethod
def _classify_error_type(exc: Exception) -> str:
text = str(exc).lower()
if "timeout" in text:
return "timeout"
if "429" in text or "rate limit" in text:
return "rate_limit"
if "401" in text or "403" in text or "unauthorized" in text or "forbidden" in text:
return "auth"
if "connect" in text or "connection" in text:
return "connection"
return exc.__class__.__name__.lower()