Source code for lalandre_rag.retrieval.fusion_service

"""
Result Fusion Service
Algorithms for combining search results from multiple sources
"""

import logging
from collections.abc import Hashable, Sequence
from typing import Any, Dict, List, Optional, Tuple, TypeAlias

from .result import RetrievalResult

logger = logging.getLogger(__name__)

ResultKey: TypeAlias = Tuple[str, Hashable]


[docs] class ResultFusionService: """ Fusion service for combining search results Implements multiple fusion algorithms: - Reciprocal Rank Fusion (RRF) - rank-based fusion - Weighted Score Fusion - score-based fusion with weights - Score normalization utilities Use cases: - Combine BM25 (lexical) + semantic search - Combine multiple semantic searches - Combine Graph RAG + standard RAG - Multi-stage retrieval pipelines Responsibilities: - Implement fusion algorithms (RRF, weighted) - Deduplicate results by subdivision_id - Normalize scores to [0, 1] range - Preserve metadata from all sources Does NOT: - Execute searches (uses search services) - Generate embeddings - Access databases """ def __init__( self, fusion_method: str = "rrf", lexical_weight: float = 0.3, semantic_weight: float = 0.7, rrf_k: int = 60 ): """ Initialize fusion service Args: fusion_method: "rrf" or "weighted" lexical_weight: Weight for lexical scores (weighted method) semantic_weight: Weight for semantic scores (weighted method) rrf_k: RRF constant (typically 60) """ self.fusion_method = fusion_method self.rrf_k = rrf_k # Normalize weights for weighted method if fusion_method == "weighted": total = lexical_weight + semantic_weight self.lexical_weight = lexical_weight / total self.semantic_weight = semantic_weight / total else: self.lexical_weight = lexical_weight self.semantic_weight = semantic_weight
[docs] def fuse( self, lexical_results: Sequence["RetrievalResult"], semantic_results: Sequence["RetrievalResult"], override_lexical_weight: Optional[float] = None, override_semantic_weight: Optional[float] = None, ) -> List["RetrievalResult"]: """ Fuse lexical and semantic search results. When override weights are provided the method always uses weighted score fusion regardless of the configured fusion_method. This lets callers apply dynamic weights (e.g. boosted lexical weight for queries containing explicit legal references) without changing the service-level default. Args: lexical_results: BM25 or other lexical search results semantic_results: Vector-based semantic search results override_lexical_weight: Forces weighted fusion with this lexical weight override_semantic_weight: Forces weighted fusion with this semantic weight Returns: Fused and sorted results """ if not lexical_results and not semantic_results: return [] if not lexical_results: return list(semantic_results) if not semantic_results: return list(lexical_results) use_override = override_lexical_weight is not None or override_semantic_weight is not None effective_method = "weighted" if use_override else self.fusion_method logger.debug( "Fusing %d lexical + %d semantic using %s", len(lexical_results), len(semantic_results), effective_method, ) if effective_method == "rrf": fused = self.reciprocal_rank_fusion( lexical_results, semantic_results, k=self.rrf_k, ) else: fused = self.weighted_score_fusion( lexical_results, semantic_results, lexical_weight=override_lexical_weight, semantic_weight=override_semantic_weight, ) logger.info("Fused results: %d unique documents", len(fused)) return fused
@staticmethod def _to_hashable(value: Any) -> Hashable: """Return a hashable representation for fusion keys.""" return value if isinstance(value, Hashable) else str(value) @classmethod def _get_result_key(cls, result: "RetrievalResult") -> ResultKey: metadata: Dict[str, Any] = result.metadata chunk_id = metadata.get("chunk_id") if chunk_id is not None: return ("chunk", cls._to_hashable(chunk_id)) if result.subdivision_id: return ("subdivision", result.subdivision_id) if result.act_id: return ("act", result.act_id) return ("unknown", id(result)) @staticmethod def _is_retrieval_segment(result: "RetrievalResult") -> bool: metadata: Dict[str, Any] = result.metadata return bool(metadata.get("is_retrieval_segment") or metadata.get("retrieval_segment_index") is not None) @classmethod def _prefer_result( cls, current: "RetrievalResult", candidate: "RetrievalResult", ) -> "RetrievalResult": current_is_segment = cls._is_retrieval_segment(current) candidate_is_segment = cls._is_retrieval_segment(candidate) if candidate_is_segment and not current_is_segment: return candidate return current @staticmethod def _clamp_unit_interval(score: float) -> float: return min(max(float(score), 0.0), 1.0) @staticmethod def _normalize_rrf_score(raw_score: float, rrf_constant: int) -> float: max_rrf_score = 2.0 / (float(rrf_constant) + 1.0) if max_rrf_score <= 0.0: return 0.0 return ResultFusionService._clamp_unit_interval(raw_score / max_rrf_score)
[docs] def weighted_score_fusion( self, lexical_results: Sequence["RetrievalResult"], semantic_results: Sequence["RetrievalResult"], lexical_weight: Optional[float] = None, semantic_weight: Optional[float] = None, ) -> List["RetrievalResult"]: """ Weighted score fusion Combines scores using weighted average: combined_score = lexical_weight * lex_score + semantic_weight * sem_score Args: lexical_results: Lexical search results (with scores) semantic_results: Semantic search results (with scores) lexical_weight: Weight for lexical scores (default: instance weight) semantic_weight: Weight for semantic scores (default: instance weight) Returns: Fused results sorted by combined score (descending) """ lex_w = lexical_weight if lexical_weight is not None else self.lexical_weight sem_w = semantic_weight if semantic_weight is not None else self.semantic_weight # Index by subdivision_id for deduplication result_map: Dict[ResultKey, RetrievalResult] = {} score_map: Dict[ResultKey, Tuple[float, float]] = {} # (lexical_score, semantic_score) # Add lexical scores for result in lexical_results: key = self._get_result_key(result) result_map[key] = result score_map[key] = (result.score, 0.0) # Add semantic scores for result in semantic_results: key = self._get_result_key(result) if key in result_map: lex_score, _ = score_map[key] result_map[key] = self._prefer_result(result_map[key], result) score_map[key] = (lex_score, result.score) else: result_map[key] = result score_map[key] = (0.0, result.score) # Compute weighted fusion fused_results: List[RetrievalResult] = [] for key, result in result_map.items(): lex_score, sem_score = score_map[key] combined_score = self._clamp_unit_interval(lex_w * lex_score + sem_w * sem_score) # Create new result with combined score # Determine provenance based on which sources contributed if lex_score > 0.0 and sem_score > 0.0: method = "hybrid" elif lex_score > 0.0: method = "bm25" else: method = "semantic" fused_result = RetrievalResult( content=result.content, score=combined_score, subdivision_id=result.subdivision_id, act_id=result.act_id, celex=result.celex, subdivision_type=result.subdivision_type, sequence_order=result.sequence_order, metadata={ **result.metadata, "search_method": method, "lexical_score": lex_score, "semantic_score": sem_score, "fusion_method": "weighted", }, ) fused_results.append(fused_result) # Sort by combined score (descending) fused_results.sort(key=lambda x: x.score, reverse=True) return fused_results
[docs] def reciprocal_rank_fusion( self, lexical_results: Sequence["RetrievalResult"], semantic_results: Sequence["RetrievalResult"], k: Optional[int] = None, ) -> List["RetrievalResult"]: """ Reciprocal Rank Fusion (RRF) RRF formula: RRF(d) = sum(1 / (k + rank(d))) where k is a constant (typically 60) and rank starts at 1 RRF is score-agnostic and only considers ranking position, making it robust to score distribution differences. Args: lexical_results: Lexical search results (pre-sorted) semantic_results: Semantic search results (pre-sorted) k: RRF constant (default: instance rrf_k, typically 60) Returns: Fused results sorted by RRF score (descending) """ rrf_constant = k if k is not None else self.rrf_k rrf_scores: Dict[ResultKey, float] = {} result_map: Dict[ResultKey, RetrievalResult] = {} rank_map: Dict[ResultKey, Tuple[int, int]] = {} # (lex_rank, sem_rank) 0=not present # Add lexical ranks for rank, result in enumerate(lexical_results, start=1): key = self._get_result_key(result) result_map[key] = result rrf_scores[key] = 1.0 / (rrf_constant + rank) rank_map[key] = (rank, 0) # Add semantic ranks for rank, result in enumerate(semantic_results, start=1): key = self._get_result_key(result) if key not in result_map: result_map[key] = result rrf_scores[key] = 0.0 rank_map[key] = (0, rank) else: result_map[key] = self._prefer_result(result_map[key], result) lex_rank, _ = rank_map[key] rank_map[key] = (lex_rank, rank) rrf_scores[key] += 1.0 / (rrf_constant + rank) # Create results with RRF scores fused_results: List[RetrievalResult] = [] for key, rrf_score_raw in rrf_scores.items(): result = result_map[key] lex_rank, sem_rank = rank_map[key] rrf_score = self._normalize_rrf_score(rrf_score_raw, rrf_constant) # Determine provenance based on which sources contributed if lex_rank > 0 and sem_rank > 0: method = "hybrid" elif lex_rank > 0: method = "bm25" else: method = "semantic" fused_result = RetrievalResult( content=result.content, score=rrf_score, subdivision_id=result.subdivision_id, act_id=result.act_id, celex=result.celex, subdivision_type=result.subdivision_type, sequence_order=result.sequence_order, metadata={ **result.metadata, "search_method": method, "lexical_rank": lex_rank, "semantic_rank": sem_rank, "rrf_score": rrf_score, "rrf_score_raw": rrf_score_raw, "fusion_method": "rrf", }, ) fused_results.append(fused_result) # Sort by RRF score (descending) fused_results.sort(key=lambda x: x.score, reverse=True) return fused_results
[docs] def get_statistics(self) -> Dict[str, Any]: """ Get fusion service statistics Returns: Dictionary with configuration """ return { "fusion_method": self.fusion_method, "lexical_weight": self.lexical_weight, "semantic_weight": self.semantic_weight, "rrf_k": self.rrf_k, }