Source code for lalandre_core.linking.ner_client

"""Tiny HTTP client for the dedicated NER service.

The NER service exposes a single ``POST /detect`` endpoint that runs GLiNER
in zero-shot mode. The client is deliberately minimal: it does one synchronous
request per call, surfaces a typed result, and never raises on network errors —
it returns an empty span list and lets the caller log/skip.

Designed to be reused outside RAG (extraction pipeline, evaluation scripts).
"""

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Iterable, List, Optional

import requests

logger = logging.getLogger(__name__)


[docs] @dataclass(frozen=True) class NerSpan: """A single span detected by the NER service.""" text: str start: int end: int type: str score: float
[docs] class NerClient: """Thin HTTP wrapper around ``ner-service /detect``. Failure modes (network error, non-2xx, malformed JSON) are swallowed and logged at WARNING level; the call returns ``[]`` in that case so the caller can degrade gracefully. """ def __init__( self, base_url: str, *, timeout_seconds: float = 5.0, default_threshold: float = 0.5, default_entity_types: Optional[Iterable[str]] = None, ) -> None: self._base_url = base_url.rstrip("/") self._timeout_seconds = timeout_seconds self._default_threshold = default_threshold self._default_entity_types: List[str] = list( default_entity_types or ("regulation", "directive", "article", "decision", "communication") ) @property def base_url(self) -> str: """Return the base URL of the configured NER service.""" return self._base_url
[docs] def detect( self, text: str, *, entity_types: Optional[Iterable[str]] = None, threshold: Optional[float] = None, ) -> List[NerSpan]: """Call ``POST /detect`` and return the matched spans (empty on any failure).""" if not text or not text.strip(): return [] payload = { "text": text, "entity_types": list(entity_types) if entity_types else self._default_entity_types, "threshold": (self._default_threshold if threshold is None else float(threshold)), } try: response = requests.post( f"{self._base_url}/detect", json=payload, timeout=self._timeout_seconds, ) except requests.RequestException: logger.warning("ner-service unreachable at %s", self._base_url, exc_info=True) return [] if response.status_code != 200: logger.warning( "ner-service returned status=%s body=%r", response.status_code, response.text[:200], ) return [] try: data = response.json() raw_spans = data.get("spans") or [] except ValueError: logger.warning("ner-service returned non-JSON body") return [] spans: List[NerSpan] = [] for item in raw_spans: try: spans.append( NerSpan( text=str(item["text"]), start=int(item["start"]), end=int(item["end"]), type=str(item["type"]), score=float(item["score"]), ) ) except (KeyError, TypeError, ValueError): continue return spans
__all__ = ["NerClient", "NerSpan"]