"""
Rerank Service — dedicated cross-encoder reranking for Lalandre.
Loads BAAI/bge-reranker-v2-m3 at startup and exposes a /rerank endpoint.
"""
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from typing import Any, Optional
from fastapi import FastAPI, HTTPException
from lalandre_core.config import get_config
from lalandre_core.logging_setup import setup_worker_logging
try:
from .models import RerankRequest, RerankResponse, RerankResult
except ImportError: # pragma: no cover - keeps direct script execution working
from models import RerankRequest, RerankResponse, RerankResult
setup_worker_logging()
logger = logging.getLogger(__name__)
_model: Optional[Any] = None
_model_name: str = ""
_inference_lock = asyncio.Lock()
def _load_model() -> None:
global _model, _model_name
from sentence_transformers import CrossEncoder
config = get_config()
_model_name = config.search.rerank_model
device = config.search.rerank_device
cache_dir = config.search.rerank_cache_dir or config.models_cache_dir
logger.info("Loading reranker model: %s (device=%s)", _model_name, device)
kwargs: dict[str, Any] = {"device": device}
if cache_dir:
try:
_model = CrossEncoder(_model_name, cache_folder=cache_dir, **kwargs)
except TypeError:
_model = CrossEncoder(_model_name, **kwargs)
else:
_model = CrossEncoder(_model_name, **kwargs)
logger.info("Reranker model loaded successfully")
def _run_inference(pairs: list[tuple[str, str]], batch_size: int) -> tuple[list[float], float]:
"""Run cross-encoder inference synchronously (called from a thread)."""
started_at = time.perf_counter()
assert _model is not None
raw_scores = _model.predict(
pairs,
batch_size=batch_size,
show_progress_bar=False,
)
duration_ms = round((time.perf_counter() - started_at) * 1000.0, 1)
return [float(s) for s in raw_scores], duration_ms
[docs]
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load the reranker model during FastAPI startup."""
_load_model()
yield
app = FastAPI(
title="Lalandre Rerank Service",
description="Cross-encoder reranking service",
version="0.1.0",
lifespan=lifespan,
)
[docs]
@app.get("/health")
async def health():
"""Return rerank-service readiness and current model metadata."""
return {
"status": "healthy" if _model is not None else "loading",
"service": "rerank-service",
"model": _model_name,
"busy": _inference_lock.locked(),
}
[docs]
@app.post("/rerank", response_model=RerankResponse)
async def rerank(request: RerankRequest):
"""Rerank the supplied documents for the given query."""
if _model is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
if not request.documents:
return RerankResponse(results=[], model=_model_name, duration_ms=0.0)
config = get_config()
max_chars = config.search.rerank_max_chars
batch_size = config.search.rerank_batch_size
pairs: list[tuple[str, str]] = []
for doc in request.documents:
content = doc.content[:max_chars] if max_chars > 0 else doc.content
pairs.append((request.query, content))
async with _inference_lock:
try:
scores, duration_ms = await asyncio.to_thread(_run_inference, pairs, batch_size)
except Exception as exc:
logger.error("Reranker inference failed: %s", exc, exc_info=True)
raise HTTPException(status_code=500, detail=f"Inference failed: {exc}") from exc
scored = [(doc, score) for doc, score in zip(request.documents, scores)]
scored.sort(key=lambda x: x[1], reverse=True)
top_k = request.top_k
if top_k is not None:
scored = scored[:top_k]
results = [RerankResult(id=doc.id, score=score, rank=rank) for rank, (doc, score) in enumerate(scored, start=1)]
logger.info(
"Rerank completed: %d pairs, %.1f ms",
len(pairs),
duration_ms,
)
return RerankResponse(
results=results,
model=_model_name,
duration_ms=duration_ms,
)
[docs]
@app.post("/reload")
async def reload_model():
"""Reload the reranker model from the current config (app_config.yaml)."""
old = _model_name
try:
await asyncio.to_thread(_load_model)
return {"status": "reloaded", "previous_model": old, "current_model": _model_name}
except Exception as exc:
logger.error("Failed to reload reranker: %s", exc, exc_info=True)
raise HTTPException(status_code=500, detail=f"Reload failed: {exc}")