Source code for lalandre_rag.linker_factory

"""Build a ``LegalEntityLinker`` wired for the rag-service runtime.

The linker is shared with the extraction pipeline. For RAG use, we additionally
populate ``act_id`` on each ``ActAliasEntry`` (so resolutions carry the act
primary key) and supply an ``article_lookup`` callable that maps
``(act_id, article_number)`` to a subdivision id using a small in-memory cache
seeded from the database.
"""

from __future__ import annotations

import logging
import os
import re
from functools import lru_cache
from threading import RLock
from typing import Callable, List, Optional

from lalandre_core.linking import ActAliasEntry, LegalEntityLinker, NerClient
from lalandre_db_postgres import PostgresRepository
from lalandre_db_postgres.models import ActsSQL, SubdivisionsSQL

from .ner_external import build_ner_external_detector
from .prose_linker import ExternalDetector

logger = logging.getLogger(__name__)


_ARTICLE_SUBDIVISION_TYPE = "ARTICLE"

_ARTICLE_NUMBER_NORMALIZER = re.compile(r"\s+")

# Canonical EU CELEX shape (year + sector letter + sequence): used to flag
# the "main" normative texts so they win acronym ownership over derivative
# notes / letters / reports that mention the same acronym in their title.
_CELEX_CANONICAL_RE = re.compile(r"^3\d{4}[A-Z]\d{4}$")


def _normalize_article_number(raw: str) -> str:
    """Normalize an article number for lookup (trim, collapse whitespace, lowercase)."""
    if raw is None:
        return ""
    cleaned = _ARTICLE_NUMBER_NORMALIZER.sub("", raw).strip().lower()
    # Strip trailing punctuation that may come from regex capture.
    cleaned = cleaned.rstrip(",;.")
    return cleaned


class _ArticleLookup:
    """Callable cache that resolves ``(act_id, article_number) -> subdivision_id``.

    Backed by PostgresRepository. Uses a single shared LRU cache (thread-safe
    via RLock). Returns ``None`` when no matching subdivision is found.
    """

    def __init__(self, pg_repo: PostgresRepository, *, max_entries: int = 4096) -> None:
        self._pg_repo = pg_repo
        self._lock = RLock()
        self._lookup = lru_cache(maxsize=max_entries)(self._fetch)

    def __call__(self, act_id: int, article_number: str) -> Optional[int]:
        normalized = _normalize_article_number(article_number)
        if not normalized:
            return None
        with self._lock:
            return self._lookup(act_id, normalized)

    def _fetch(self, act_id: int, normalized_article: str) -> Optional[int]:
        try:
            with self._pg_repo.get_session() as session:
                candidates = (
                    session.query(SubdivisionsSQL.id, SubdivisionsSQL.number)
                    .filter(SubdivisionsSQL.act_id == act_id)
                    .filter(SubdivisionsSQL.subdivision_type == _ARTICLE_SUBDIVISION_TYPE)
                    .all()
                )
        except Exception:
            logger.debug(
                "Article lookup failed for act_id=%s article=%s",
                act_id,
                normalized_article,
                exc_info=True,
            )
            return None

        for subdivision_id, number in candidates:
            if number is None:
                continue
            if _normalize_article_number(number) == normalized_article:
                return subdivision_id
        return None


[docs] def build_linker( pg_repo: PostgresRepository, *, fuzzy_threshold: float, fuzzy_min_gap: float, fuzzy_limit: int, min_alias_chars: int, article_cache_size: int = 4096, ) -> LegalEntityLinker: """Construct a ``LegalEntityLinker`` seeded from the acts table. The returned linker carries ``act_id`` on every alias entry and is wired with an ``article_lookup`` callable that queries ``subdivisions`` on demand (LRU-cached, bounded). """ with pg_repo.get_session() as session: acts = session.query(ActsSQL).all() snapshot = [ { "id": act.id, "celex": act.celex, "title": act.title, "eli": act.eli, "official_journal_reference": act.official_journal_reference, "form_number": act.form_number, } for act in acts ] # Sort so canonical EU acts come first: when an acronym (e.g. DORA) is # shared by the main regulation and derivative notes/letters, the main # regulation wins ownership of the alias. snapshot.sort(key=lambda a: 0 if _CELEX_CANONICAL_RE.match(a["celex"] or "") else 1) entries: List[ActAliasEntry] = [] for act in snapshot: aliases = LegalEntityLinker.derive_aliases( title=act["title"], eli=act["eli"], official_journal_reference=act["official_journal_reference"], form_number=act["form_number"], ) # Only canonical EU acts get acronyms — derivative ESA/EIOPA notes # that happen to mention "(DORA)" or "(MAR)" in their title should # not steal the acronym from the main regulation. is_canonical = bool(_CELEX_CANONICAL_RE.match(act["celex"] or "")) acronyms = LegalEntityLinker.derive_acronyms(act["title"]) if is_canonical else () entries.append( ActAliasEntry( celex=act["celex"], title=act["title"], aliases=aliases, acronyms=acronyms, act_id=act["id"], eli=act["eli"], ) ) article_lookup: Callable[[int, str], Optional[int]] = _ArticleLookup( pg_repo, max_entries=article_cache_size, ) linker = LegalEntityLinker( entries, fuzzy_threshold=fuzzy_threshold, fuzzy_min_gap=fuzzy_min_gap, fuzzy_limit=fuzzy_limit, min_alias_chars=min_alias_chars, article_lookup=article_lookup, ) logger.info( "Prose linker initialized: acts=%s aliases=%s article_cache=%s", len(entries), linker.alias_count, article_cache_size, ) return linker
[docs] def build_external_detector(linker: LegalEntityLinker) -> Optional[ExternalDetector]: """Construct the optional NER-backed ``ExternalDetector`` for prose linking. Reads ``NER_SERVICE_URL`` from the environment. When unset (or empty), returns ``None`` so the regex+fuzzy linker keeps its V1 behaviour with zero overhead. When set, builds a small HTTP client and wraps it in an adapter that resolves NER spans through the same ``LegalEntityLinker``. """ base_url = (os.environ.get("NER_SERVICE_URL") or "").strip() if not base_url: return None timeout_str = os.environ.get("NER_SERVICE_TIMEOUT_SECONDS", "5.0") try: timeout_seconds = float(timeout_str) except ValueError: timeout_seconds = 5.0 try: client = NerClient(base_url, timeout_seconds=timeout_seconds) except Exception: logger.warning( "NER_SERVICE_URL=%r set but client init failed; disabling NER hook", base_url, exc_info=True, ) return None detector = build_ner_external_detector(client, linker) logger.info("NER external detector enabled: %s (timeout=%.1fs)", base_url, timeout_seconds) return detector
__all__ = ["build_linker", "build_external_detector"]