"""Semantic search strategy mixin for RetrievalService."""
import logging
import math
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
from lalandre_core.config import get_config
from ..metrics import observe_retrieval_error, observe_retrieval_phase
from ..query_expansion import ExpandedQuery
from ..result import RetrievalResult
logger = logging.getLogger(__name__)
[docs]
class SemanticStrategyMixin:
"""Provides semantic_only() and the underlying multi-collection + expansion helpers."""
[docs]
def semantic_only(
self,
query: Optional[str] = None,
query_vector: Optional[List[float]] = None,
top_k: int = 10,
score_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
collections: Optional[List[str]] = None,
granularity: Optional[str] = None,
embedding_preset: Optional[str] = None,
) -> List[RetrievalResult]:
"""Execute semantic-only search (no lexical component).
Supports both text queries (embedded on the fly) and pre-computed vectors.
"""
operation = "semantic"
total_started_at = time.perf_counter()
try:
semantic_started_at = time.perf_counter()
active_embedding_svc, col_suffix = self._resolve_embedding_context( # type: ignore[attr-defined]
embedding_preset
)
base_collections = self._base_collections_from_granularity( # type: ignore[attr-defined]
granularity=granularity,
collections=collections,
)
search_collections = self._apply_preset_suffix(base_collections, col_suffix) # type: ignore[attr-defined]
try:
if query_vector is not None:
expected_dim = active_embedding_svc.get_vector_size()
actual_dim = len(query_vector)
if actual_dim != expected_dim:
raise ValueError(f"Invalid embedding dimension: expected {expected_dim}, got {actual_dim}")
results: List[RetrievalResult] = self.semantic_service.search( # type: ignore[attr-defined]
query_vector=query_vector,
top_k=top_k,
filters=filters,
collections=search_collections,
)
elif query:
results = self._semantic_search_with_expansion(
query=query,
top_k=top_k,
filters=filters,
collections=search_collections,
output_top_k=top_k,
embedding_service=active_embedding_svc,
)
else:
raise ValueError("Either query or query_vector must be provided")
except Exception as exc:
observe_retrieval_error(
operation=operation,
phase="semantic_search",
exc_or_reason=exc,
)
raise
observe_retrieval_phase(
operation=operation,
phase="semantic_search",
duration_seconds=time.perf_counter() - semantic_started_at,
)
if score_threshold is not None:
threshold_started_at = time.perf_counter()
results = [r for r in results if r.score >= score_threshold]
observe_retrieval_phase(
operation=operation,
phase="score_threshold",
duration_seconds=time.perf_counter() - threshold_started_at,
)
if query:
results = self._rerank_results(query, results, top_k, operation=operation) # type: ignore[attr-defined]
return results
finally:
observe_retrieval_phase(
operation=operation,
phase="total",
duration_seconds=time.perf_counter() - total_started_at,
)
def _semantic_search_with_expansion(
self,
*,
query: str,
expanded_queries: Optional[List[ExpandedQuery]] = None,
top_k: int,
filters: Optional[Dict[str, Any]],
collections: Optional[List[str]],
output_top_k: int,
embedding_service=None,
telemetry: Optional[Dict[str, float]] = None,
) -> List[RetrievalResult]:
resolved_expanded_queries = expanded_queries or self._expanded_queries(query) # type: ignore[attr-defined]
per_variant_k = self._variant_candidate_pool_size( # type: ignore[attr-defined]
top_k,
variants_count=len(resolved_expanded_queries),
output_top_k=output_top_k,
)
if not resolved_expanded_queries:
if telemetry is not None:
telemetry["embedding_ms"] = 0.0
telemetry["search_ms"] = 0.0
return []
_emb = embedding_service or self.embedding_service # type: ignore[attr-defined]
variant_texts = [expanded.text for expanded in resolved_expanded_queries]
embed_started_at = time.perf_counter()
if len(variant_texts) == 1:
query_vectors = [_emb.embed_text(variant_texts[0])]
else:
query_vectors = _emb.embed_batch(variant_texts, batch_size=len(variant_texts))
embedding_ms = (time.perf_counter() - embed_started_at) * 1000.0
logger.warning("PERF embed_batch(%d): %.1f ms", len(variant_texts), embedding_ms)
combined: List[RetrievalResult] = []
search_started_at = time.perf_counter()
if len(resolved_expanded_queries) == 1:
variant_results = [
self._semantic_search_multi_by_vector(
query_vector=query_vectors[0],
top_k=per_variant_k,
filters=filters,
collections=collections,
)
]
else:
with ThreadPoolExecutor(
max_workers=min(len(resolved_expanded_queries), self._branch_workers()) # type: ignore[attr-defined]
) as executor:
futures = [
executor.submit(
self._semantic_search_multi_by_vector,
query_vector=query_vector,
top_k=per_variant_k,
filters=filters,
collections=collections,
)
for query_vector in query_vectors
]
variant_results = [future.result() for future in futures]
search_ms = (time.perf_counter() - search_started_at) * 1000.0
logger.warning("PERF semantic variants (%d): %.1f ms", len(variant_texts), search_ms)
for variant_index, (expanded, semantic_results) in enumerate(
zip(resolved_expanded_queries, variant_results),
start=1,
):
self._apply_query_variant_metadata( # type: ignore[attr-defined]
semantic_results,
expanded=expanded,
variant_index=variant_index,
search_branch="semantic",
expansion_enabled=self.query_expansion_enabled, # type: ignore[attr-defined]
)
combined.extend(semantic_results)
if telemetry is not None:
telemetry["embedding_ms"] = round(embedding_ms, 1)
telemetry["search_ms"] = round(search_ms, 1)
deduped = self._dedupe_by_best_score(combined) # type: ignore[attr-defined]
deduped.sort(key=lambda item: item.score, reverse=True)
return deduped[:top_k]
def _semantic_search_multi_by_vector(
self,
*,
query_vector: List[float],
top_k: int,
filters: Optional[Dict[str, Any]],
collections: Optional[List[str]],
) -> List[RetrievalResult]:
"""Execute semantic search across one or more collections for a precomputed vector."""
search_collections = collections or list(self.qdrant_repos.keys()) # type: ignore[attr-defined]
if len(search_collections) <= 1:
all_results = self.semantic_service.search( # type: ignore[attr-defined]
query_vector=query_vector,
top_k=top_k,
filters=filters,
collections=search_collections,
)
else:
valid_collections = [
c
for c in search_collections
if c in self.qdrant_repos # type: ignore[attr-defined]
]
if not valid_collections:
return []
per_collection_limit = max(1, math.ceil(top_k / len(valid_collections)))
per_collection_limit = max(
1,
math.ceil(
(per_collection_limit * self.semantic_service.per_collection_oversampling), # type: ignore[attr-defined]
),
)
def _search_collection(collection: str) -> List[RetrievalResult]:
return self.semantic_service.search( # type: ignore[attr-defined]
query_vector=query_vector,
top_k=per_collection_limit,
filters=filters,
collections=[collection],
)
with ThreadPoolExecutor(
max_workers=min(
len(valid_collections),
max(int(get_config().search.max_parallel_workers), 1),
)
) as executor:
futures = [executor.submit(_search_collection, collection) for collection in valid_collections]
all_results = [result for future in futures for result in future.result()]
for result in all_results:
result.metadata.setdefault("source_collection", result.metadata.get("collection"))
all_results.sort(key=lambda item: item.score, reverse=True)
seen: set = set()
deduplicated: List[RetrievalResult] = []
for result in all_results:
key = self._result_dedupe_key(result) # type: ignore[attr-defined]
if key not in seen:
seen.add(key)
deduplicated.append(result)
return deduplicated[:top_k]
def _semantic_search_multi(
self,
query: str,
top_k: int,
filters: Optional[Dict[str, Any]] = None,
collections: Optional[List[str]] = None,
embedding_service=None,
) -> List[RetrievalResult]:
"""Execute semantic search across multiple Qdrant collections."""
collections = collections or list(self.qdrant_repos.keys()) # type: ignore[attr-defined]
embed_start = time.perf_counter()
_emb = embedding_service or self.embedding_service # type: ignore[attr-defined]
query_vector = _emb.embed_text(query)
embed_ms = (time.perf_counter() - embed_start) * 1000.0
logger.warning("PERF embed_text: %.1f ms", embed_ms)
qdrant_start = time.perf_counter()
all_results = self._semantic_search_multi_by_vector(
query_vector=query_vector,
top_k=top_k,
filters=filters,
collections=collections,
)
qdrant_ms = (time.perf_counter() - qdrant_start) * 1000.0
logger.warning(
"PERF qdrant search (%s): %.1f ms, %d results",
collections,
qdrant_ms,
len(all_results),
)
return all_results