Source code for lalandre_rag.retrieval.rerank_service

"""
Rerank Service
Cross-encoder reranking — local (in-process) or via dedicated HTTP service.
"""

import logging
import time
from dataclasses import dataclass
from typing import Any, Optional, cast

import httpx

from .result import RetrievalResult

logger = logging.getLogger(__name__)


[docs] @dataclass class RerankConfig: """Runtime configuration for the retrieval reranker.""" model_name: str device: str batch_size: int max_candidates: int max_chars: int enabled: bool = True cache_dir: Optional[str] = None rerank_service_url: Optional[str] = None service_timeout_seconds: float = 10.0 fallback_to_skip: bool = True circuit_failure_threshold: int = 2 circuit_cooldown_seconds: float = 30.0
[docs] class RerankService: """ Cross-encoder reranker. Two modes: - **HTTP** (when ``rerank_service_url`` is set): calls the dedicated rerank-service. - **Local** (fallback): loads CrossEncoder in-process via sentence-transformers. If the HTTP service is unreachable and ``fallback_to_skip`` is True, results are returned without reranking. Includes a circuit breaker: after *circuit_failure_threshold* consecutive HTTP failures, reranking is skipped for *circuit_cooldown_seconds*. """ def __init__(self, config: RerankConfig): self.config = config self._use_http = bool(config.rerank_service_url) self._model: Any = None self._model_load_failed = False # Circuit breaker state self._consecutive_failures: int = 0 self._circuit_open_until: float = 0.0 # ── Public API ──────────────────────────────────────────────────────
[docs] def load(self) -> bool: """Eagerly load the local reranker model (no-op in HTTP mode).""" if self._use_http: return True return self._load_model()
[docs] def rerank( self, query: str, results: list["RetrievalResult"], top_k: Optional[int] = None, ) -> list["RetrievalResult"]: """Rerank retrieval results using a cross-encoder.""" if not self.config.enabled or not query or not results: return results if self._use_http: return self._rerank_http(query, results, top_k) return self._rerank_local(query, results, top_k)
# ── HTTP mode ─────────────────────────────────────────────────────── def _rerank_http( self, query: str, results: list["RetrievalResult"], top_k: Optional[int], ) -> list["RetrievalResult"]: # ── Circuit breaker check ───────────────────────────────────── now = time.monotonic() threshold = self.config.circuit_failure_threshold if self._consecutive_failures >= threshold: if now < self._circuit_open_until: remaining = self._circuit_open_until - now logger.warning( "Rerank circuit breaker OPEN (failures=%d, retry in %.1fs), skipping", self._consecutive_failures, remaining, ) if self.config.fallback_to_skip: return results raise RuntimeError("Rerank circuit breaker is open") else: logger.info("Rerank circuit breaker HALF-OPEN, attempting probe request") max_candidates = min(len(results), max(self.config.max_candidates, 1)) candidates = results[:max_candidates] remainder = results[max_candidates:] documents = [] for idx, item in enumerate(candidates): text = item.content if not text: raw_content = item.metadata.get("content") text = str(raw_content) if raw_content is not None else "" documents.append({"id": str(idx), "content": self._truncate(text)}) url = f"{self.config.rerank_service_url}/rerank" payload = { "query": query, "documents": documents, "top_k": None, # get scores for all candidates } try: response = httpx.post( url, json=payload, timeout=self.config.service_timeout_seconds, ) response.raise_for_status() data = response.json() except Exception as exc: self._consecutive_failures += 1 self._circuit_open_until = time.monotonic() + self.config.circuit_cooldown_seconds logger.warning( "Rerank service call failed (%s), %s (failures=%d): %s", url, "skipping reranking" if self.config.fallback_to_skip else "raising", self._consecutive_failures, exc, ) if self.config.fallback_to_skip: return results raise # Success — reset circuit breaker self._consecutive_failures = 0 # Map scores back to RetrievalResult objects score_map: dict[int, float] = {} for item in data.get("results", []): score_map[int(item["id"])] = float(item["score"]) reranked: list[RetrievalResult] = [] for idx, item in enumerate(candidates): score = score_map.get(idx) if score is None: continue metadata = dict(item.metadata) metadata["pre_rerank_score"] = item.score metadata["rerank_score"] = score metadata["rerank_model"] = self.config.model_name metadata["rerank_mode"] = "http" reranked.append( RetrievalResult( content=item.content, score=score, subdivision_id=item.subdivision_id, act_id=item.act_id, celex=item.celex, subdivision_type=item.subdivision_type, sequence_order=item.sequence_order, metadata=metadata, ) ) reranked.sort(key=lambda x: x.score, reverse=True) combined = reranked + remainder return combined[:top_k] if top_k is not None else combined # ── Local mode ────────────────────────────────────────────────────── def _rerank_local( self, query: str, results: list["RetrievalResult"], top_k: Optional[int], ) -> list["RetrievalResult"]: if not self._load_model(): return results model = self._model if model is None: return results max_candidates = min(len(results), max(self.config.max_candidates, 1)) candidates = results[:max_candidates] remainder = results[max_candidates:] pairs: list[tuple[str, str]] = [] for item in candidates: text = item.content if not text: raw_content = item.metadata.get("content") text = str(raw_content) if raw_content is not None else "" pairs.append((query, self._truncate(text))) try: raw_scores = model.predict( pairs, batch_size=self.config.batch_size, show_progress_bar=False, ) scores = [float(score) for score in raw_scores] except Exception as exc: logger.warning("Reranker inference failed, falling back: %s", exc) return results reranked: list[RetrievalResult] = [] for item, score in zip(candidates, scores): metadata = dict(item.metadata) metadata["pre_rerank_score"] = item.score metadata["rerank_score"] = score metadata["rerank_model"] = self.config.model_name metadata["rerank_mode"] = "local" reranked.append( RetrievalResult( content=item.content, score=score, subdivision_id=item.subdivision_id, act_id=item.act_id, celex=item.celex, subdivision_type=item.subdivision_type, sequence_order=item.sequence_order, metadata=metadata, ) ) reranked.sort(key=lambda x: x.score, reverse=True) combined = reranked + remainder return combined[:top_k] if top_k is not None else combined # ── Internal helpers ──────────────────────────────────────────────── def _load_model(self) -> bool: if not self.config.enabled: return False if self._model is not None: return True if self._model_load_failed: return False try: from sentence_transformers import CrossEncoder cross_encoder_cls = cast(Any, CrossEncoder) if self.config.cache_dir is not None: try: loaded_model = cross_encoder_cls( self.config.model_name, device=self.config.device, cache_dir=self.config.cache_dir, ) except TypeError: loaded_model = cross_encoder_cls( self.config.model_name, device=self.config.device, ) else: loaded_model = cross_encoder_cls( self.config.model_name, device=self.config.device, ) self._model = loaded_model return True except Exception as exc: logger.warning("Failed to load reranker model '%s': %s", self.config.model_name, exc) self._model_load_failed = True return False def _truncate(self, text: str) -> str: if not text: return "" max_chars = self.config.max_chars if max_chars <= 0: return text return text[:max_chars]