Source code for lalandre_rag.retrieval.strategies.semantic

"""Semantic search strategy mixin for RetrievalService."""

import logging
import math
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional

from lalandre_core.config import get_config

from ..metrics import observe_retrieval_error, observe_retrieval_phase
from ..query_expansion import ExpandedQuery
from ..result import RetrievalResult

logger = logging.getLogger(__name__)


[docs] class SemanticStrategyMixin: """Provides semantic_only() and the underlying multi-collection + expansion helpers."""
[docs] def semantic_only( self, query: Optional[str] = None, query_vector: Optional[List[float]] = None, top_k: int = 10, score_threshold: Optional[float] = None, filters: Optional[Dict[str, Any]] = None, collections: Optional[List[str]] = None, granularity: Optional[str] = None, embedding_preset: Optional[str] = None, ) -> List[RetrievalResult]: """Execute semantic-only search (no lexical component). Supports both text queries (embedded on the fly) and pre-computed vectors. """ operation = "semantic" total_started_at = time.perf_counter() try: semantic_started_at = time.perf_counter() active_embedding_svc, col_suffix = self._resolve_embedding_context( # type: ignore[attr-defined] embedding_preset ) base_collections = self._base_collections_from_granularity( # type: ignore[attr-defined] granularity=granularity, collections=collections, ) search_collections = self._apply_preset_suffix(base_collections, col_suffix) # type: ignore[attr-defined] try: if query_vector is not None: expected_dim = active_embedding_svc.get_vector_size() actual_dim = len(query_vector) if actual_dim != expected_dim: raise ValueError(f"Invalid embedding dimension: expected {expected_dim}, got {actual_dim}") results: List[RetrievalResult] = self.semantic_service.search( # type: ignore[attr-defined] query_vector=query_vector, top_k=top_k, filters=filters, collections=search_collections, ) elif query: results = self._semantic_search_with_expansion( query=query, top_k=top_k, filters=filters, collections=search_collections, output_top_k=top_k, embedding_service=active_embedding_svc, ) else: raise ValueError("Either query or query_vector must be provided") except Exception as exc: observe_retrieval_error( operation=operation, phase="semantic_search", exc_or_reason=exc, ) raise observe_retrieval_phase( operation=operation, phase="semantic_search", duration_seconds=time.perf_counter() - semantic_started_at, ) if score_threshold is not None: threshold_started_at = time.perf_counter() results = [r for r in results if r.score >= score_threshold] observe_retrieval_phase( operation=operation, phase="score_threshold", duration_seconds=time.perf_counter() - threshold_started_at, ) if query: results = self._rerank_results(query, results, top_k, operation=operation) # type: ignore[attr-defined] return results finally: observe_retrieval_phase( operation=operation, phase="total", duration_seconds=time.perf_counter() - total_started_at, )
def _semantic_search_with_expansion( self, *, query: str, expanded_queries: Optional[List[ExpandedQuery]] = None, top_k: int, filters: Optional[Dict[str, Any]], collections: Optional[List[str]], output_top_k: int, embedding_service=None, telemetry: Optional[Dict[str, float]] = None, ) -> List[RetrievalResult]: resolved_expanded_queries = expanded_queries or self._expanded_queries(query) # type: ignore[attr-defined] per_variant_k = self._variant_candidate_pool_size( # type: ignore[attr-defined] top_k, variants_count=len(resolved_expanded_queries), output_top_k=output_top_k, ) if not resolved_expanded_queries: if telemetry is not None: telemetry["embedding_ms"] = 0.0 telemetry["search_ms"] = 0.0 return [] _emb = embedding_service or self.embedding_service # type: ignore[attr-defined] variant_texts = [expanded.text for expanded in resolved_expanded_queries] embed_started_at = time.perf_counter() if len(variant_texts) == 1: query_vectors = [_emb.embed_text(variant_texts[0])] else: query_vectors = _emb.embed_batch(variant_texts, batch_size=len(variant_texts)) embedding_ms = (time.perf_counter() - embed_started_at) * 1000.0 logger.warning("PERF embed_batch(%d): %.1f ms", len(variant_texts), embedding_ms) combined: List[RetrievalResult] = [] search_started_at = time.perf_counter() if len(resolved_expanded_queries) == 1: variant_results = [ self._semantic_search_multi_by_vector( query_vector=query_vectors[0], top_k=per_variant_k, filters=filters, collections=collections, ) ] else: with ThreadPoolExecutor( max_workers=min(len(resolved_expanded_queries), self._branch_workers()) # type: ignore[attr-defined] ) as executor: futures = [ executor.submit( self._semantic_search_multi_by_vector, query_vector=query_vector, top_k=per_variant_k, filters=filters, collections=collections, ) for query_vector in query_vectors ] variant_results = [future.result() for future in futures] search_ms = (time.perf_counter() - search_started_at) * 1000.0 logger.warning("PERF semantic variants (%d): %.1f ms", len(variant_texts), search_ms) for variant_index, (expanded, semantic_results) in enumerate( zip(resolved_expanded_queries, variant_results), start=1, ): self._apply_query_variant_metadata( # type: ignore[attr-defined] semantic_results, expanded=expanded, variant_index=variant_index, search_branch="semantic", expansion_enabled=self.query_expansion_enabled, # type: ignore[attr-defined] ) combined.extend(semantic_results) if telemetry is not None: telemetry["embedding_ms"] = round(embedding_ms, 1) telemetry["search_ms"] = round(search_ms, 1) deduped = self._dedupe_by_best_score(combined) # type: ignore[attr-defined] deduped.sort(key=lambda item: item.score, reverse=True) return deduped[:top_k] def _semantic_search_multi_by_vector( self, *, query_vector: List[float], top_k: int, filters: Optional[Dict[str, Any]], collections: Optional[List[str]], ) -> List[RetrievalResult]: """Execute semantic search across one or more collections for a precomputed vector.""" search_collections = collections or list(self.qdrant_repos.keys()) # type: ignore[attr-defined] if len(search_collections) <= 1: all_results = self.semantic_service.search( # type: ignore[attr-defined] query_vector=query_vector, top_k=top_k, filters=filters, collections=search_collections, ) else: valid_collections = [ c for c in search_collections if c in self.qdrant_repos # type: ignore[attr-defined] ] if not valid_collections: return [] per_collection_limit = max(1, math.ceil(top_k / len(valid_collections))) per_collection_limit = max( 1, math.ceil( (per_collection_limit * self.semantic_service.per_collection_oversampling), # type: ignore[attr-defined] ), ) def _search_collection(collection: str) -> List[RetrievalResult]: return self.semantic_service.search( # type: ignore[attr-defined] query_vector=query_vector, top_k=per_collection_limit, filters=filters, collections=[collection], ) with ThreadPoolExecutor( max_workers=min( len(valid_collections), max(int(get_config().search.max_parallel_workers), 1), ) ) as executor: futures = [executor.submit(_search_collection, collection) for collection in valid_collections] all_results = [result for future in futures for result in future.result()] for result in all_results: result.metadata.setdefault("source_collection", result.metadata.get("collection")) all_results.sort(key=lambda item: item.score, reverse=True) seen: set = set() deduplicated: List[RetrievalResult] = [] for result in all_results: key = self._result_dedupe_key(result) # type: ignore[attr-defined] if key not in seen: seen.add(key) deduplicated.append(result) return deduplicated[:top_k] def _semantic_search_multi( self, query: str, top_k: int, filters: Optional[Dict[str, Any]] = None, collections: Optional[List[str]] = None, embedding_service=None, ) -> List[RetrievalResult]: """Execute semantic search across multiple Qdrant collections.""" collections = collections or list(self.qdrant_repos.keys()) # type: ignore[attr-defined] embed_start = time.perf_counter() _emb = embedding_service or self.embedding_service # type: ignore[attr-defined] query_vector = _emb.embed_text(query) embed_ms = (time.perf_counter() - embed_start) * 1000.0 logger.warning("PERF embed_text: %.1f ms", embed_ms) qdrant_start = time.perf_counter() all_results = self._semantic_search_multi_by_vector( query_vector=query_vector, top_k=top_k, filters=filters, collections=collections, ) qdrant_ms = (time.perf_counter() - qdrant_start) * 1000.0 logger.warning( "PERF qdrant search (%s): %.1f ms, %d results", collections, qdrant_ms, len(all_results), ) return all_results