Source code for lalandre_extraction.llm.agent

"""PydanticAI agent for structured relation extraction."""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

from lalandre_core.llm.structured import run_structured_agent
from pydantic_ai import Agent, ModelRetry

from .models import ALLOWED_RELATION_TYPES, ExtractionOutput


def _build_extraction_agent(
    *,
    min_evidence_chars: int,
    min_rationale_chars: int,
) -> Agent[Any, ExtractionOutput]:
    """Build an extraction agent with configurable validation thresholds."""

    agent: Agent[Any, ExtractionOutput] = Agent(
        output_type=ExtractionOutput,
        instructions=(
            "You are an EU/FR legal relation extractor.\n"
            "Identify ALL explicit legal relationships between acts in the text.\n"
            "Only extract explicitly stated relations. Do NOT infer or guess.\n"
            "Return an empty relations list if no explicit relation is present."
        ),
        output_retries=2,
        defer_model_check=True,
    )

    @agent.output_validator
    def _validate_extraction(result: ExtractionOutput) -> ExtractionOutput:
        filtered = []
        seen: set[tuple[str, str]] = set()
        for rel in result.relations:
            if rel.relation_type not in ALLOWED_RELATION_TYPES:
                continue
            if not rel.target_reference:
                continue
            if len(rel.text_evidence) < min_evidence_chars:
                continue
            if rel.relation_rationale and len(rel.relation_rationale) < min_rationale_chars:
                rel = rel.model_copy(update={"relation_rationale": ""})
            key = (rel.target_reference.lower(), rel.relation_type)
            if key in seen:
                continue
            seen.add(key)
            filtered.append(rel)
        if not filtered and result.relations:
            raise ModelRetry(
                f"All {len(result.relations)} relations were filtered out. "
                f"Allowed types: {', '.join(sorted(ALLOWED_RELATION_TYPES))}. "
                f"Min evidence chars: {min_evidence_chars}."
            )
        return result.model_copy(update={"relations": filtered})

    return agent


[docs] def run_extraction_agent( *, prompt: str, generate_text: Callable[[str], str], model_name: str, min_evidence_chars: int = 8, min_rationale_chars: int = 24, ) -> tuple[ExtractionOutput, int]: """Run the extraction agent and return validated output + retry count.""" agent = _build_extraction_agent( min_evidence_chars=min_evidence_chars, min_rationale_chars=min_rationale_chars, ) return run_structured_agent( agent=agent, prompt=prompt, llm_or_generate=generate_text, model_name=model_name, )