Source code for lalandre_extraction.relation_extractor

"""
Regulatory Relation Extractor

GraphRAG-style LLM-first extraction of legal relationships between acts.
The LLM reads text chunks directly and identifies all relations.
Entity linking resolves references to canonical CELEX identifiers post-extraction.
"""

import logging
import re
from collections import Counter, deque
from concurrent.futures import Future, ThreadPoolExecutor
from datetime import datetime
from hashlib import sha256
from threading import Lock
from typing import Dict, List, Optional, Sequence

from lalandre_core.config import ExtractionConfidenceConfig, get_config
from lalandre_core.linking import LegalEntityLinker, LinkResolution, is_generic_target, looks_like_identifier
from lalandre_core.models.types import RelationType
from lalandre_core.utils import normalize_celex
from pydantic import BaseModel, Field

from .llm import ALLOWED_RELATION_TYPES, ExtractionLLMClient, RawExtractedRelation

logger = logging.getLogger(__name__)


# ═══════════════════════════════════════════════════════════════════════════
# Model
# ═══════════════════════════════════════════════════════════════════════════


[docs] class ExtractedRelation(BaseModel): """A legal relationship extracted from text.""" source_celex: str target_celex: str relation_type: RelationType confidence: float = Field(ge=0.0, le=1.0) text_evidence: str relation_description: Optional[str] = None extraction_method: str = "llm_extraction" effect_date: Optional[datetime] = None source_subdivision: Optional[str] = None target_subdivision: Optional[str] = None raw_target_reference: Optional[str] = None resolution_method: Optional[str] = None resolution_score: Optional[float] = None
# ═══════════════════════════════════════════════════════════════════════════ # Extractor # ═══════════════════════════════════════════════════════════════════════════ _RELATION_TYPE_BY_VALUE: Dict[str, RelationType] = { # Fast lookup for LLM string output -> internal enum. rt.value: rt for rt in RelationType if rt.value in ALLOWED_RELATION_TYPES }
[docs] class RegulatoryRelationExtractor: """ LLM-first extraction pipeline (GraphRAG-style): - text chunking - LLM extraction per chunk - entity linking (post-LLM resolution) - confidence scoring - validation & merge """ # ─────────────────────────────────────────────────────────────────── # Initialization # ─────────────────────────────────────────────────────────────────── def __init__( self, max_chunk_size: Optional[int] = None, entity_linker: Optional[LegalEntityLinker] = None, validation_enabled: Optional[bool] = None, min_evidence_chars: Optional[int] = None, llm_client: Optional[ExtractionLLMClient] = None, ) -> None: config = get_config() extraction_cfg = config.extraction if max_chunk_size is None: max_chunk_size = config.chunking.extraction_max_chunk_chars default_min_confidence = config.gateway.job_extract_min_confidence if default_min_confidence is None: raise ValueError("gateway.job_extract_min_confidence must be configured for extraction") resolved_validation_enabled = ( validation_enabled if validation_enabled is not None else extraction_cfg.validation_enabled ) resolved_min_evidence_chars = ( min_evidence_chars if min_evidence_chars is not None else extraction_cfg.min_evidence_chars ) self.llm_client = llm_client if llm_client is not None else ExtractionLLMClient.from_runtime(config=config) self.max_chunk_size = max_chunk_size self.llm_max_parallel_chunks = max(1, int(extraction_cfg.llm_max_parallel_chunks)) self.llm_chunk_cache_size = max(0, int(extraction_cfg.llm_chunk_cache_size)) self.default_min_confidence = float(default_min_confidence) self.entity_linker = entity_linker self.validation_enabled = bool(resolved_validation_enabled) self.min_evidence_chars = int(resolved_min_evidence_chars) self.max_evidence_chars = int(extraction_cfg.max_evidence_chars) self.min_description_chars = int(extraction_cfg.min_description_chars) self._chunk_cache: Dict[str, List[RawExtractedRelation]] = {} self._chunk_cache_order: deque[str] = deque() self._chunk_cache_lock = Lock() self._conf: ExtractionConfidenceConfig = extraction_cfg.confidence if self.llm_client is not None: logger.info( "Extractor LLM enabled: provider=%s model=%s", self.llm_client.provider, self.llm_client.model, )
[docs] def set_entity_linker(self, entity_linker: Optional[LegalEntityLinker]) -> None: """Replace the entity linker used for post-LLM reference resolution.""" self.entity_linker = entity_linker
# ─────────────────────────────────────────────────────────────────── # Public API # ───────────────────────────────────────────────────────────────────
[docs] def extract_relations( self, text: str, source_celex: str, min_confidence: Optional[float] = None, ) -> List[ExtractedRelation]: """Extract, resolve, merge, and validate relations from one act text.""" if self.llm_client is None: logger.warning("No LLM client configured; extraction skipped") return [] resolved_min_confidence = float(min_confidence) if min_confidence is not None else self.default_min_confidence chunks = self._chunk_text(text, self.max_chunk_size) relations: List[ExtractedRelation] = [] provenance: Counter[str] = Counter() raw_relations_count = 0 failed_chunks = 0 if len(chunks) <= 1 or self.llm_max_parallel_chunks <= 1: for i, chunk in enumerate(chunks): try: resolved, chunk_provenance, raw_count = self._extract_single_chunk( chunk=chunk, source_celex=source_celex, min_confidence=resolved_min_confidence, ) except Exception: failed_chunks += 1 logger.error( "Chunk %d/%d failed for celex=%s", i + 1, len(chunks), source_celex, exc_info=True, ) continue relations.extend(resolved) provenance.update(chunk_provenance) raw_relations_count += raw_count else: max_workers = min(self.llm_max_parallel_chunks, len(chunks)) logger.info( "Extraction chunk parallelism enabled: chunks=%d workers=%d", len(chunks), max_workers, ) with ThreadPoolExecutor(max_workers=max_workers) as executor: futures: List[Future[tuple[List[ExtractedRelation], Counter[str], int]]] = [ executor.submit( self._extract_single_chunk, chunk=chunk, source_celex=source_celex, min_confidence=resolved_min_confidence, ) for chunk in chunks ] for i, future in enumerate(futures): try: resolved, chunk_provenance, raw_count = future.result() except Exception: failed_chunks += 1 logger.error( "Chunk %d/%d failed for celex=%s", i + 1, len(chunks), source_celex, exc_info=True, ) continue relations.extend(resolved) provenance.update(chunk_provenance) raw_relations_count += raw_count logger.info( "Extraction summary: chunks=%d, failed=%d, raw_relations=%d, resolved=%d", len(chunks), failed_chunks, raw_relations_count, len(relations), ) if provenance: logger.info("Extraction provenance: %s", dict(provenance)) if not relations: return [] merged = self._merge_relations(relations) if not self.validation_enabled: return merged validated = self._validate_relations( merged, source_celex=source_celex, min_confidence=resolved_min_confidence, ) if len(validated) != len(merged): logger.info( "Validation filtered %d relations (kept=%d)", len(merged) - len(validated), len(validated), ) return validated
def _extract_single_chunk( self, *, chunk: str, source_celex: str, min_confidence: float, ) -> tuple[List[ExtractedRelation], Counter[str], int]: raw_relations = self._extract_chunk_with_cache(chunk=chunk, source_celex=source_celex) chunk_provenance: Counter[str] = Counter() resolved = self._resolve_and_score( raw_relations, source_celex=source_celex, min_confidence=min_confidence, provenance=chunk_provenance, ) return resolved, chunk_provenance, len(raw_relations) def _extract_chunk_with_cache( self, *, chunk: str, source_celex: str, ) -> List[RawExtractedRelation]: llm_client = self.llm_client assert llm_client is not None if self.llm_chunk_cache_size <= 0: return llm_client.extract_relations(chunk, source_celex) cache_key = self._build_chunk_cache_key(chunk=chunk, source_celex=source_celex) cached = self._get_cached_chunk_relations(cache_key) if cached is not None: return cached raw_relations = llm_client.extract_relations(chunk, source_celex) self._set_cached_chunk_relations(cache_key, raw_relations) return raw_relations def _build_chunk_cache_key(self, *, chunk: str, source_celex: str) -> str: llm_client = self.llm_client assert llm_client is not None provider = llm_client.provider model = llm_client.model fingerprint = f"{provider}|{model}|{source_celex}|{chunk}" return sha256(fingerprint.encode("utf-8")).hexdigest() def _get_cached_chunk_relations(self, cache_key: str) -> Optional[List[RawExtractedRelation]]: with self._chunk_cache_lock: cached = self._chunk_cache.get(cache_key) if cached is None: return None # Return a shallow copy to keep cache immutable from callers. return list(cached) def _set_cached_chunk_relations( self, cache_key: str, raw_relations: Sequence[RawExtractedRelation], ) -> None: with self._chunk_cache_lock: if cache_key in self._chunk_cache: self._chunk_cache[cache_key] = list(raw_relations) return self._chunk_cache[cache_key] = list(raw_relations) self._chunk_cache_order.append(cache_key) while len(self._chunk_cache_order) > self.llm_chunk_cache_size: oldest_key = self._chunk_cache_order.popleft() self._chunk_cache.pop(oldest_key, None) # ─────────────────────────────────────────────────────────────────── # LLM result resolution # ─────────────────────────────────────────────────────────────────── def _resolve_and_score( self, raw_relations: Sequence[RawExtractedRelation], *, source_celex: str, min_confidence: float, provenance: Counter[str], ) -> List[ExtractedRelation]: results: List[ExtractedRelation] = [] for raw in raw_relations: relation_type = _RELATION_TYPE_BY_VALUE.get(raw.relation_type) if relation_type is None: continue resolution = self._resolve_reference(raw.target_reference) if resolution is None: # Last resort: try normalize_celex directly. # normalize_celex is a formatter, not a resolver — it does not # verify the entity exists, so the score stays low. normalized = normalize_celex(raw.target_reference).strip() if normalized: resolution = LinkResolution( celex=normalized, score=self._conf.normalize_fallback_score, method="normalize_fallback", matched_text=raw.target_reference, ) if resolution is None: provenance["unresolved"] += 1 # Keep with raw reference — validation will filter if not identifier. target_celex = raw.target_reference.strip() resolution_method = "unresolved" resolution_score = 0.0 else: target_celex = resolution.celex resolution_method = resolution.method resolution_score = resolution.score confidence = self._compute_confidence( relation_type=relation_type, resolution_method=resolution_method, resolution_score=resolution_score, evidence=raw.text_evidence, ) if confidence < min_confidence: continue evidence = raw.text_evidence.strip() if raw.text_evidence else "" results.append( ExtractedRelation( source_celex=source_celex, target_celex=target_celex, relation_type=relation_type, confidence=confidence, text_evidence=self._truncate_evidence(evidence), relation_description=self._build_relation_description(raw), extraction_method="llm_extraction", raw_target_reference=raw.target_reference.strip(), resolution_method=resolution_method, resolution_score=resolution_score, ) ) provenance[f"llm:{resolution_method}"] += 1 return results def _resolve_reference(self, target_reference: str) -> Optional[LinkResolution]: """Resolve a textual reference to a CELEX via entity linker.""" linker = self.entity_linker if linker is None: return None resolution = linker.resolve(target_reference) if resolution is not None: return resolution # Try normalizing first, then resolving the normalized form. normalized = normalize_celex(target_reference).strip() if normalized and normalized != target_reference.strip(): resolution = linker.resolve(normalized) if resolution is not None: return resolution return None # ─────────────────────────────────────────────────────────────────── # Confidence scoring # ─────────────────────────────────────────────────────────────────── def _compute_confidence( self, *, relation_type: RelationType, resolution_method: str, resolution_score: float, evidence: str, ) -> float: conf = self._conf score = conf.base if relation_type != RelationType.CITES: score += conf.non_cites_bonus if resolution_method == "explicit": score += conf.explicit_resolution_bonus elif resolution_method in ("exact_alias", "fuzzy_alias"): score += conf.alias_resolution_bonus if resolution_method == "fuzzy_alias": score *= max(conf.fuzzy_min_factor, min(resolution_score, 1.0)) if evidence and len(evidence.strip()) >= conf.evidence_min_chars: score += conf.evidence_bonus return min(score, conf.max_confidence) # ─────────────────────────────────────────────────────────────────── # Text chunking # ─────────────────────────────────────────────────────────────────── def _chunk_text(self, text: str, max_size: int) -> List[str]: if len(text) <= max_size: return [text] paragraphs = text.split("\n\n") chunks: List[str] = [] current_chunk = "" for paragraph in paragraphs: if len(current_chunk) + len(paragraph) + 2 <= max_size: current_chunk += paragraph + "\n\n" continue if current_chunk: chunks.append(current_chunk.strip()) current_chunk = paragraph + "\n\n" if current_chunk: chunks.append(current_chunk.strip()) if not chunks: chunks = [text[:max_size]] final: List[str] = [] for chunk in chunks: if len(chunk) <= max_size: final.append(chunk) else: final.extend(self._split_oversized_chunk(chunk, max_size)) return final @staticmethod def _split_oversized_chunk(chunk: str, max_size: int) -> List[str]: """Break an oversized chunk at \\n boundaries, then hard-cut.""" lines = chunk.split("\n") parts: List[str] = [] current = "" for line in lines: if len(current) + len(line) + 1 <= max_size: current += line + "\n" continue if current: parts.append(current.strip()) if len(line) > max_size: start = 0 while start < len(line): parts.append(line[start : start + max_size].strip()) start += max_size else: current = line + "\n" continue current = "" if current.strip(): parts.append(current.strip()) return [p for p in parts if p] # ─────────────────────────────────────────────────────────────────── # Post-extraction (merge + validation) # ─────────────────────────────────────────────────────────────────── def _truncate_evidence(self, sentence: str) -> str: cleaned = re.sub(r"\s+", " ", sentence).strip() if len(cleaned) <= self.max_evidence_chars: return cleaned return cleaned[: self.max_evidence_chars - 3] + "..." def _build_relation_description(self, raw: RawExtractedRelation) -> Optional[str]: rationale = raw.relation_rationale.strip() if raw.relation_rationale else "" if not rationale: return None cleaned = re.sub(r"\s+", " ", rationale).strip() max_description_chars = max(self.max_evidence_chars * 2, self.min_description_chars) if len(cleaned) <= max_description_chars: return cleaned return cleaned[: max_description_chars - 3] + "..." def _merge_relations(self, relations: List[ExtractedRelation]) -> List[ExtractedRelation]: grouped: Dict[tuple[str, str, str], ExtractedRelation] = {} for relation in relations: key = (relation.source_celex, relation.target_celex, relation.relation_type.value) existing = grouped.get(key) if existing is None: grouped[key] = relation continue if relation.confidence <= existing.confidence: continue previous_evidence = existing.text_evidence if previous_evidence and previous_evidence != relation.text_evidence: grouped[key] = relation.model_copy( update={"text_evidence": f"{relation.text_evidence} | {previous_evidence}"} ) else: grouped[key] = relation merged = list(grouped.values()) merged.sort(key=lambda item: item.confidence, reverse=True) return merged def _validate_relations( self, relations: List[ExtractedRelation], *, source_celex: str, min_confidence: float, ) -> List[ExtractedRelation]: accepted: List[ExtractedRelation] = [] rejected_reasons: Counter[str] = Counter() normalized_source = normalize_celex(source_celex).lower() for relation in relations: reason = self._validate_relation( relation=relation, normalized_source=normalized_source, min_confidence=min_confidence, ) if reason is None: accepted.append(relation) continue rejected_reasons[reason] += 1 if rejected_reasons: logger.info("Validation rejects: %s", dict(rejected_reasons)) return accepted def _validate_relation( self, *, relation: ExtractedRelation, normalized_source: str, min_confidence: float, ) -> Optional[str]: target = normalize_celex(relation.target_celex).strip() if not target: return "missing_target" if target.lower() == normalized_source: return "self_reference" if relation.confidence < min_confidence: return "below_confidence" evidence = relation.text_evidence.strip() if len(evidence) < self.min_evidence_chars: return "evidence_too_short" if is_generic_target(target): return "generic_target" if not looks_like_identifier(target): return "non_identifier_target" return None