"""
Result Fusion Service
Algorithms for combining search results from multiple sources
"""
import logging
from collections.abc import Hashable, Sequence
from typing import Any, Dict, List, Optional, Tuple, TypeAlias
from .result import RetrievalResult
logger = logging.getLogger(__name__)
ResultKey: TypeAlias = Tuple[str, Hashable]
[docs]
class ResultFusionService:
"""
Fusion service for combining search results
Implements multiple fusion algorithms:
- Reciprocal Rank Fusion (RRF) - rank-based fusion
- Weighted Score Fusion - score-based fusion with weights
- Score normalization utilities
Use cases:
- Combine BM25 (lexical) + semantic search
- Combine multiple semantic searches
- Combine Graph RAG + standard RAG
- Multi-stage retrieval pipelines
Responsibilities:
- Implement fusion algorithms (RRF, weighted)
- Deduplicate results by subdivision_id
- Normalize scores to [0, 1] range
- Preserve metadata from all sources
Does NOT:
- Execute searches (uses search services)
- Generate embeddings
- Access databases
"""
def __init__(
self, fusion_method: str = "rrf", lexical_weight: float = 0.3, semantic_weight: float = 0.7, rrf_k: int = 60
):
"""
Initialize fusion service
Args:
fusion_method: "rrf" or "weighted"
lexical_weight: Weight for lexical scores (weighted method)
semantic_weight: Weight for semantic scores (weighted method)
rrf_k: RRF constant (typically 60)
"""
self.fusion_method = fusion_method
self.rrf_k = rrf_k
# Normalize weights for weighted method
if fusion_method == "weighted":
total = lexical_weight + semantic_weight
self.lexical_weight = lexical_weight / total
self.semantic_weight = semantic_weight / total
else:
self.lexical_weight = lexical_weight
self.semantic_weight = semantic_weight
[docs]
def fuse(
self,
lexical_results: Sequence["RetrievalResult"],
semantic_results: Sequence["RetrievalResult"],
override_lexical_weight: Optional[float] = None,
override_semantic_weight: Optional[float] = None,
) -> List["RetrievalResult"]:
"""
Fuse lexical and semantic search results.
When override weights are provided the method always uses weighted
score fusion regardless of the configured fusion_method. This lets
callers apply dynamic weights (e.g. boosted lexical weight for
queries containing explicit legal references) without changing the
service-level default.
Args:
lexical_results: BM25 or other lexical search results
semantic_results: Vector-based semantic search results
override_lexical_weight: Forces weighted fusion with this lexical weight
override_semantic_weight: Forces weighted fusion with this semantic weight
Returns:
Fused and sorted results
"""
if not lexical_results and not semantic_results:
return []
if not lexical_results:
return list(semantic_results)
if not semantic_results:
return list(lexical_results)
use_override = override_lexical_weight is not None or override_semantic_weight is not None
effective_method = "weighted" if use_override else self.fusion_method
logger.debug(
"Fusing %d lexical + %d semantic using %s",
len(lexical_results),
len(semantic_results),
effective_method,
)
if effective_method == "rrf":
fused = self.reciprocal_rank_fusion(
lexical_results,
semantic_results,
k=self.rrf_k,
)
else:
fused = self.weighted_score_fusion(
lexical_results,
semantic_results,
lexical_weight=override_lexical_weight,
semantic_weight=override_semantic_weight,
)
logger.info("Fused results: %d unique documents", len(fused))
return fused
@staticmethod
def _to_hashable(value: Any) -> Hashable:
"""Return a hashable representation for fusion keys."""
return value if isinstance(value, Hashable) else str(value)
@classmethod
def _get_result_key(cls, result: "RetrievalResult") -> ResultKey:
metadata: Dict[str, Any] = result.metadata
chunk_id = metadata.get("chunk_id")
if chunk_id is not None:
return ("chunk", cls._to_hashable(chunk_id))
if result.subdivision_id:
return ("subdivision", result.subdivision_id)
if result.act_id:
return ("act", result.act_id)
return ("unknown", id(result))
@staticmethod
def _is_retrieval_segment(result: "RetrievalResult") -> bool:
metadata: Dict[str, Any] = result.metadata
return bool(metadata.get("is_retrieval_segment") or metadata.get("retrieval_segment_index") is not None)
@classmethod
def _prefer_result(
cls,
current: "RetrievalResult",
candidate: "RetrievalResult",
) -> "RetrievalResult":
current_is_segment = cls._is_retrieval_segment(current)
candidate_is_segment = cls._is_retrieval_segment(candidate)
if candidate_is_segment and not current_is_segment:
return candidate
return current
@staticmethod
def _clamp_unit_interval(score: float) -> float:
return min(max(float(score), 0.0), 1.0)
@staticmethod
def _normalize_rrf_score(raw_score: float, rrf_constant: int) -> float:
max_rrf_score = 2.0 / (float(rrf_constant) + 1.0)
if max_rrf_score <= 0.0:
return 0.0
return ResultFusionService._clamp_unit_interval(raw_score / max_rrf_score)
[docs]
def weighted_score_fusion(
self,
lexical_results: Sequence["RetrievalResult"],
semantic_results: Sequence["RetrievalResult"],
lexical_weight: Optional[float] = None,
semantic_weight: Optional[float] = None,
) -> List["RetrievalResult"]:
"""
Weighted score fusion
Combines scores using weighted average:
combined_score = lexical_weight * lex_score + semantic_weight * sem_score
Args:
lexical_results: Lexical search results (with scores)
semantic_results: Semantic search results (with scores)
lexical_weight: Weight for lexical scores (default: instance weight)
semantic_weight: Weight for semantic scores (default: instance weight)
Returns:
Fused results sorted by combined score (descending)
"""
lex_w = lexical_weight if lexical_weight is not None else self.lexical_weight
sem_w = semantic_weight if semantic_weight is not None else self.semantic_weight
# Index by subdivision_id for deduplication
result_map: Dict[ResultKey, RetrievalResult] = {}
score_map: Dict[ResultKey, Tuple[float, float]] = {} # (lexical_score, semantic_score)
# Add lexical scores
for result in lexical_results:
key = self._get_result_key(result)
result_map[key] = result
score_map[key] = (result.score, 0.0)
# Add semantic scores
for result in semantic_results:
key = self._get_result_key(result)
if key in result_map:
lex_score, _ = score_map[key]
result_map[key] = self._prefer_result(result_map[key], result)
score_map[key] = (lex_score, result.score)
else:
result_map[key] = result
score_map[key] = (0.0, result.score)
# Compute weighted fusion
fused_results: List[RetrievalResult] = []
for key, result in result_map.items():
lex_score, sem_score = score_map[key]
combined_score = self._clamp_unit_interval(lex_w * lex_score + sem_w * sem_score)
# Create new result with combined score
# Determine provenance based on which sources contributed
if lex_score > 0.0 and sem_score > 0.0:
method = "hybrid"
elif lex_score > 0.0:
method = "bm25"
else:
method = "semantic"
fused_result = RetrievalResult(
content=result.content,
score=combined_score,
subdivision_id=result.subdivision_id,
act_id=result.act_id,
celex=result.celex,
subdivision_type=result.subdivision_type,
sequence_order=result.sequence_order,
metadata={
**result.metadata,
"search_method": method,
"lexical_score": lex_score,
"semantic_score": sem_score,
"fusion_method": "weighted",
},
)
fused_results.append(fused_result)
# Sort by combined score (descending)
fused_results.sort(key=lambda x: x.score, reverse=True)
return fused_results
[docs]
def reciprocal_rank_fusion(
self,
lexical_results: Sequence["RetrievalResult"],
semantic_results: Sequence["RetrievalResult"],
k: Optional[int] = None,
) -> List["RetrievalResult"]:
"""
Reciprocal Rank Fusion (RRF)
RRF formula: RRF(d) = sum(1 / (k + rank(d)))
where k is a constant (typically 60) and rank starts at 1
RRF is score-agnostic and only considers ranking position,
making it robust to score distribution differences.
Args:
lexical_results: Lexical search results (pre-sorted)
semantic_results: Semantic search results (pre-sorted)
k: RRF constant (default: instance rrf_k, typically 60)
Returns:
Fused results sorted by RRF score (descending)
"""
rrf_constant = k if k is not None else self.rrf_k
rrf_scores: Dict[ResultKey, float] = {}
result_map: Dict[ResultKey, RetrievalResult] = {}
rank_map: Dict[ResultKey, Tuple[int, int]] = {} # (lex_rank, sem_rank) 0=not present
# Add lexical ranks
for rank, result in enumerate(lexical_results, start=1):
key = self._get_result_key(result)
result_map[key] = result
rrf_scores[key] = 1.0 / (rrf_constant + rank)
rank_map[key] = (rank, 0)
# Add semantic ranks
for rank, result in enumerate(semantic_results, start=1):
key = self._get_result_key(result)
if key not in result_map:
result_map[key] = result
rrf_scores[key] = 0.0
rank_map[key] = (0, rank)
else:
result_map[key] = self._prefer_result(result_map[key], result)
lex_rank, _ = rank_map[key]
rank_map[key] = (lex_rank, rank)
rrf_scores[key] += 1.0 / (rrf_constant + rank)
# Create results with RRF scores
fused_results: List[RetrievalResult] = []
for key, rrf_score_raw in rrf_scores.items():
result = result_map[key]
lex_rank, sem_rank = rank_map[key]
rrf_score = self._normalize_rrf_score(rrf_score_raw, rrf_constant)
# Determine provenance based on which sources contributed
if lex_rank > 0 and sem_rank > 0:
method = "hybrid"
elif lex_rank > 0:
method = "bm25"
else:
method = "semantic"
fused_result = RetrievalResult(
content=result.content,
score=rrf_score,
subdivision_id=result.subdivision_id,
act_id=result.act_id,
celex=result.celex,
subdivision_type=result.subdivision_type,
sequence_order=result.sequence_order,
metadata={
**result.metadata,
"search_method": method,
"lexical_rank": lex_rank,
"semantic_rank": sem_rank,
"rrf_score": rrf_score,
"rrf_score_raw": rrf_score_raw,
"fusion_method": "rrf",
},
)
fused_results.append(fused_result)
# Sort by RRF score (descending)
fused_results.sort(key=lambda x: x.score, reverse=True)
return fused_results
[docs]
def get_statistics(self) -> Dict[str, Any]:
"""
Get fusion service statistics
Returns:
Dictionary with configuration
"""
return {
"fusion_method": self.fusion_method,
"lexical_weight": self.lexical_weight,
"semantic_weight": self.semantic_weight,
"rrf_k": self.rrf_k,
}