Source code for lalandre_rag.retrieval.query_expansion

"""
Legal query expansion utilities.

Provides deterministic multi-query expansion for EU/FR legal retrieval.
"""

import re
from dataclasses import dataclass
from typing import List

from lalandre_core.config import get_config


[docs] @dataclass(frozen=True) class ExpandedQuery: """One expanded query candidate with a weighting hint.""" text: str weight: float strategy: str
[docs] class LegalQueryExpansionService: """ Deterministic query expansion focused on legal references (UE/France). The objective is recall improvement while keeping runtime bounded. """ _SPACE_RE = re.compile(r"\s+") _IDENTIFIER_RE = re.compile( r"\b(" r"3\d{4}[A-Z]\d{4}|" r"\d{2,4}/\d{2,4}/(?:UE|EU|CE|EC|CEE)|" r"(?:directive|regulation|r[èe]glement|decision|d[ée]cision)\s*" r"(?:\((?:UE|EU|CE|EC|CEE)\)\s*)?" r"(?:no\.?|n[°o]\s*)?\d{2,4}/\d{2,4}" r")\b", re.IGNORECASE, ) _TOKEN_RE = re.compile(r"[A-Za-zÀ-ÿ0-9/_-]+") _ABBREVIATION_RULES: tuple[tuple[re.Pattern[str], str], ...] = ( (re.compile(r"\bart\.\s*", re.IGNORECASE), "article "), (re.compile(r"\bdir\.\s*", re.IGNORECASE), "directive "), (re.compile(r"\br[èe]gl\.\s*", re.IGNORECASE), "règlement "), (re.compile(r"\bd[ée]c\.\s*", re.IGNORECASE), "décision "), ) _FR_TO_EN_RULES: tuple[tuple[re.Pattern[str], str], ...] = ( (re.compile(r"\br[èe]glement(?:s)?\b", re.IGNORECASE), "regulation"), (re.compile(r"\bd[ée]cision(?:s)?\b", re.IGNORECASE), "decision"), (re.compile(r"\bloi(?:s)?\b", re.IGNORECASE), "law"), (re.compile(r"\bd[ée]cret(?:s)?\b", re.IGNORECASE), "decree"), (re.compile(r"\btransposition\b", re.IGNORECASE), "implementation"), ) _EN_TO_FR_RULES: tuple[tuple[re.Pattern[str], str], ...] = ( (re.compile(r"\bregulation(?:s)?\b", re.IGNORECASE), "règlement"), (re.compile(r"\bdecision(?:s)?\b", re.IGNORECASE), "décision"), (re.compile(r"\blaw(?:s)?\b", re.IGNORECASE), "loi"), (re.compile(r"\bdecree(?:s)?\b", re.IGNORECASE), "décret"), (re.compile(r"\bimplementation\b", re.IGNORECASE), "transposition"), ) _STOPWORDS = frozenset( { "a", "à", "au", "aux", "avec", "ce", "ces", "dans", "de", "des", "du", "en", "et", "for", "how", "la", "le", "les", "of", "on", "ou", "par", "pour", "quel", "quelle", "quelles", "quels", "sur", "the", "to", "un", "une", "what", "which", "with", } ) _LEGAL_CUE_TOKENS = frozenset( { "article", "articles", "directive", "directives", "regulation", "regulations", "règlement", "règlements", "decision", "decisions", "décision", "décisions", "law", "laws", "loi", "lois", "decree", "décret", "décrets", "transposition", "implementation", "amendment", "modification", "abrogation", "repeal", } ) def __init__(self, *, min_query_chars: int = 24) -> None: self.min_query_chars = max(min_query_chars, 0)
[docs] def expand(self, query: str, *, max_variants: int = 3) -> List[ExpandedQuery]: """ Expand a query into deterministic variants. Always returns at least one query (the normalized original). """ normalized = self._normalize(query) if not normalized: return [] search_cfg = get_config().search bounded_max_variants = max(1, min(int(max_variants), search_cfg.query_expansion_max_variants_cap)) variants: List[ExpandedQuery] = [] seen: set[str] = set() self._add_variant( variants=variants, seen=seen, text=normalized, weight=1.0, strategy="original", ) # Avoid generating noisy variants for very short, non-reference queries. if len(normalized) < self.min_query_chars and not self._IDENTIFIER_RE.search(normalized): return variants self._add_variant( variants=variants, seen=seen, text=self._expand_abbreviations(normalized), weight=search_cfg.query_expansion_abbreviation_weight, strategy="abbreviation_normalization", ) self._add_variant( variants=variants, seen=seen, text=self._keyword_focus_variant(normalized), weight=search_cfg.query_expansion_keyword_focus_weight, strategy="keyword_focus", ) self._add_variant( variants=variants, seen=seen, text=self._reference_focus_variant(normalized), weight=search_cfg.query_expansion_reference_focus_weight, strategy="reference_focus", ) bilingual = self._bilingual_mirror_variant(normalized) if bilingual is not None: self._add_variant( variants=variants, seen=seen, text=bilingual, weight=search_cfg.query_expansion_bilingual_weight, strategy="bilingual_mirror", ) return variants[:bounded_max_variants]
def _normalize(self, text: str) -> str: value = self._SPACE_RE.sub(" ", text.strip()) return value.strip(" \t\r\n;,.") def _add_variant( self, *, variants: List[ExpandedQuery], seen: set[str], text: str, weight: float, strategy: str, ) -> None: normalized = self._normalize(text) if len(normalized) < 4: return dedupe_key = normalized.lower() if dedupe_key in seen: return seen.add(dedupe_key) variants.append( ExpandedQuery( text=normalized, weight=max(0.1, min(weight, 1.0)), strategy=strategy, ) ) def _expand_abbreviations(self, query: str) -> str: expanded = query for pattern, replacement in self._ABBREVIATION_RULES: expanded = pattern.sub(replacement, expanded) return expanded def _keyword_focus_variant(self, query: str) -> str: tokens = self._TOKEN_RE.findall(query) kept: list[str] = [] for token in tokens: lowered = token.lower() if lowered in self._STOPWORDS and not any(ch.isdigit() for ch in token): continue if len(token) <= 2 and token.isalpha(): continue kept.append(token) return " ".join(kept) def _reference_focus_variant(self, query: str) -> str: identifiers: list[str] = [match.group(0) for match in self._IDENTIFIER_RE.finditer(query)] if not identifiers: return "" legal_tokens: list[str] = [] for token in self._TOKEN_RE.findall(query): lowered = token.lower() if lowered in self._LEGAL_CUE_TOKENS: legal_tokens.append(token) merged: list[str] = identifiers + legal_tokens seen: set[str] = set() ordered_unique: list[str] = [] for item in merged: key = item.lower() if key in seen: continue seen.add(key) ordered_unique.append(item) return " ".join(ordered_unique) def _bilingual_mirror_variant(self, query: str) -> str | None: contains_french = bool(re.search(r"[éèàùâêîôûç]", query.lower())) rules = self._FR_TO_EN_RULES if contains_french else self._EN_TO_FR_RULES mirrored = query changed = False for pattern, replacement in rules: updated = pattern.sub(replacement, mirrored) if updated != mirrored: changed = True mirrored = updated return mirrored if changed else None