"""Lexical (BM25) search strategy mixin for RetrievalService."""
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
from ..metrics import observe_retrieval_error, observe_retrieval_phase
from ..query_expansion import ExpandedQuery
from ..query_utils import truncate_lexical_query
from ..result import RetrievalResult
logger = logging.getLogger(__name__)
[docs]
class LexicalStrategyMixin:
"""Provides lexical_only() and the underlying BM25 + expansion helpers."""
[docs]
def lexical_only(
self,
query: str,
top_k: int = 10,
score_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
) -> List[RetrievalResult]:
"""Execute lexical-only search using BM25 (no semantic component).
BM25 scores are normalized into [0, 1] before thresholding so that
``score_threshold`` stays comparable across modes.
"""
operation = "lexical"
total_started_at = time.perf_counter()
try:
lexical_started_at = time.perf_counter()
try:
results = self._lexical_search_with_expansion(
query=query,
top_k=top_k,
filters=filters,
target="subdivisions",
output_top_k=top_k,
)
except Exception as exc:
observe_retrieval_error(
operation=operation,
phase="lexical_search",
exc_or_reason=exc,
)
raise
observe_retrieval_phase(
operation=operation,
phase="lexical_search",
duration_seconds=time.perf_counter() - lexical_started_at,
)
if score_threshold is not None:
threshold_started_at = time.perf_counter()
results = [r for r in results if r.score >= score_threshold]
observe_retrieval_phase(
operation=operation,
phase="score_threshold",
duration_seconds=time.perf_counter() - threshold_started_at,
)
results = self._rerank_results(query, results, top_k, operation=operation) # type: ignore[attr-defined]
return results
finally:
observe_retrieval_phase(
operation=operation,
phase="total",
duration_seconds=time.perf_counter() - total_started_at,
)
def _lexical_search_with_expansion(
self,
*,
query: str,
expanded_queries: Optional[List[ExpandedQuery]] = None,
top_k: int,
filters: Optional[Dict[str, Any]],
target: Optional[str],
output_top_k: int,
telemetry: Optional[Dict[str, float]] = None,
) -> List[RetrievalResult]:
if not target:
if telemetry is not None:
telemetry["search_ms"] = 0.0
return []
# Truncate long queries to avoid near-zero BM25 scores.
query = truncate_lexical_query(query)
resolved_expanded_queries = expanded_queries or self._expanded_queries(query) # type: ignore[attr-defined]
per_variant_k = self._variant_candidate_pool_size( # type: ignore[attr-defined]
top_k,
variants_count=len(resolved_expanded_queries),
output_top_k=output_top_k,
)
if not resolved_expanded_queries:
if telemetry is not None:
telemetry["search_ms"] = 0.0
return []
combined: List[RetrievalResult] = []
search_started_at = time.perf_counter()
if len(resolved_expanded_queries) == 1:
variant_results = [
self._lexical_search(resolved_expanded_queries[0].text, per_variant_k, filters, target=target)
]
else:
with ThreadPoolExecutor(
max_workers=min(len(resolved_expanded_queries), self._branch_workers()) # type: ignore[attr-defined]
) as executor:
futures = [
executor.submit(self._lexical_search, expanded.text, per_variant_k, filters, target)
for expanded in resolved_expanded_queries
]
variant_results = [future.result() for future in futures]
search_ms = (time.perf_counter() - search_started_at) * 1000.0
logger.warning("PERF lexical variants (%d): %.1f ms", len(resolved_expanded_queries), search_ms)
for variant_index, (expanded, lexical_results) in enumerate(
zip(resolved_expanded_queries, variant_results),
start=1,
):
self._apply_query_variant_metadata( # type: ignore[attr-defined]
lexical_results,
expanded=expanded,
variant_index=variant_index,
search_branch="lexical",
expansion_enabled=self.query_expansion_enabled, # type: ignore[attr-defined]
)
combined.extend(lexical_results)
if telemetry is not None:
telemetry["search_ms"] = round(search_ms, 1)
deduped = self._dedupe_by_best_score(combined) # type: ignore[attr-defined]
deduped.sort(key=lambda item: item.score, reverse=True)
return deduped[:top_k]
def _lexical_search(
self,
query: str,
top_k: int,
filters: Optional[Dict[str, Any]] = None,
target: str = "subdivisions",
) -> List[RetrievalResult]:
"""Execute lexical search via BM25."""
bm25_start = time.perf_counter()
results = self.bm25_service.search( # type: ignore[attr-defined]
query=query,
top_k=top_k,
filters=filters,
target=target,
)
bm25_ms = (time.perf_counter() - bm25_start) * 1000.0
logger.warning("PERF bm25 search (%s): %.1f ms, %d results", target, bm25_ms, len(results))
return results