"""
Legal query expansion utilities.
Provides deterministic multi-query expansion for EU/FR legal retrieval.
"""
import re
from dataclasses import dataclass
from typing import List
from lalandre_core.config import get_config
[docs]
@dataclass(frozen=True)
class ExpandedQuery:
"""One expanded query candidate with a weighting hint."""
text: str
weight: float
strategy: str
[docs]
class LegalQueryExpansionService:
"""
Deterministic query expansion focused on legal references (UE/France).
The objective is recall improvement while keeping runtime bounded.
"""
_SPACE_RE = re.compile(r"\s+")
_IDENTIFIER_RE = re.compile(
r"\b("
r"3\d{4}[A-Z]\d{4}|"
r"\d{2,4}/\d{2,4}/(?:UE|EU|CE|EC|CEE)|"
r"(?:directive|regulation|r[èe]glement|decision|d[ée]cision)\s*"
r"(?:\((?:UE|EU|CE|EC|CEE)\)\s*)?"
r"(?:no\.?|n[°o]\s*)?\d{2,4}/\d{2,4}"
r")\b",
re.IGNORECASE,
)
_TOKEN_RE = re.compile(r"[A-Za-zÀ-ÿ0-9/_-]+")
_ABBREVIATION_RULES: tuple[tuple[re.Pattern[str], str], ...] = (
(re.compile(r"\bart\.\s*", re.IGNORECASE), "article "),
(re.compile(r"\bdir\.\s*", re.IGNORECASE), "directive "),
(re.compile(r"\br[èe]gl\.\s*", re.IGNORECASE), "règlement "),
(re.compile(r"\bd[ée]c\.\s*", re.IGNORECASE), "décision "),
)
_FR_TO_EN_RULES: tuple[tuple[re.Pattern[str], str], ...] = (
(re.compile(r"\br[èe]glement(?:s)?\b", re.IGNORECASE), "regulation"),
(re.compile(r"\bd[ée]cision(?:s)?\b", re.IGNORECASE), "decision"),
(re.compile(r"\bloi(?:s)?\b", re.IGNORECASE), "law"),
(re.compile(r"\bd[ée]cret(?:s)?\b", re.IGNORECASE), "decree"),
(re.compile(r"\btransposition\b", re.IGNORECASE), "implementation"),
)
_EN_TO_FR_RULES: tuple[tuple[re.Pattern[str], str], ...] = (
(re.compile(r"\bregulation(?:s)?\b", re.IGNORECASE), "règlement"),
(re.compile(r"\bdecision(?:s)?\b", re.IGNORECASE), "décision"),
(re.compile(r"\blaw(?:s)?\b", re.IGNORECASE), "loi"),
(re.compile(r"\bdecree(?:s)?\b", re.IGNORECASE), "décret"),
(re.compile(r"\bimplementation\b", re.IGNORECASE), "transposition"),
)
_STOPWORDS = frozenset(
{
"a",
"à",
"au",
"aux",
"avec",
"ce",
"ces",
"dans",
"de",
"des",
"du",
"en",
"et",
"for",
"how",
"la",
"le",
"les",
"of",
"on",
"ou",
"par",
"pour",
"quel",
"quelle",
"quelles",
"quels",
"sur",
"the",
"to",
"un",
"une",
"what",
"which",
"with",
}
)
_LEGAL_CUE_TOKENS = frozenset(
{
"article",
"articles",
"directive",
"directives",
"regulation",
"regulations",
"règlement",
"règlements",
"decision",
"decisions",
"décision",
"décisions",
"law",
"laws",
"loi",
"lois",
"decree",
"décret",
"décrets",
"transposition",
"implementation",
"amendment",
"modification",
"abrogation",
"repeal",
}
)
def __init__(self, *, min_query_chars: int = 24) -> None:
self.min_query_chars = max(min_query_chars, 0)
[docs]
def expand(self, query: str, *, max_variants: int = 3) -> List[ExpandedQuery]:
"""
Expand a query into deterministic variants.
Always returns at least one query (the normalized original).
"""
normalized = self._normalize(query)
if not normalized:
return []
search_cfg = get_config().search
bounded_max_variants = max(1, min(int(max_variants), search_cfg.query_expansion_max_variants_cap))
variants: List[ExpandedQuery] = []
seen: set[str] = set()
self._add_variant(
variants=variants,
seen=seen,
text=normalized,
weight=1.0,
strategy="original",
)
# Avoid generating noisy variants for very short, non-reference queries.
if len(normalized) < self.min_query_chars and not self._IDENTIFIER_RE.search(normalized):
return variants
self._add_variant(
variants=variants,
seen=seen,
text=self._expand_abbreviations(normalized),
weight=search_cfg.query_expansion_abbreviation_weight,
strategy="abbreviation_normalization",
)
self._add_variant(
variants=variants,
seen=seen,
text=self._keyword_focus_variant(normalized),
weight=search_cfg.query_expansion_keyword_focus_weight,
strategy="keyword_focus",
)
self._add_variant(
variants=variants,
seen=seen,
text=self._reference_focus_variant(normalized),
weight=search_cfg.query_expansion_reference_focus_weight,
strategy="reference_focus",
)
bilingual = self._bilingual_mirror_variant(normalized)
if bilingual is not None:
self._add_variant(
variants=variants,
seen=seen,
text=bilingual,
weight=search_cfg.query_expansion_bilingual_weight,
strategy="bilingual_mirror",
)
return variants[:bounded_max_variants]
def _normalize(self, text: str) -> str:
value = self._SPACE_RE.sub(" ", text.strip())
return value.strip(" \t\r\n;,.")
def _add_variant(
self,
*,
variants: List[ExpandedQuery],
seen: set[str],
text: str,
weight: float,
strategy: str,
) -> None:
normalized = self._normalize(text)
if len(normalized) < 4:
return
dedupe_key = normalized.lower()
if dedupe_key in seen:
return
seen.add(dedupe_key)
variants.append(
ExpandedQuery(
text=normalized,
weight=max(0.1, min(weight, 1.0)),
strategy=strategy,
)
)
def _expand_abbreviations(self, query: str) -> str:
expanded = query
for pattern, replacement in self._ABBREVIATION_RULES:
expanded = pattern.sub(replacement, expanded)
return expanded
def _keyword_focus_variant(self, query: str) -> str:
tokens = self._TOKEN_RE.findall(query)
kept: list[str] = []
for token in tokens:
lowered = token.lower()
if lowered in self._STOPWORDS and not any(ch.isdigit() for ch in token):
continue
if len(token) <= 2 and token.isalpha():
continue
kept.append(token)
return " ".join(kept)
def _reference_focus_variant(self, query: str) -> str:
identifiers: list[str] = [match.group(0) for match in self._IDENTIFIER_RE.finditer(query)]
if not identifiers:
return ""
legal_tokens: list[str] = []
for token in self._TOKEN_RE.findall(query):
lowered = token.lower()
if lowered in self._LEGAL_CUE_TOKENS:
legal_tokens.append(token)
merged: list[str] = identifiers + legal_tokens
seen: set[str] = set()
ordered_unique: list[str] = []
for item in merged:
key = item.lower()
if key in seen:
continue
seen.add(key)
ordered_unique.append(item)
return " ".join(ordered_unique)
def _bilingual_mirror_variant(self, query: str) -> str | None:
contains_french = bool(re.search(r"[éèàùâêîôûç]", query.lower()))
rules = self._FR_TO_EN_RULES if contains_french else self._EN_TO_FR_RULES
mirrored = query
changed = False
for pattern, replacement in rules:
updated = pattern.sub(replacement, mirrored)
if updated != mirrored:
changed = True
mirrored = updated
return mirrored if changed else None