Source code for lalandre_rag.retrieval.service

"""
Document Retrieval Service
Orchestrates semantic + lexical search, fusion, reranking, and caching.
Strategy implementations live in retrieval/strategies/.
"""

import logging
import time
from collections.abc import Hashable
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, TypeAlias

from lalandre_core.config import get_config
from lalandre_core.repositories.common import PayloadBuilder
from lalandre_db_postgres import PostgresRepository
from lalandre_db_qdrant import QdrantRepository
from lalandre_embedding import EmbeddingService

from .bm25_search import BM25SearchService
from .fusion_service import ResultFusionService
from .metrics import observe_retrieval_error, observe_retrieval_phase
from .query_expansion import ExpandedQuery, LegalQueryExpansionService
from .rerank_service import RerankService
from .result import RetrievalResult, RetrievalStats
from .result_cache import RetrievalCache
from .search_config import ResolvedSearchConfig
from .semantic_search import SemanticSearchService
from .strategies import HybridStrategyMixin, LexicalStrategyMixin, SemanticStrategyMixin
from .strategies.hybrid import has_explicit_legal_reference

logger = logging.getLogger(__name__)

ResultDedupeKey: TypeAlias = tuple[str, Hashable]


def _mmr_act_diverse(
    candidates: List[RetrievalResult],
    top_k: int,
    max_per_act: int,
) -> List[RetrievalResult]:
    """Select up to top_k results while capping contributions per act_id.

    Iterates the score-sorted candidates and picks each result only if
    the act has not yet reached ``max_per_act`` slots.  This avoids
    returning multiple near-duplicate chunks from the same legal act,
    ensuring diversity across acts in the final context window.
    """
    selected: List[RetrievalResult] = []
    act_counts: Dict[int, int] = {}
    for r in candidates:
        act_id = r.act_id
        count = act_counts.get(act_id, 0)
        if count < max_per_act:
            selected.append(r)
            act_counts[act_id] = count + 1
        if len(selected) >= top_k:
            break
    return selected


[docs] class RetrievalService(SemanticStrategyMixin, LexicalStrategyMixin, HybridStrategyMixin): """ Unified retrieval service combining semantic and lexical search. Responsibilities: - Execute hybrid search (retrieve) across multiple collections - Delegate semantic/lexical/hybrid-precomputed searches to strategy mixins - Fuse results using RRF or weighted scores - Rerank, deduplicate, and cache results Does NOT: - Modify Qdrant collections - Generate embeddings (uses EmbeddingService) - Enrich context (uses ContextService) """ def __init__( self, qdrant_repos: Dict[str, QdrantRepository], pg_repo: PostgresRepository, embedding_service: EmbeddingService, reranker: Optional[RerankService] = None, payload_builder: Optional[PayloadBuilder] = None, search_language: Optional[str] = None, candidate_multiplier: Optional[float] = None, min_candidates: Optional[int] = None, max_candidates: Optional[int] = None, semantic_per_collection_oversampling: Optional[float] = None, hnsw_ef: Optional[int] = None, exact_search: Optional[bool] = None, query_expansion_enabled: Optional[bool] = None, query_expansion_max_variants: Optional[int] = None, query_expansion_min_query_chars: Optional[int] = None, lexical_weight: Optional[float] = None, semantic_weight: Optional[float] = None, fusion_method: Optional[str] = None, dynamic_fusion_enabled: Optional[bool] = None, redis_client: Optional[Any] = None, result_cache_ttl: Optional[int] = None, preset_embedding_services: Optional[Dict[str, EmbeddingService]] = None, ): self.qdrant_repos = qdrant_repos self.embedding_service = embedding_service self.preset_embedding_services: Dict[str, EmbeddingService] = preset_embedding_services or {} sc = ResolvedSearchConfig.from_overrides( search_language=search_language, candidate_multiplier=candidate_multiplier, min_candidates=min_candidates, max_candidates=max_candidates, hnsw_ef=hnsw_ef, exact_search=exact_search, semantic_per_collection_oversampling=semantic_per_collection_oversampling, query_expansion_enabled=query_expansion_enabled, query_expansion_max_variants=query_expansion_max_variants, query_expansion_min_query_chars=query_expansion_min_query_chars, lexical_weight=lexical_weight, semantic_weight=semantic_weight, fusion_method=fusion_method, dynamic_fusion_enabled=dynamic_fusion_enabled, result_cache_ttl=result_cache_ttl, ) self.candidate_multiplier = sc.candidate_multiplier self.min_candidates = sc.min_candidates self.max_candidates = sc.max_candidates self.hnsw_ef = sc.hnsw_ef self.exact_search = sc.exact_search self.semantic_per_collection_oversampling = sc.per_collection_oversampling self.query_expansion_enabled = sc.query_expansion_enabled self.query_expansion_max_variants = sc.query_expansion_max_variants self.query_expander = LegalQueryExpansionService( min_query_chars=sc.query_expansion_min_query_chars, ) self.fusion_method = sc.fusion_method self.lexical_weight = sc.lexical_weight self.semantic_weight = sc.semantic_weight self.dynamic_fusion_enabled = sc.dynamic_fusion_enabled self._lexical_boost_factor = sc.lexical_boost_factor self._lexical_boost_max = sc.lexical_boost_max self.bm25_service = BM25SearchService( pg_repo=pg_repo, language=sc.search_language, payload_builder=payload_builder, ) self.semantic_service = SemanticSearchService( qdrant_repos=qdrant_repos, score_threshold=None, hnsw_ef=sc.hnsw_ef, exact_search=sc.exact_search, per_collection_oversampling=sc.per_collection_oversampling, ) self.fusion_service = ResultFusionService( fusion_method=sc.fusion_method, lexical_weight=sc.lexical_weight, semantic_weight=sc.semantic_weight, rrf_k=get_config().search.fusion_rrf_k, ) self.reranker = reranker self._cache = RetrievalCache(redis_client, sc.result_cache_ttl) self.last_retrieval_stats: RetrievalStats = RetrievalStats() def _resolve_embedding_context( self, embedding_preset: Optional[str], ) -> tuple[EmbeddingService, str]: if embedding_preset and embedding_preset in self.preset_embedding_services: return self.preset_embedding_services[embedding_preset], f"__{embedding_preset}" return self.embedding_service, "" def _base_collections_from_granularity( self, *, granularity: Optional[str], collections: Optional[List[str]], ) -> List[str]: if granularity: if granularity == "all": return [k for k in self.qdrant_repos if "__" not in k] return [granularity] if collections: return list(collections) return ["chunks"] @staticmethod def _apply_preset_suffix(base_collections: List[str], suffix: str) -> List[str]: if not suffix: return base_collections return [collection if "__" in collection else f"{collection}{suffix}" for collection in base_collections]
[docs] def retrieve( self, query: str, top_k: int = 10, score_threshold: Optional[float] = None, filters: Optional[Dict[str, Any]] = None, collections: Optional[List[str]] = None, granularity: Optional[str] = None, embedding_preset: Optional[str] = None, ) -> List[RetrievalResult]: """Hybrid search (semantic + BM25) with fusion, reranking, and caching. Args: query: Search query top_k: Number of final results to return score_threshold: Minimum score threshold (post-fusion) filters: Metadata filters (act_id, celex, etc.) collections: Specific collections to search granularity: 'subdivisions', 'chunks', or 'all' (overrides collections) embedding_preset: Route semantic search to this preset's collections/embedding service """ operation = "hybrid" total_started_at = time.perf_counter() # Resolve preset embedding service and collection suffix active_embedding_svc, col_suffix = self._resolve_embedding_context(embedding_preset) cache_key = RetrievalCache.cache_key( query, top_k, score_threshold, filters, granularity, collections, embedding_preset, ) cached = self._cache.get(cache_key) if cached is not None: logger.debug("Cache hit for retrieval key %s", cache_key) self.last_retrieval_stats = RetrievalStats( candidates_returned=len(cached), cache_hit=True, ) return cached expanded_queries = self._expanded_queries(query) dyn_lex_w, dyn_sem_w = self._resolve_fusion_override(query) base_collections = self._base_collections_from_granularity( granularity=granularity, collections=collections, ) # Apply preset suffix to route semantic search to the right Qdrant collections search_collections = self._apply_preset_suffix(base_collections, col_suffix) candidate_k = self._candidate_pool_size(top_k) lexical_target: Optional[str] = None if "chunks" in base_collections and "subdivisions" not in base_collections: lexical_target = "chunks" elif "subdivisions" in base_collections: lexical_target = "subdivisions" try: parallel_started_at = time.perf_counter() semantic_telemetry: Dict[str, float] = {} lexical_telemetry: Dict[str, float] = {} with ThreadPoolExecutor(max_workers=min(2, self._max_parallel_workers())) as executor: sem_future = executor.submit( self._semantic_search_with_expansion, query=query, expanded_queries=expanded_queries, top_k=candidate_k, filters=filters, collections=search_collections, output_top_k=top_k, embedding_service=active_embedding_svc, telemetry=semantic_telemetry, ) lex_future = executor.submit( self._lexical_search_with_expansion, query=query, expanded_queries=expanded_queries, top_k=candidate_k, filters=filters, target=lexical_target, output_top_k=top_k, telemetry=lexical_telemetry, ) try: semantic_results = sem_future.result() except Exception as exc: observe_retrieval_error(operation=operation, phase="semantic_search", exc_or_reason=exc) raise try: lexical_results = lex_future.result() except Exception as exc: observe_retrieval_error(operation=operation, phase="lexical_search", exc_or_reason=exc) raise observe_retrieval_phase( operation=operation, phase="semantic_search", duration_seconds=max(float(semantic_telemetry.get("search_ms", 0.0)), 0.0) / 1000.0, ) observe_retrieval_phase( operation=operation, phase="lexical_search", duration_seconds=max(float(lexical_telemetry.get("search_ms", 0.0)), 0.0) / 1000.0, ) parallel_ms = (time.perf_counter() - parallel_started_at) * 1000.0 logger.warning( "PERF retrieve parallel phase: %.1f ms (semantic=%d, lexical=%d results)", parallel_ms, len(semantic_results), len(lexical_results), ) fuse_started_at = time.perf_counter() fused_results = self._fuse_results( lexical_results, semantic_results, override_lexical_weight=dyn_lex_w, override_semantic_weight=dyn_sem_w, ) fusion_ms = (time.perf_counter() - fuse_started_at) * 1000.0 observe_retrieval_phase(operation=operation, phase="fusion", duration_seconds=fusion_ms / 1000.0) cfg = get_config() effective_threshold = score_threshold if score_threshold is not None else cfg.search.score_threshold_default candidates_after_fusion = len(fused_results) if effective_threshold is not None: threshold_started_at = time.perf_counter() fused_results = [r for r in fused_results if r.score >= effective_threshold] observe_retrieval_phase( operation=operation, phase="score_threshold", duration_seconds=time.perf_counter() - threshold_started_at, ) candidates_after_threshold = len(fused_results) rerank_started_at = time.perf_counter() reranked = self._rerank_results(query, fused_results, top_k, operation=operation) rerank_ms = (time.perf_counter() - rerank_started_at) * 1000.0 candidates_after_rerank = len(reranked) # Adaptive score-drop cutoff: truncate at the first large gap adaptive_cutoff_applied = False drop_threshold = cfg.search.adaptive_score_drop_threshold if drop_threshold is not None and len(reranked) > 1: scores = [r.score for r in reranked] for i in range(1, len(scores)): if scores[i - 1] - scores[i] > drop_threshold: reranked = reranked[:i] adaptive_cutoff_applied = True logger.debug( "Adaptive cutoff at position %d (drop %.3f > %.3f)", i, scores[i - 1] - scores[i], drop_threshold, ) break # MMR act-based diversity: cap chunks per act_id if cfg.search.mmr_enabled and cfg.search.mmr_max_per_act > 0: reranked = _mmr_act_diverse(reranked, top_k, cfg.search.mmr_max_per_act) final = reranked[:top_k] total_retrieve_ms = (time.perf_counter() - total_started_at) * 1000.0 logger.warning( "PERF retrieve total: %.1f ms (parallel=%.1f, fusion=%.1f, rerank=%.1f)", total_retrieve_ms, parallel_ms, fusion_ms, rerank_ms, ) self.last_retrieval_stats = RetrievalStats( candidates_after_fusion=candidates_after_fusion, candidates_after_threshold=candidates_after_threshold, candidates_after_rerank=candidates_after_rerank, candidates_after_adaptive_cutoff=len(reranked), candidates_returned=len(final), adaptive_cutoff_applied=adaptive_cutoff_applied, effective_score_threshold=effective_threshold, fusion_lexical_weight=dyn_lex_w if dyn_lex_w is not None else self.lexical_weight, fusion_semantic_weight=dyn_sem_w if dyn_sem_w is not None else self.semantic_weight, query_variants_count=len(expanded_queries), cache_hit=False, embedding_ms=round(float(semantic_telemetry.get("embedding_ms", 0.0)), 1), semantic_search_ms=round(float(semantic_telemetry.get("search_ms", 0.0)), 1), lexical_search_ms=round(float(lexical_telemetry.get("search_ms", 0.0)), 1), parallel_search_ms=round(parallel_ms, 1), fusion_ms=round(fusion_ms, 1), rerank_ms=round(rerank_ms, 1), total_retrieve_ms=round(total_retrieve_ms, 1), ) self._cache.set(cache_key, final) return final finally: observe_retrieval_phase( operation=operation, phase="total", duration_seconds=time.perf_counter() - total_started_at, )
# ------------------------------------------------------------------------- # Shared helpers used by strategy mixins via self.* # ------------------------------------------------------------------------- def _rerank_results( self, query: str, results: List[RetrievalResult], top_k: Optional[int] = None, *, operation: str, ) -> List[RetrievalResult]: reranker = self.reranker if reranker is None: return results rerank_started_at = time.perf_counter() try: return reranker.rerank(query, results, top_k=top_k) except Exception as exc: observe_retrieval_error(operation=operation, phase="rerank", exc_or_reason=exc) raise finally: observe_retrieval_phase( operation=operation, phase="rerank", duration_seconds=time.perf_counter() - rerank_started_at, ) def _expanded_queries(self, query: str) -> List[ExpandedQuery]: normalized = query.strip() if not normalized: return [] if not self.query_expansion_enabled: return [ExpandedQuery(text=normalized, weight=1.0, strategy="original")] max_variants = self.query_expansion_max_variants if not has_explicit_legal_reference(normalized): max_variants = min(max_variants, 2) expanded = self.query_expander.expand(normalized, max_variants=max_variants) if not expanded: return [ExpandedQuery(text=normalized, weight=1.0, strategy="original")] return expanded @staticmethod def _max_parallel_workers() -> int: return max(int(get_config().search.max_parallel_workers), 1) def _branch_workers(self) -> int: return max(1, self._max_parallel_workers() // 2) @staticmethod def _variant_candidate_pool_size(candidate_k: int, *, variants_count: int, output_top_k: int) -> int: if variants_count <= 1: return max(candidate_k, output_top_k, 1) spread = int(candidate_k / variants_count) buffer = max(int(output_top_k / 2), 1) return max(min(spread + buffer, candidate_k), output_top_k, 1) @staticmethod def _apply_query_variant_metadata( results: List[RetrievalResult], *, expanded: ExpandedQuery, variant_index: int, search_branch: str, expansion_enabled: bool, ) -> None: for rank, result in enumerate(results, start=1): raw_score = result.score weighted_score = raw_score * expanded.weight metadata = result.metadata metadata["query_variant"] = expanded.text metadata["query_variant_strategy"] = expanded.strategy metadata["query_variant_weight"] = expanded.weight metadata["query_variant_index"] = variant_index metadata["query_variant_rank"] = rank metadata["query_expansion_enabled"] = expansion_enabled if search_branch == "lexical": metadata.setdefault("lexical_score_raw", raw_score) metadata["lexical_score_normalized"] = raw_score else: metadata["semantic_score_raw"] = raw_score result.score = weighted_score def _dedupe_by_best_score(self, results: List[RetrievalResult]) -> List[RetrievalResult]: best_by_key: Dict[ResultDedupeKey, RetrievalResult] = {} for result in results: key = self._result_dedupe_key(result) current = best_by_key.get(key) if current is None or result.score > current.score: best_by_key[key] = result return list(best_by_key.values()) @staticmethod def _result_dedupe_key(result: RetrievalResult) -> ResultDedupeKey: metadata = result.metadata chunk_id = metadata.get("chunk_id") if chunk_id is not None: chunk_hashable = chunk_id if isinstance(chunk_id, Hashable) else str(chunk_id) return ("chunk", chunk_hashable) if result.subdivision_id: return ("subdivision", result.subdivision_id) if result.act_id: return ("act", result.act_id) return ("unknown", id(result)) def _candidate_pool_size(self, top_k: int) -> int: scaled = int(top_k * self.candidate_multiplier) with_floor = max(scaled, self.min_candidates) with_cap = min(with_floor, self.max_candidates) return max(with_cap, 1)
[docs] def get_statistics(self) -> Dict[str, Any]: """Return retrieval service statistics.""" qdrant_repo = self.qdrant_repos.get("chunks") if not qdrant_repo and self.qdrant_repos: qdrant_repo = next(iter(self.qdrant_repos.values())) vector_count: int | str = "unknown" if qdrant_repo is not None: try: collection_info = qdrant_repo.client.get_collection(qdrant_repo.collection_name) points_count = collection_info.points_count vector_count = int(points_count) if points_count is not None else "unknown" except Exception: vector_count = "unknown" return { "fusion_method": self.fusion_method, "lexical_weight": self.lexical_weight, "semantic_weight": self.semantic_weight, "candidate_multiplier": self.candidate_multiplier, "min_candidates": self.min_candidates, "max_candidates": self.max_candidates, "semantic_per_collection_oversampling": self.semantic_per_collection_oversampling, "hnsw_ef": self.hnsw_ef, "exact_search": self.exact_search, "query_expansion_enabled": self.query_expansion_enabled, "query_expansion_max_variants": self.query_expansion_max_variants, "query_expansion_min_query_chars": self.query_expander.min_query_chars, "vector_documents": vector_count, "embedding_model": self.embedding_service.model_name, "embedding_dimension": self.embedding_service.get_vector_size(), }