"""
Rerank Service
Cross-encoder reranking — local (in-process) or via dedicated HTTP service.
"""
import logging
import time
from dataclasses import dataclass
from typing import Any, Optional, cast
import httpx
from .result import RetrievalResult
logger = logging.getLogger(__name__)
[docs]
@dataclass
class RerankConfig:
"""Runtime configuration for the retrieval reranker."""
model_name: str
device: str
batch_size: int
max_candidates: int
max_chars: int
enabled: bool = True
cache_dir: Optional[str] = None
rerank_service_url: Optional[str] = None
service_timeout_seconds: float = 10.0
fallback_to_skip: bool = True
circuit_failure_threshold: int = 2
circuit_cooldown_seconds: float = 30.0
[docs]
class RerankService:
"""
Cross-encoder reranker.
Two modes:
- **HTTP** (when ``rerank_service_url`` is set): calls the dedicated rerank-service.
- **Local** (fallback): loads CrossEncoder in-process via sentence-transformers.
If the HTTP service is unreachable and ``fallback_to_skip`` is True,
results are returned without reranking.
Includes a circuit breaker: after *circuit_failure_threshold* consecutive
HTTP failures, reranking is skipped for *circuit_cooldown_seconds*.
"""
def __init__(self, config: RerankConfig):
self.config = config
self._use_http = bool(config.rerank_service_url)
self._model: Any = None
self._model_load_failed = False
# Circuit breaker state
self._consecutive_failures: int = 0
self._circuit_open_until: float = 0.0
# ── Public API ──────────────────────────────────────────────────────
[docs]
def load(self) -> bool:
"""Eagerly load the local reranker model (no-op in HTTP mode)."""
if self._use_http:
return True
return self._load_model()
[docs]
def rerank(
self,
query: str,
results: list["RetrievalResult"],
top_k: Optional[int] = None,
) -> list["RetrievalResult"]:
"""Rerank retrieval results using a cross-encoder."""
if not self.config.enabled or not query or not results:
return results
if self._use_http:
return self._rerank_http(query, results, top_k)
return self._rerank_local(query, results, top_k)
# ── HTTP mode ───────────────────────────────────────────────────────
def _rerank_http(
self,
query: str,
results: list["RetrievalResult"],
top_k: Optional[int],
) -> list["RetrievalResult"]:
# ── Circuit breaker check ─────────────────────────────────────
now = time.monotonic()
threshold = self.config.circuit_failure_threshold
if self._consecutive_failures >= threshold:
if now < self._circuit_open_until:
remaining = self._circuit_open_until - now
logger.warning(
"Rerank circuit breaker OPEN (failures=%d, retry in %.1fs), skipping",
self._consecutive_failures,
remaining,
)
if self.config.fallback_to_skip:
return results
raise RuntimeError("Rerank circuit breaker is open")
else:
logger.info("Rerank circuit breaker HALF-OPEN, attempting probe request")
max_candidates = min(len(results), max(self.config.max_candidates, 1))
candidates = results[:max_candidates]
remainder = results[max_candidates:]
documents = []
for idx, item in enumerate(candidates):
text = item.content
if not text:
raw_content = item.metadata.get("content")
text = str(raw_content) if raw_content is not None else ""
documents.append({"id": str(idx), "content": self._truncate(text)})
url = f"{self.config.rerank_service_url}/rerank"
payload = {
"query": query,
"documents": documents,
"top_k": None, # get scores for all candidates
}
try:
response = httpx.post(
url,
json=payload,
timeout=self.config.service_timeout_seconds,
)
response.raise_for_status()
data = response.json()
except Exception as exc:
self._consecutive_failures += 1
self._circuit_open_until = time.monotonic() + self.config.circuit_cooldown_seconds
logger.warning(
"Rerank service call failed (%s), %s (failures=%d): %s",
url,
"skipping reranking" if self.config.fallback_to_skip else "raising",
self._consecutive_failures,
exc,
)
if self.config.fallback_to_skip:
return results
raise
# Success — reset circuit breaker
self._consecutive_failures = 0
# Map scores back to RetrievalResult objects
score_map: dict[int, float] = {}
for item in data.get("results", []):
score_map[int(item["id"])] = float(item["score"])
reranked: list[RetrievalResult] = []
for idx, item in enumerate(candidates):
score = score_map.get(idx)
if score is None:
continue
metadata = dict(item.metadata)
metadata["pre_rerank_score"] = item.score
metadata["rerank_score"] = score
metadata["rerank_model"] = self.config.model_name
metadata["rerank_mode"] = "http"
reranked.append(
RetrievalResult(
content=item.content,
score=score,
subdivision_id=item.subdivision_id,
act_id=item.act_id,
celex=item.celex,
subdivision_type=item.subdivision_type,
sequence_order=item.sequence_order,
metadata=metadata,
)
)
reranked.sort(key=lambda x: x.score, reverse=True)
combined = reranked + remainder
return combined[:top_k] if top_k is not None else combined
# ── Local mode ──────────────────────────────────────────────────────
def _rerank_local(
self,
query: str,
results: list["RetrievalResult"],
top_k: Optional[int],
) -> list["RetrievalResult"]:
if not self._load_model():
return results
model = self._model
if model is None:
return results
max_candidates = min(len(results), max(self.config.max_candidates, 1))
candidates = results[:max_candidates]
remainder = results[max_candidates:]
pairs: list[tuple[str, str]] = []
for item in candidates:
text = item.content
if not text:
raw_content = item.metadata.get("content")
text = str(raw_content) if raw_content is not None else ""
pairs.append((query, self._truncate(text)))
try:
raw_scores = model.predict(
pairs,
batch_size=self.config.batch_size,
show_progress_bar=False,
)
scores = [float(score) for score in raw_scores]
except Exception as exc:
logger.warning("Reranker inference failed, falling back: %s", exc)
return results
reranked: list[RetrievalResult] = []
for item, score in zip(candidates, scores):
metadata = dict(item.metadata)
metadata["pre_rerank_score"] = item.score
metadata["rerank_score"] = score
metadata["rerank_model"] = self.config.model_name
metadata["rerank_mode"] = "local"
reranked.append(
RetrievalResult(
content=item.content,
score=score,
subdivision_id=item.subdivision_id,
act_id=item.act_id,
celex=item.celex,
subdivision_type=item.subdivision_type,
sequence_order=item.sequence_order,
metadata=metadata,
)
)
reranked.sort(key=lambda x: x.score, reverse=True)
combined = reranked + remainder
return combined[:top_k] if top_k is not None else combined
# ── Internal helpers ────────────────────────────────────────────────
def _load_model(self) -> bool:
if not self.config.enabled:
return False
if self._model is not None:
return True
if self._model_load_failed:
return False
try:
from sentence_transformers import CrossEncoder
cross_encoder_cls = cast(Any, CrossEncoder)
if self.config.cache_dir is not None:
try:
loaded_model = cross_encoder_cls(
self.config.model_name,
device=self.config.device,
cache_dir=self.config.cache_dir,
)
except TypeError:
loaded_model = cross_encoder_cls(
self.config.model_name,
device=self.config.device,
)
else:
loaded_model = cross_encoder_cls(
self.config.model_name,
device=self.config.device,
)
self._model = loaded_model
return True
except Exception as exc:
logger.warning("Failed to load reranker model '%s': %s", self.config.model_name, exc)
self._model_load_failed = True
return False
def _truncate(self, text: str) -> str:
if not text:
return ""
max_chars = self.config.max_chars
if max_chars <= 0:
return text
return text[:max_chars]