"""
Document Retrieval Service
Orchestrates semantic + lexical search, fusion, reranking, and caching.
Strategy implementations live in retrieval/strategies/.
"""
import logging
import time
from collections.abc import Hashable
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, TypeAlias
from lalandre_core.config import get_config
from lalandre_core.repositories.common import PayloadBuilder
from lalandre_db_postgres import PostgresRepository
from lalandre_db_qdrant import QdrantRepository
from lalandre_embedding import EmbeddingService
from .bm25_search import BM25SearchService
from .fusion_service import ResultFusionService
from .metrics import observe_retrieval_error, observe_retrieval_phase
from .query_expansion import ExpandedQuery, LegalQueryExpansionService
from .rerank_service import RerankService
from .result import RetrievalResult, RetrievalStats
from .result_cache import RetrievalCache
from .search_config import ResolvedSearchConfig
from .semantic_search import SemanticSearchService
from .strategies import HybridStrategyMixin, LexicalStrategyMixin, SemanticStrategyMixin
from .strategies.hybrid import has_explicit_legal_reference
logger = logging.getLogger(__name__)
ResultDedupeKey: TypeAlias = tuple[str, Hashable]
def _mmr_act_diverse(
candidates: List[RetrievalResult],
top_k: int,
max_per_act: int,
) -> List[RetrievalResult]:
"""Select up to top_k results while capping contributions per act_id.
Iterates the score-sorted candidates and picks each result only if
the act has not yet reached ``max_per_act`` slots. This avoids
returning multiple near-duplicate chunks from the same legal act,
ensuring diversity across acts in the final context window.
"""
selected: List[RetrievalResult] = []
act_counts: Dict[int, int] = {}
for r in candidates:
act_id = r.act_id
count = act_counts.get(act_id, 0)
if count < max_per_act:
selected.append(r)
act_counts[act_id] = count + 1
if len(selected) >= top_k:
break
return selected
[docs]
class RetrievalService(SemanticStrategyMixin, LexicalStrategyMixin, HybridStrategyMixin):
"""
Unified retrieval service combining semantic and lexical search.
Responsibilities:
- Execute hybrid search (retrieve) across multiple collections
- Delegate semantic/lexical/hybrid-precomputed searches to strategy mixins
- Fuse results using RRF or weighted scores
- Rerank, deduplicate, and cache results
Does NOT:
- Modify Qdrant collections
- Generate embeddings (uses EmbeddingService)
- Enrich context (uses ContextService)
"""
def __init__(
self,
qdrant_repos: Dict[str, QdrantRepository],
pg_repo: PostgresRepository,
embedding_service: EmbeddingService,
reranker: Optional[RerankService] = None,
payload_builder: Optional[PayloadBuilder] = None,
search_language: Optional[str] = None,
candidate_multiplier: Optional[float] = None,
min_candidates: Optional[int] = None,
max_candidates: Optional[int] = None,
semantic_per_collection_oversampling: Optional[float] = None,
hnsw_ef: Optional[int] = None,
exact_search: Optional[bool] = None,
query_expansion_enabled: Optional[bool] = None,
query_expansion_max_variants: Optional[int] = None,
query_expansion_min_query_chars: Optional[int] = None,
lexical_weight: Optional[float] = None,
semantic_weight: Optional[float] = None,
fusion_method: Optional[str] = None,
dynamic_fusion_enabled: Optional[bool] = None,
redis_client: Optional[Any] = None,
result_cache_ttl: Optional[int] = None,
preset_embedding_services: Optional[Dict[str, EmbeddingService]] = None,
):
self.qdrant_repos = qdrant_repos
self.embedding_service = embedding_service
self.preset_embedding_services: Dict[str, EmbeddingService] = preset_embedding_services or {}
sc = ResolvedSearchConfig.from_overrides(
search_language=search_language,
candidate_multiplier=candidate_multiplier,
min_candidates=min_candidates,
max_candidates=max_candidates,
hnsw_ef=hnsw_ef,
exact_search=exact_search,
semantic_per_collection_oversampling=semantic_per_collection_oversampling,
query_expansion_enabled=query_expansion_enabled,
query_expansion_max_variants=query_expansion_max_variants,
query_expansion_min_query_chars=query_expansion_min_query_chars,
lexical_weight=lexical_weight,
semantic_weight=semantic_weight,
fusion_method=fusion_method,
dynamic_fusion_enabled=dynamic_fusion_enabled,
result_cache_ttl=result_cache_ttl,
)
self.candidate_multiplier = sc.candidate_multiplier
self.min_candidates = sc.min_candidates
self.max_candidates = sc.max_candidates
self.hnsw_ef = sc.hnsw_ef
self.exact_search = sc.exact_search
self.semantic_per_collection_oversampling = sc.per_collection_oversampling
self.query_expansion_enabled = sc.query_expansion_enabled
self.query_expansion_max_variants = sc.query_expansion_max_variants
self.query_expander = LegalQueryExpansionService(
min_query_chars=sc.query_expansion_min_query_chars,
)
self.fusion_method = sc.fusion_method
self.lexical_weight = sc.lexical_weight
self.semantic_weight = sc.semantic_weight
self.dynamic_fusion_enabled = sc.dynamic_fusion_enabled
self._lexical_boost_factor = sc.lexical_boost_factor
self._lexical_boost_max = sc.lexical_boost_max
self.bm25_service = BM25SearchService(
pg_repo=pg_repo,
language=sc.search_language,
payload_builder=payload_builder,
)
self.semantic_service = SemanticSearchService(
qdrant_repos=qdrant_repos,
score_threshold=None,
hnsw_ef=sc.hnsw_ef,
exact_search=sc.exact_search,
per_collection_oversampling=sc.per_collection_oversampling,
)
self.fusion_service = ResultFusionService(
fusion_method=sc.fusion_method,
lexical_weight=sc.lexical_weight,
semantic_weight=sc.semantic_weight,
rrf_k=get_config().search.fusion_rrf_k,
)
self.reranker = reranker
self._cache = RetrievalCache(redis_client, sc.result_cache_ttl)
self.last_retrieval_stats: RetrievalStats = RetrievalStats()
def _resolve_embedding_context(
self,
embedding_preset: Optional[str],
) -> tuple[EmbeddingService, str]:
if embedding_preset and embedding_preset in self.preset_embedding_services:
return self.preset_embedding_services[embedding_preset], f"__{embedding_preset}"
return self.embedding_service, ""
def _base_collections_from_granularity(
self,
*,
granularity: Optional[str],
collections: Optional[List[str]],
) -> List[str]:
if granularity:
if granularity == "all":
return [k for k in self.qdrant_repos if "__" not in k]
return [granularity]
if collections:
return list(collections)
return ["chunks"]
@staticmethod
def _apply_preset_suffix(base_collections: List[str], suffix: str) -> List[str]:
if not suffix:
return base_collections
return [collection if "__" in collection else f"{collection}{suffix}" for collection in base_collections]
[docs]
def retrieve(
self,
query: str,
top_k: int = 10,
score_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
collections: Optional[List[str]] = None,
granularity: Optional[str] = None,
embedding_preset: Optional[str] = None,
) -> List[RetrievalResult]:
"""Hybrid search (semantic + BM25) with fusion, reranking, and caching.
Args:
query: Search query
top_k: Number of final results to return
score_threshold: Minimum score threshold (post-fusion)
filters: Metadata filters (act_id, celex, etc.)
collections: Specific collections to search
granularity: 'subdivisions', 'chunks', or 'all' (overrides collections)
embedding_preset: Route semantic search to this preset's collections/embedding service
"""
operation = "hybrid"
total_started_at = time.perf_counter()
# Resolve preset embedding service and collection suffix
active_embedding_svc, col_suffix = self._resolve_embedding_context(embedding_preset)
cache_key = RetrievalCache.cache_key(
query,
top_k,
score_threshold,
filters,
granularity,
collections,
embedding_preset,
)
cached = self._cache.get(cache_key)
if cached is not None:
logger.debug("Cache hit for retrieval key %s", cache_key)
self.last_retrieval_stats = RetrievalStats(
candidates_returned=len(cached),
cache_hit=True,
)
return cached
expanded_queries = self._expanded_queries(query)
dyn_lex_w, dyn_sem_w = self._resolve_fusion_override(query)
base_collections = self._base_collections_from_granularity(
granularity=granularity,
collections=collections,
)
# Apply preset suffix to route semantic search to the right Qdrant collections
search_collections = self._apply_preset_suffix(base_collections, col_suffix)
candidate_k = self._candidate_pool_size(top_k)
lexical_target: Optional[str] = None
if "chunks" in base_collections and "subdivisions" not in base_collections:
lexical_target = "chunks"
elif "subdivisions" in base_collections:
lexical_target = "subdivisions"
try:
parallel_started_at = time.perf_counter()
semantic_telemetry: Dict[str, float] = {}
lexical_telemetry: Dict[str, float] = {}
with ThreadPoolExecutor(max_workers=min(2, self._max_parallel_workers())) as executor:
sem_future = executor.submit(
self._semantic_search_with_expansion,
query=query,
expanded_queries=expanded_queries,
top_k=candidate_k,
filters=filters,
collections=search_collections,
output_top_k=top_k,
embedding_service=active_embedding_svc,
telemetry=semantic_telemetry,
)
lex_future = executor.submit(
self._lexical_search_with_expansion,
query=query,
expanded_queries=expanded_queries,
top_k=candidate_k,
filters=filters,
target=lexical_target,
output_top_k=top_k,
telemetry=lexical_telemetry,
)
try:
semantic_results = sem_future.result()
except Exception as exc:
observe_retrieval_error(operation=operation, phase="semantic_search", exc_or_reason=exc)
raise
try:
lexical_results = lex_future.result()
except Exception as exc:
observe_retrieval_error(operation=operation, phase="lexical_search", exc_or_reason=exc)
raise
observe_retrieval_phase(
operation=operation,
phase="semantic_search",
duration_seconds=max(float(semantic_telemetry.get("search_ms", 0.0)), 0.0) / 1000.0,
)
observe_retrieval_phase(
operation=operation,
phase="lexical_search",
duration_seconds=max(float(lexical_telemetry.get("search_ms", 0.0)), 0.0) / 1000.0,
)
parallel_ms = (time.perf_counter() - parallel_started_at) * 1000.0
logger.warning(
"PERF retrieve parallel phase: %.1f ms (semantic=%d, lexical=%d results)",
parallel_ms,
len(semantic_results),
len(lexical_results),
)
fuse_started_at = time.perf_counter()
fused_results = self._fuse_results(
lexical_results,
semantic_results,
override_lexical_weight=dyn_lex_w,
override_semantic_weight=dyn_sem_w,
)
fusion_ms = (time.perf_counter() - fuse_started_at) * 1000.0
observe_retrieval_phase(operation=operation, phase="fusion", duration_seconds=fusion_ms / 1000.0)
cfg = get_config()
effective_threshold = score_threshold if score_threshold is not None else cfg.search.score_threshold_default
candidates_after_fusion = len(fused_results)
if effective_threshold is not None:
threshold_started_at = time.perf_counter()
fused_results = [r for r in fused_results if r.score >= effective_threshold]
observe_retrieval_phase(
operation=operation,
phase="score_threshold",
duration_seconds=time.perf_counter() - threshold_started_at,
)
candidates_after_threshold = len(fused_results)
rerank_started_at = time.perf_counter()
reranked = self._rerank_results(query, fused_results, top_k, operation=operation)
rerank_ms = (time.perf_counter() - rerank_started_at) * 1000.0
candidates_after_rerank = len(reranked)
# Adaptive score-drop cutoff: truncate at the first large gap
adaptive_cutoff_applied = False
drop_threshold = cfg.search.adaptive_score_drop_threshold
if drop_threshold is not None and len(reranked) > 1:
scores = [r.score for r in reranked]
for i in range(1, len(scores)):
if scores[i - 1] - scores[i] > drop_threshold:
reranked = reranked[:i]
adaptive_cutoff_applied = True
logger.debug(
"Adaptive cutoff at position %d (drop %.3f > %.3f)",
i,
scores[i - 1] - scores[i],
drop_threshold,
)
break
# MMR act-based diversity: cap chunks per act_id
if cfg.search.mmr_enabled and cfg.search.mmr_max_per_act > 0:
reranked = _mmr_act_diverse(reranked, top_k, cfg.search.mmr_max_per_act)
final = reranked[:top_k]
total_retrieve_ms = (time.perf_counter() - total_started_at) * 1000.0
logger.warning(
"PERF retrieve total: %.1f ms (parallel=%.1f, fusion=%.1f, rerank=%.1f)",
total_retrieve_ms,
parallel_ms,
fusion_ms,
rerank_ms,
)
self.last_retrieval_stats = RetrievalStats(
candidates_after_fusion=candidates_after_fusion,
candidates_after_threshold=candidates_after_threshold,
candidates_after_rerank=candidates_after_rerank,
candidates_after_adaptive_cutoff=len(reranked),
candidates_returned=len(final),
adaptive_cutoff_applied=adaptive_cutoff_applied,
effective_score_threshold=effective_threshold,
fusion_lexical_weight=dyn_lex_w if dyn_lex_w is not None else self.lexical_weight,
fusion_semantic_weight=dyn_sem_w if dyn_sem_w is not None else self.semantic_weight,
query_variants_count=len(expanded_queries),
cache_hit=False,
embedding_ms=round(float(semantic_telemetry.get("embedding_ms", 0.0)), 1),
semantic_search_ms=round(float(semantic_telemetry.get("search_ms", 0.0)), 1),
lexical_search_ms=round(float(lexical_telemetry.get("search_ms", 0.0)), 1),
parallel_search_ms=round(parallel_ms, 1),
fusion_ms=round(fusion_ms, 1),
rerank_ms=round(rerank_ms, 1),
total_retrieve_ms=round(total_retrieve_ms, 1),
)
self._cache.set(cache_key, final)
return final
finally:
observe_retrieval_phase(
operation=operation,
phase="total",
duration_seconds=time.perf_counter() - total_started_at,
)
# -------------------------------------------------------------------------
# Shared helpers used by strategy mixins via self.*
# -------------------------------------------------------------------------
def _rerank_results(
self,
query: str,
results: List[RetrievalResult],
top_k: Optional[int] = None,
*,
operation: str,
) -> List[RetrievalResult]:
reranker = self.reranker
if reranker is None:
return results
rerank_started_at = time.perf_counter()
try:
return reranker.rerank(query, results, top_k=top_k)
except Exception as exc:
observe_retrieval_error(operation=operation, phase="rerank", exc_or_reason=exc)
raise
finally:
observe_retrieval_phase(
operation=operation,
phase="rerank",
duration_seconds=time.perf_counter() - rerank_started_at,
)
def _expanded_queries(self, query: str) -> List[ExpandedQuery]:
normalized = query.strip()
if not normalized:
return []
if not self.query_expansion_enabled:
return [ExpandedQuery(text=normalized, weight=1.0, strategy="original")]
max_variants = self.query_expansion_max_variants
if not has_explicit_legal_reference(normalized):
max_variants = min(max_variants, 2)
expanded = self.query_expander.expand(normalized, max_variants=max_variants)
if not expanded:
return [ExpandedQuery(text=normalized, weight=1.0, strategy="original")]
return expanded
@staticmethod
def _max_parallel_workers() -> int:
return max(int(get_config().search.max_parallel_workers), 1)
def _branch_workers(self) -> int:
return max(1, self._max_parallel_workers() // 2)
@staticmethod
def _variant_candidate_pool_size(candidate_k: int, *, variants_count: int, output_top_k: int) -> int:
if variants_count <= 1:
return max(candidate_k, output_top_k, 1)
spread = int(candidate_k / variants_count)
buffer = max(int(output_top_k / 2), 1)
return max(min(spread + buffer, candidate_k), output_top_k, 1)
@staticmethod
def _apply_query_variant_metadata(
results: List[RetrievalResult],
*,
expanded: ExpandedQuery,
variant_index: int,
search_branch: str,
expansion_enabled: bool,
) -> None:
for rank, result in enumerate(results, start=1):
raw_score = result.score
weighted_score = raw_score * expanded.weight
metadata = result.metadata
metadata["query_variant"] = expanded.text
metadata["query_variant_strategy"] = expanded.strategy
metadata["query_variant_weight"] = expanded.weight
metadata["query_variant_index"] = variant_index
metadata["query_variant_rank"] = rank
metadata["query_expansion_enabled"] = expansion_enabled
if search_branch == "lexical":
metadata.setdefault("lexical_score_raw", raw_score)
metadata["lexical_score_normalized"] = raw_score
else:
metadata["semantic_score_raw"] = raw_score
result.score = weighted_score
def _dedupe_by_best_score(self, results: List[RetrievalResult]) -> List[RetrievalResult]:
best_by_key: Dict[ResultDedupeKey, RetrievalResult] = {}
for result in results:
key = self._result_dedupe_key(result)
current = best_by_key.get(key)
if current is None or result.score > current.score:
best_by_key[key] = result
return list(best_by_key.values())
@staticmethod
def _result_dedupe_key(result: RetrievalResult) -> ResultDedupeKey:
metadata = result.metadata
chunk_id = metadata.get("chunk_id")
if chunk_id is not None:
chunk_hashable = chunk_id if isinstance(chunk_id, Hashable) else str(chunk_id)
return ("chunk", chunk_hashable)
if result.subdivision_id:
return ("subdivision", result.subdivision_id)
if result.act_id:
return ("act", result.act_id)
return ("unknown", id(result))
def _candidate_pool_size(self, top_k: int) -> int:
scaled = int(top_k * self.candidate_multiplier)
with_floor = max(scaled, self.min_candidates)
with_cap = min(with_floor, self.max_candidates)
return max(with_cap, 1)
[docs]
def get_statistics(self) -> Dict[str, Any]:
"""Return retrieval service statistics."""
qdrant_repo = self.qdrant_repos.get("chunks")
if not qdrant_repo and self.qdrant_repos:
qdrant_repo = next(iter(self.qdrant_repos.values()))
vector_count: int | str = "unknown"
if qdrant_repo is not None:
try:
collection_info = qdrant_repo.client.get_collection(qdrant_repo.collection_name)
points_count = collection_info.points_count
vector_count = int(points_count) if points_count is not None else "unknown"
except Exception:
vector_count = "unknown"
return {
"fusion_method": self.fusion_method,
"lexical_weight": self.lexical_weight,
"semantic_weight": self.semantic_weight,
"candidate_multiplier": self.candidate_multiplier,
"min_candidates": self.min_candidates,
"max_candidates": self.max_candidates,
"semantic_per_collection_oversampling": self.semantic_per_collection_oversampling,
"hnsw_ef": self.hnsw_ef,
"exact_search": self.exact_search,
"query_expansion_enabled": self.query_expansion_enabled,
"query_expansion_max_variants": self.query_expansion_max_variants,
"query_expansion_min_query_chars": self.query_expander.min_query_chars,
"vector_documents": vector_count,
"embedding_model": self.embedding_service.model_name,
"embedding_dimension": self.embedding_service.get_vector_size(),
}