"""Hybrid search strategy mixin for RetrievalService (pre-computed embeddings)."""
import logging
import re
import time
from typing import Any, Dict, List, Optional, Tuple
from ..metrics import observe_retrieval_error, observe_retrieval_phase
from ..result import RetrievalResult
logger = logging.getLogger(__name__)
# Regex that detects explicit legal references in a query.
# When matched, dynamic fusion boosts lexical weight for precise BM25 recall.
_LEGAL_REF_RE: re.Pattern[str] = re.compile(
r"\b("
r"3\d{4}[A-Z]\d{4}|" # CELEX
r"\d{2,4}/\d{2,4}/(?:UE|EU|CE|EC|CEE)|" # directive/regulation number
r"(?:directive|regulation|r[èe]glement|decision|d[ée]cision)\s*"
r"(?:\((?:UE|EU|CE|EC|CEE)\)\s*)?"
r"(?:no\.?|n[°o]\s*)?\d{2,4}/\d{2,4}"
r")\b",
re.IGNORECASE,
)
[docs]
def has_explicit_legal_reference(query: str) -> bool:
"""Return True when the query contains a strong legal-reference cue."""
return bool(_LEGAL_REF_RE.search(query))
[docs]
class HybridStrategyMixin:
"""Provides hybrid_with_embedding(), _fuse_results(), and _resolve_fusion_override()."""
[docs]
def hybrid_with_embedding(
self,
query: str,
query_vector: List[float],
top_k: int = 10,
score_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
collections: Optional[List[str]] = None,
granularity: Optional[str] = None,
embedding_preset: Optional[str] = None,
) -> List[RetrievalResult]:
"""Execute hybrid search with a pre-computed query embedding.
Avoids redundant embedding computation when the caller already has the vector.
"""
operation = "hybrid_precomputed"
total_started_at = time.perf_counter()
active_embedding_svc, col_suffix = self._resolve_embedding_context( # type: ignore[attr-defined]
embedding_preset
)
expected_dim = active_embedding_svc.get_vector_size()
actual_dim = len(query_vector)
if actual_dim != expected_dim:
raise ValueError(f"Invalid embedding dimension: expected {expected_dim}, got {actual_dim}")
base_collections = self._base_collections_from_granularity( # type: ignore[attr-defined]
granularity=granularity,
collections=collections,
)
search_collections = self._apply_preset_suffix(base_collections, col_suffix) # type: ignore[attr-defined]
candidate_k = self._candidate_pool_size(top_k) # type: ignore[attr-defined]
try:
semantic_started_at = time.perf_counter()
try:
semantic_results = self.semantic_service.search( # type: ignore[attr-defined]
query_vector=query_vector,
top_k=candidate_k,
filters=filters,
collections=search_collections,
)
except Exception as exc:
observe_retrieval_error(
operation=operation,
phase="semantic_search",
exc_or_reason=exc,
)
raise
observe_retrieval_phase(
operation=operation,
phase="semantic_search",
duration_seconds=time.perf_counter() - semantic_started_at,
)
lexical_target: Optional[str] = None
if "chunks" in base_collections and "subdivisions" not in base_collections:
lexical_target = "chunks"
elif "subdivisions" in base_collections:
lexical_target = "subdivisions"
lexical_started_at = time.perf_counter()
try:
lexical_results = self._lexical_search_with_expansion( # type: ignore[attr-defined]
query=query,
top_k=candidate_k,
filters=filters,
target=lexical_target,
output_top_k=top_k,
)
except Exception as exc:
observe_retrieval_error(
operation=operation,
phase="lexical_search",
exc_or_reason=exc,
)
raise
observe_retrieval_phase(
operation=operation,
phase="lexical_search",
duration_seconds=time.perf_counter() - lexical_started_at,
)
fuse_started_at = time.perf_counter()
fused_results = self._fuse_results(lexical_results, semantic_results)
observe_retrieval_phase(
operation=operation,
phase="fusion",
duration_seconds=time.perf_counter() - fuse_started_at,
)
if score_threshold is not None:
threshold_started_at = time.perf_counter()
fused_results = [r for r in fused_results if r.score >= score_threshold]
observe_retrieval_phase(
operation=operation,
phase="score_threshold",
duration_seconds=time.perf_counter() - threshold_started_at,
)
reranked = self._rerank_results( # type: ignore[attr-defined]
query, fused_results, top_k, operation=operation
)
return reranked[:top_k]
finally:
observe_retrieval_phase(
operation=operation,
phase="total",
duration_seconds=time.perf_counter() - total_started_at,
)
def _resolve_fusion_override(self, query: str) -> Tuple[Optional[float], Optional[float]]:
"""Return boosted (lexical_weight, semantic_weight) for queries with explicit legal refs.
Returns (None, None) when no boost is needed.
"""
if not self.dynamic_fusion_enabled: # type: ignore[attr-defined]
return None, None
if not has_explicit_legal_reference(query):
return None, None
boosted_lex = min(
self.lexical_weight * self._lexical_boost_factor, # type: ignore[attr-defined]
self._lexical_boost_max, # type: ignore[attr-defined]
)
remaining_sem = 1.0 - boosted_lex
logger.debug("Dynamic fusion boost applied for legal reference in query")
return boosted_lex, remaining_sem
def _fuse_results(
self,
lexical_results: List[RetrievalResult],
semantic_results: List[RetrievalResult],
override_lexical_weight: Optional[float] = None,
override_semantic_weight: Optional[float] = None,
) -> List[RetrievalResult]:
return self.fusion_service.fuse( # type: ignore[attr-defined]
lexical_results,
semantic_results,
override_lexical_weight=override_lexical_weight,
override_semantic_weight=override_semantic_weight,
)