Source code for lalandre_extraction.relation_extractor
"""
Regulatory Relation Extractor
GraphRAG-style LLM-first extraction of legal relationships between acts.
The LLM reads text chunks directly and identifies all relations.
Entity linking resolves references to canonical CELEX identifiers post-extraction.
"""
import logging
import re
from collections import Counter, deque
from concurrent.futures import Future, ThreadPoolExecutor
from datetime import datetime
from hashlib import sha256
from threading import Lock
from typing import Dict, List, Optional, Sequence
from lalandre_core.config import ExtractionConfidenceConfig, get_config
from lalandre_core.linking import LegalEntityLinker, LinkResolution, is_generic_target, looks_like_identifier
from lalandre_core.models.types import RelationType
from lalandre_core.utils import normalize_celex
from pydantic import BaseModel, Field
from .llm import ALLOWED_RELATION_TYPES, ExtractionLLMClient, RawExtractedRelation
logger = logging.getLogger(__name__)
# ═══════════════════════════════════════════════════════════════════════════
# Model
# ═══════════════════════════════════════════════════════════════════════════
[docs]
class ExtractedRelation(BaseModel):
"""A legal relationship extracted from text."""
source_celex: str
target_celex: str
relation_type: RelationType
confidence: float = Field(ge=0.0, le=1.0)
text_evidence: str
relation_description: Optional[str] = None
extraction_method: str = "llm_extraction"
effect_date: Optional[datetime] = None
source_subdivision: Optional[str] = None
target_subdivision: Optional[str] = None
raw_target_reference: Optional[str] = None
resolution_method: Optional[str] = None
resolution_score: Optional[float] = None
# ═══════════════════════════════════════════════════════════════════════════
# Extractor
# ═══════════════════════════════════════════════════════════════════════════
_RELATION_TYPE_BY_VALUE: Dict[str, RelationType] = {
# Fast lookup for LLM string output -> internal enum.
rt.value: rt
for rt in RelationType
if rt.value in ALLOWED_RELATION_TYPES
}
[docs]
class RegulatoryRelationExtractor:
"""
LLM-first extraction pipeline (GraphRAG-style):
- text chunking
- LLM extraction per chunk
- entity linking (post-LLM resolution)
- confidence scoring
- validation & merge
"""
# ───────────────────────────────────────────────────────────────────
# Initialization
# ───────────────────────────────────────────────────────────────────
def __init__(
self,
max_chunk_size: Optional[int] = None,
entity_linker: Optional[LegalEntityLinker] = None,
validation_enabled: Optional[bool] = None,
min_evidence_chars: Optional[int] = None,
llm_client: Optional[ExtractionLLMClient] = None,
) -> None:
config = get_config()
extraction_cfg = config.extraction
if max_chunk_size is None:
max_chunk_size = config.chunking.extraction_max_chunk_chars
default_min_confidence = config.gateway.job_extract_min_confidence
if default_min_confidence is None:
raise ValueError("gateway.job_extract_min_confidence must be configured for extraction")
resolved_validation_enabled = (
validation_enabled if validation_enabled is not None else extraction_cfg.validation_enabled
)
resolved_min_evidence_chars = (
min_evidence_chars if min_evidence_chars is not None else extraction_cfg.min_evidence_chars
)
self.llm_client = llm_client if llm_client is not None else ExtractionLLMClient.from_runtime(config=config)
self.max_chunk_size = max_chunk_size
self.llm_max_parallel_chunks = max(1, int(extraction_cfg.llm_max_parallel_chunks))
self.llm_chunk_cache_size = max(0, int(extraction_cfg.llm_chunk_cache_size))
self.default_min_confidence = float(default_min_confidence)
self.entity_linker = entity_linker
self.validation_enabled = bool(resolved_validation_enabled)
self.min_evidence_chars = int(resolved_min_evidence_chars)
self.max_evidence_chars = int(extraction_cfg.max_evidence_chars)
self.min_description_chars = int(extraction_cfg.min_description_chars)
self._chunk_cache: Dict[str, List[RawExtractedRelation]] = {}
self._chunk_cache_order: deque[str] = deque()
self._chunk_cache_lock = Lock()
self._conf: ExtractionConfidenceConfig = extraction_cfg.confidence
if self.llm_client is not None:
logger.info(
"Extractor LLM enabled: provider=%s model=%s",
self.llm_client.provider,
self.llm_client.model,
)
[docs]
def set_entity_linker(self, entity_linker: Optional[LegalEntityLinker]) -> None:
"""Replace the entity linker used for post-LLM reference resolution."""
self.entity_linker = entity_linker
# ───────────────────────────────────────────────────────────────────
# Public API
# ───────────────────────────────────────────────────────────────────
[docs]
def extract_relations(
self,
text: str,
source_celex: str,
min_confidence: Optional[float] = None,
) -> List[ExtractedRelation]:
"""Extract, resolve, merge, and validate relations from one act text."""
if self.llm_client is None:
logger.warning("No LLM client configured; extraction skipped")
return []
resolved_min_confidence = float(min_confidence) if min_confidence is not None else self.default_min_confidence
chunks = self._chunk_text(text, self.max_chunk_size)
relations: List[ExtractedRelation] = []
provenance: Counter[str] = Counter()
raw_relations_count = 0
failed_chunks = 0
if len(chunks) <= 1 or self.llm_max_parallel_chunks <= 1:
for i, chunk in enumerate(chunks):
try:
resolved, chunk_provenance, raw_count = self._extract_single_chunk(
chunk=chunk,
source_celex=source_celex,
min_confidence=resolved_min_confidence,
)
except Exception:
failed_chunks += 1
logger.error(
"Chunk %d/%d failed for celex=%s",
i + 1,
len(chunks),
source_celex,
exc_info=True,
)
continue
relations.extend(resolved)
provenance.update(chunk_provenance)
raw_relations_count += raw_count
else:
max_workers = min(self.llm_max_parallel_chunks, len(chunks))
logger.info(
"Extraction chunk parallelism enabled: chunks=%d workers=%d",
len(chunks),
max_workers,
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures: List[Future[tuple[List[ExtractedRelation], Counter[str], int]]] = [
executor.submit(
self._extract_single_chunk,
chunk=chunk,
source_celex=source_celex,
min_confidence=resolved_min_confidence,
)
for chunk in chunks
]
for i, future in enumerate(futures):
try:
resolved, chunk_provenance, raw_count = future.result()
except Exception:
failed_chunks += 1
logger.error(
"Chunk %d/%d failed for celex=%s",
i + 1,
len(chunks),
source_celex,
exc_info=True,
)
continue
relations.extend(resolved)
provenance.update(chunk_provenance)
raw_relations_count += raw_count
logger.info(
"Extraction summary: chunks=%d, failed=%d, raw_relations=%d, resolved=%d",
len(chunks),
failed_chunks,
raw_relations_count,
len(relations),
)
if provenance:
logger.info("Extraction provenance: %s", dict(provenance))
if not relations:
return []
merged = self._merge_relations(relations)
if not self.validation_enabled:
return merged
validated = self._validate_relations(
merged,
source_celex=source_celex,
min_confidence=resolved_min_confidence,
)
if len(validated) != len(merged):
logger.info(
"Validation filtered %d relations (kept=%d)",
len(merged) - len(validated),
len(validated),
)
return validated
def _extract_single_chunk(
self,
*,
chunk: str,
source_celex: str,
min_confidence: float,
) -> tuple[List[ExtractedRelation], Counter[str], int]:
raw_relations = self._extract_chunk_with_cache(chunk=chunk, source_celex=source_celex)
chunk_provenance: Counter[str] = Counter()
resolved = self._resolve_and_score(
raw_relations,
source_celex=source_celex,
min_confidence=min_confidence,
provenance=chunk_provenance,
)
return resolved, chunk_provenance, len(raw_relations)
def _extract_chunk_with_cache(
self,
*,
chunk: str,
source_celex: str,
) -> List[RawExtractedRelation]:
llm_client = self.llm_client
assert llm_client is not None
if self.llm_chunk_cache_size <= 0:
return llm_client.extract_relations(chunk, source_celex)
cache_key = self._build_chunk_cache_key(chunk=chunk, source_celex=source_celex)
cached = self._get_cached_chunk_relations(cache_key)
if cached is not None:
return cached
raw_relations = llm_client.extract_relations(chunk, source_celex)
self._set_cached_chunk_relations(cache_key, raw_relations)
return raw_relations
def _build_chunk_cache_key(self, *, chunk: str, source_celex: str) -> str:
llm_client = self.llm_client
assert llm_client is not None
provider = llm_client.provider
model = llm_client.model
fingerprint = f"{provider}|{model}|{source_celex}|{chunk}"
return sha256(fingerprint.encode("utf-8")).hexdigest()
def _get_cached_chunk_relations(self, cache_key: str) -> Optional[List[RawExtractedRelation]]:
with self._chunk_cache_lock:
cached = self._chunk_cache.get(cache_key)
if cached is None:
return None
# Return a shallow copy to keep cache immutable from callers.
return list(cached)
def _set_cached_chunk_relations(
self,
cache_key: str,
raw_relations: Sequence[RawExtractedRelation],
) -> None:
with self._chunk_cache_lock:
if cache_key in self._chunk_cache:
self._chunk_cache[cache_key] = list(raw_relations)
return
self._chunk_cache[cache_key] = list(raw_relations)
self._chunk_cache_order.append(cache_key)
while len(self._chunk_cache_order) > self.llm_chunk_cache_size:
oldest_key = self._chunk_cache_order.popleft()
self._chunk_cache.pop(oldest_key, None)
# ───────────────────────────────────────────────────────────────────
# LLM result resolution
# ───────────────────────────────────────────────────────────────────
def _resolve_and_score(
self,
raw_relations: Sequence[RawExtractedRelation],
*,
source_celex: str,
min_confidence: float,
provenance: Counter[str],
) -> List[ExtractedRelation]:
results: List[ExtractedRelation] = []
for raw in raw_relations:
relation_type = _RELATION_TYPE_BY_VALUE.get(raw.relation_type)
if relation_type is None:
continue
resolution = self._resolve_reference(raw.target_reference)
if resolution is None:
# Last resort: try normalize_celex directly.
# normalize_celex is a formatter, not a resolver — it does not
# verify the entity exists, so the score stays low.
normalized = normalize_celex(raw.target_reference).strip()
if normalized:
resolution = LinkResolution(
celex=normalized,
score=self._conf.normalize_fallback_score,
method="normalize_fallback",
matched_text=raw.target_reference,
)
if resolution is None:
provenance["unresolved"] += 1
# Keep with raw reference — validation will filter if not identifier.
target_celex = raw.target_reference.strip()
resolution_method = "unresolved"
resolution_score = 0.0
else:
target_celex = resolution.celex
resolution_method = resolution.method
resolution_score = resolution.score
confidence = self._compute_confidence(
relation_type=relation_type,
resolution_method=resolution_method,
resolution_score=resolution_score,
evidence=raw.text_evidence,
)
if confidence < min_confidence:
continue
evidence = raw.text_evidence.strip() if raw.text_evidence else ""
results.append(
ExtractedRelation(
source_celex=source_celex,
target_celex=target_celex,
relation_type=relation_type,
confidence=confidence,
text_evidence=self._truncate_evidence(evidence),
relation_description=self._build_relation_description(raw),
extraction_method="llm_extraction",
raw_target_reference=raw.target_reference.strip(),
resolution_method=resolution_method,
resolution_score=resolution_score,
)
)
provenance[f"llm:{resolution_method}"] += 1
return results
def _resolve_reference(self, target_reference: str) -> Optional[LinkResolution]:
"""Resolve a textual reference to a CELEX via entity linker."""
linker = self.entity_linker
if linker is None:
return None
resolution = linker.resolve(target_reference)
if resolution is not None:
return resolution
# Try normalizing first, then resolving the normalized form.
normalized = normalize_celex(target_reference).strip()
if normalized and normalized != target_reference.strip():
resolution = linker.resolve(normalized)
if resolution is not None:
return resolution
return None
# ───────────────────────────────────────────────────────────────────
# Confidence scoring
# ───────────────────────────────────────────────────────────────────
def _compute_confidence(
self,
*,
relation_type: RelationType,
resolution_method: str,
resolution_score: float,
evidence: str,
) -> float:
conf = self._conf
score = conf.base
if relation_type != RelationType.CITES:
score += conf.non_cites_bonus
if resolution_method == "explicit":
score += conf.explicit_resolution_bonus
elif resolution_method in ("exact_alias", "fuzzy_alias"):
score += conf.alias_resolution_bonus
if resolution_method == "fuzzy_alias":
score *= max(conf.fuzzy_min_factor, min(resolution_score, 1.0))
if evidence and len(evidence.strip()) >= conf.evidence_min_chars:
score += conf.evidence_bonus
return min(score, conf.max_confidence)
# ───────────────────────────────────────────────────────────────────
# Text chunking
# ───────────────────────────────────────────────────────────────────
def _chunk_text(self, text: str, max_size: int) -> List[str]:
if len(text) <= max_size:
return [text]
paragraphs = text.split("\n\n")
chunks: List[str] = []
current_chunk = ""
for paragraph in paragraphs:
if len(current_chunk) + len(paragraph) + 2 <= max_size:
current_chunk += paragraph + "\n\n"
continue
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = paragraph + "\n\n"
if current_chunk:
chunks.append(current_chunk.strip())
if not chunks:
chunks = [text[:max_size]]
final: List[str] = []
for chunk in chunks:
if len(chunk) <= max_size:
final.append(chunk)
else:
final.extend(self._split_oversized_chunk(chunk, max_size))
return final
@staticmethod
def _split_oversized_chunk(chunk: str, max_size: int) -> List[str]:
"""Break an oversized chunk at \\n boundaries, then hard-cut."""
lines = chunk.split("\n")
parts: List[str] = []
current = ""
for line in lines:
if len(current) + len(line) + 1 <= max_size:
current += line + "\n"
continue
if current:
parts.append(current.strip())
if len(line) > max_size:
start = 0
while start < len(line):
parts.append(line[start : start + max_size].strip())
start += max_size
else:
current = line + "\n"
continue
current = ""
if current.strip():
parts.append(current.strip())
return [p for p in parts if p]
# ───────────────────────────────────────────────────────────────────
# Post-extraction (merge + validation)
# ───────────────────────────────────────────────────────────────────
def _truncate_evidence(self, sentence: str) -> str:
cleaned = re.sub(r"\s+", " ", sentence).strip()
if len(cleaned) <= self.max_evidence_chars:
return cleaned
return cleaned[: self.max_evidence_chars - 3] + "..."
def _build_relation_description(self, raw: RawExtractedRelation) -> Optional[str]:
rationale = raw.relation_rationale.strip() if raw.relation_rationale else ""
if not rationale:
return None
cleaned = re.sub(r"\s+", " ", rationale).strip()
max_description_chars = max(self.max_evidence_chars * 2, self.min_description_chars)
if len(cleaned) <= max_description_chars:
return cleaned
return cleaned[: max_description_chars - 3] + "..."
def _merge_relations(self, relations: List[ExtractedRelation]) -> List[ExtractedRelation]:
grouped: Dict[tuple[str, str, str], ExtractedRelation] = {}
for relation in relations:
key = (relation.source_celex, relation.target_celex, relation.relation_type.value)
existing = grouped.get(key)
if existing is None:
grouped[key] = relation
continue
if relation.confidence <= existing.confidence:
continue
previous_evidence = existing.text_evidence
if previous_evidence and previous_evidence != relation.text_evidence:
grouped[key] = relation.model_copy(
update={"text_evidence": f"{relation.text_evidence} | {previous_evidence}"}
)
else:
grouped[key] = relation
merged = list(grouped.values())
merged.sort(key=lambda item: item.confidence, reverse=True)
return merged
def _validate_relations(
self,
relations: List[ExtractedRelation],
*,
source_celex: str,
min_confidence: float,
) -> List[ExtractedRelation]:
accepted: List[ExtractedRelation] = []
rejected_reasons: Counter[str] = Counter()
normalized_source = normalize_celex(source_celex).lower()
for relation in relations:
reason = self._validate_relation(
relation=relation,
normalized_source=normalized_source,
min_confidence=min_confidence,
)
if reason is None:
accepted.append(relation)
continue
rejected_reasons[reason] += 1
if rejected_reasons:
logger.info("Validation rejects: %s", dict(rejected_reasons))
return accepted
def _validate_relation(
self,
*,
relation: ExtractedRelation,
normalized_source: str,
min_confidence: float,
) -> Optional[str]:
target = normalize_celex(relation.target_celex).strip()
if not target:
return "missing_target"
if target.lower() == normalized_source:
return "self_reference"
if relation.confidence < min_confidence:
return "below_confidence"
evidence = relation.text_evidence.strip()
if len(evidence) < self.min_evidence_chars:
return "evidence_too_short"
if is_generic_target(target):
return "generic_target"
if not looks_like_identifier(target):
return "non_identifier_target"
return None