Source code for lalandre_rag.retrieval.strategies.hybrid

"""Hybrid search strategy mixin for RetrievalService (pre-computed embeddings)."""

import logging
import re
import time
from typing import Any, Dict, List, Optional, Tuple

from ..metrics import observe_retrieval_error, observe_retrieval_phase
from ..result import RetrievalResult

logger = logging.getLogger(__name__)

# Regex that detects explicit legal references in a query.
# When matched, dynamic fusion boosts lexical weight for precise BM25 recall.
_LEGAL_REF_RE: re.Pattern[str] = re.compile(
    r"\b("
    r"3\d{4}[A-Z]\d{4}|"  # CELEX
    r"\d{2,4}/\d{2,4}/(?:UE|EU|CE|EC|CEE)|"  # directive/regulation number
    r"(?:directive|regulation|r[èe]glement|decision|d[ée]cision)\s*"
    r"(?:\((?:UE|EU|CE|EC|CEE)\)\s*)?"
    r"(?:no\.?|n[°o]\s*)?\d{2,4}/\d{2,4}"
    r")\b",
    re.IGNORECASE,
)






[docs] class HybridStrategyMixin: """Provides hybrid_with_embedding(), _fuse_results(), and _resolve_fusion_override()."""
[docs] def hybrid_with_embedding( self, query: str, query_vector: List[float], top_k: int = 10, score_threshold: Optional[float] = None, filters: Optional[Dict[str, Any]] = None, collections: Optional[List[str]] = None, granularity: Optional[str] = None, embedding_preset: Optional[str] = None, ) -> List[RetrievalResult]: """Execute hybrid search with a pre-computed query embedding. Avoids redundant embedding computation when the caller already has the vector. """ operation = "hybrid_precomputed" total_started_at = time.perf_counter() active_embedding_svc, col_suffix = self._resolve_embedding_context( # type: ignore[attr-defined] embedding_preset ) expected_dim = active_embedding_svc.get_vector_size() actual_dim = len(query_vector) if actual_dim != expected_dim: raise ValueError(f"Invalid embedding dimension: expected {expected_dim}, got {actual_dim}") base_collections = self._base_collections_from_granularity( # type: ignore[attr-defined] granularity=granularity, collections=collections, ) search_collections = self._apply_preset_suffix(base_collections, col_suffix) # type: ignore[attr-defined] candidate_k = self._candidate_pool_size(top_k) # type: ignore[attr-defined] try: semantic_started_at = time.perf_counter() try: semantic_results = self.semantic_service.search( # type: ignore[attr-defined] query_vector=query_vector, top_k=candidate_k, filters=filters, collections=search_collections, ) except Exception as exc: observe_retrieval_error( operation=operation, phase="semantic_search", exc_or_reason=exc, ) raise observe_retrieval_phase( operation=operation, phase="semantic_search", duration_seconds=time.perf_counter() - semantic_started_at, ) lexical_target: Optional[str] = None if "chunks" in base_collections and "subdivisions" not in base_collections: lexical_target = "chunks" elif "subdivisions" in base_collections: lexical_target = "subdivisions" lexical_started_at = time.perf_counter() try: lexical_results = self._lexical_search_with_expansion( # type: ignore[attr-defined] query=query, top_k=candidate_k, filters=filters, target=lexical_target, output_top_k=top_k, ) except Exception as exc: observe_retrieval_error( operation=operation, phase="lexical_search", exc_or_reason=exc, ) raise observe_retrieval_phase( operation=operation, phase="lexical_search", duration_seconds=time.perf_counter() - lexical_started_at, ) fuse_started_at = time.perf_counter() fused_results = self._fuse_results(lexical_results, semantic_results) observe_retrieval_phase( operation=operation, phase="fusion", duration_seconds=time.perf_counter() - fuse_started_at, ) if score_threshold is not None: threshold_started_at = time.perf_counter() fused_results = [r for r in fused_results if r.score >= score_threshold] observe_retrieval_phase( operation=operation, phase="score_threshold", duration_seconds=time.perf_counter() - threshold_started_at, ) reranked = self._rerank_results( # type: ignore[attr-defined] query, fused_results, top_k, operation=operation ) return reranked[:top_k] finally: observe_retrieval_phase( operation=operation, phase="total", duration_seconds=time.perf_counter() - total_started_at, )
def _resolve_fusion_override(self, query: str) -> Tuple[Optional[float], Optional[float]]: """Return boosted (lexical_weight, semantic_weight) for queries with explicit legal refs. Returns (None, None) when no boost is needed. """ if not self.dynamic_fusion_enabled: # type: ignore[attr-defined] return None, None if not has_explicit_legal_reference(query): return None, None boosted_lex = min( self.lexical_weight * self._lexical_boost_factor, # type: ignore[attr-defined] self._lexical_boost_max, # type: ignore[attr-defined] ) remaining_sem = 1.0 - boosted_lex logger.debug("Dynamic fusion boost applied for legal reference in query") return boosted_lex, remaining_sem def _fuse_results( self, lexical_results: List[RetrievalResult], semantic_results: List[RetrievalResult], override_lexical_weight: Optional[float] = None, override_semantic_weight: Optional[float] = None, ) -> List[RetrievalResult]: return self.fusion_service.fuse( # type: ignore[attr-defined] lexical_results, semantic_results, override_lexical_weight=override_lexical_weight, override_semantic_weight=override_semantic_weight, )