Source code for lalandre_rag.retrieval.strategies.lexical

"""Lexical (BM25) search strategy mixin for RetrievalService."""

import logging
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional

from ..metrics import observe_retrieval_error, observe_retrieval_phase
from ..query_expansion import ExpandedQuery
from ..query_utils import truncate_lexical_query
from ..result import RetrievalResult

logger = logging.getLogger(__name__)


[docs] class LexicalStrategyMixin: """Provides lexical_only() and the underlying BM25 + expansion helpers."""
[docs] def lexical_only( self, query: str, top_k: int = 10, score_threshold: Optional[float] = None, filters: Optional[Dict[str, Any]] = None, ) -> List[RetrievalResult]: """Execute lexical-only search using BM25 (no semantic component). BM25 scores are normalized into [0, 1] before thresholding so that ``score_threshold`` stays comparable across modes. """ operation = "lexical" total_started_at = time.perf_counter() try: lexical_started_at = time.perf_counter() try: results = self._lexical_search_with_expansion( query=query, top_k=top_k, filters=filters, target="subdivisions", output_top_k=top_k, ) except Exception as exc: observe_retrieval_error( operation=operation, phase="lexical_search", exc_or_reason=exc, ) raise observe_retrieval_phase( operation=operation, phase="lexical_search", duration_seconds=time.perf_counter() - lexical_started_at, ) if score_threshold is not None: threshold_started_at = time.perf_counter() results = [r for r in results if r.score >= score_threshold] observe_retrieval_phase( operation=operation, phase="score_threshold", duration_seconds=time.perf_counter() - threshold_started_at, ) results = self._rerank_results(query, results, top_k, operation=operation) # type: ignore[attr-defined] return results finally: observe_retrieval_phase( operation=operation, phase="total", duration_seconds=time.perf_counter() - total_started_at, )
def _lexical_search_with_expansion( self, *, query: str, expanded_queries: Optional[List[ExpandedQuery]] = None, top_k: int, filters: Optional[Dict[str, Any]], target: Optional[str], output_top_k: int, telemetry: Optional[Dict[str, float]] = None, ) -> List[RetrievalResult]: if not target: if telemetry is not None: telemetry["search_ms"] = 0.0 return [] # Truncate long queries to avoid near-zero BM25 scores. query = truncate_lexical_query(query) resolved_expanded_queries = expanded_queries or self._expanded_queries(query) # type: ignore[attr-defined] per_variant_k = self._variant_candidate_pool_size( # type: ignore[attr-defined] top_k, variants_count=len(resolved_expanded_queries), output_top_k=output_top_k, ) if not resolved_expanded_queries: if telemetry is not None: telemetry["search_ms"] = 0.0 return [] combined: List[RetrievalResult] = [] search_started_at = time.perf_counter() if len(resolved_expanded_queries) == 1: variant_results = [ self._lexical_search(resolved_expanded_queries[0].text, per_variant_k, filters, target=target) ] else: with ThreadPoolExecutor( max_workers=min(len(resolved_expanded_queries), self._branch_workers()) # type: ignore[attr-defined] ) as executor: futures = [ executor.submit(self._lexical_search, expanded.text, per_variant_k, filters, target) for expanded in resolved_expanded_queries ] variant_results = [future.result() for future in futures] search_ms = (time.perf_counter() - search_started_at) * 1000.0 logger.warning("PERF lexical variants (%d): %.1f ms", len(resolved_expanded_queries), search_ms) for variant_index, (expanded, lexical_results) in enumerate( zip(resolved_expanded_queries, variant_results), start=1, ): self._apply_query_variant_metadata( # type: ignore[attr-defined] lexical_results, expanded=expanded, variant_index=variant_index, search_branch="lexical", expansion_enabled=self.query_expansion_enabled, # type: ignore[attr-defined] ) combined.extend(lexical_results) if telemetry is not None: telemetry["search_ms"] = round(search_ms, 1) deduped = self._dedupe_by_best_score(combined) # type: ignore[attr-defined] deduped.sort(key=lambda item: item.score, reverse=True) return deduped[:top_k] def _lexical_search( self, query: str, top_k: int, filters: Optional[Dict[str, Any]] = None, target: str = "subdivisions", ) -> List[RetrievalResult]: """Execute lexical search via BM25.""" bm25_start = time.perf_counter() results = self.bm25_service.search( # type: ignore[attr-defined] query=query, top_k=top_k, filters=filters, target=target, ) bm25_ms = (time.perf_counter() - bm25_start) * 1000.0 logger.warning("PERF bm25 search (%s): %.1f ms, %d results", target, bm25_ms, len(results)) return results