Source code for rerank_service.main

"""
Rerank Service — dedicated cross-encoder reranking for Lalandre.

Loads BAAI/bge-reranker-v2-m3 at startup and exposes a /rerank endpoint.
"""

import asyncio
import logging
import time
from contextlib import asynccontextmanager
from typing import Any, Optional

from fastapi import FastAPI, HTTPException
from lalandre_core.config import get_config
from lalandre_core.logging_setup import setup_worker_logging

try:
    from .models import RerankRequest, RerankResponse, RerankResult
except ImportError:  # pragma: no cover - keeps direct script execution working
    from models import RerankRequest, RerankResponse, RerankResult

setup_worker_logging()
logger = logging.getLogger(__name__)

_model: Optional[Any] = None
_model_name: str = ""
_inference_lock = asyncio.Lock()


def _load_model() -> None:
    global _model, _model_name
    from sentence_transformers import CrossEncoder

    config = get_config()
    _model_name = config.search.rerank_model
    device = config.search.rerank_device
    cache_dir = config.search.rerank_cache_dir or config.models_cache_dir

    logger.info("Loading reranker model: %s (device=%s)", _model_name, device)

    kwargs: dict[str, Any] = {"device": device}
    if cache_dir:
        try:
            _model = CrossEncoder(_model_name, cache_folder=cache_dir, **kwargs)
        except TypeError:
            _model = CrossEncoder(_model_name, **kwargs)
    else:
        _model = CrossEncoder(_model_name, **kwargs)
    logger.info("Reranker model loaded successfully")


def _run_inference(pairs: list[tuple[str, str]], batch_size: int) -> tuple[list[float], float]:
    """Run cross-encoder inference synchronously (called from a thread)."""
    started_at = time.perf_counter()
    assert _model is not None
    raw_scores = _model.predict(
        pairs,
        batch_size=batch_size,
        show_progress_bar=False,
    )
    duration_ms = round((time.perf_counter() - started_at) * 1000.0, 1)
    return [float(s) for s in raw_scores], duration_ms


[docs] @asynccontextmanager async def lifespan(app: FastAPI): """Load the reranker model during FastAPI startup.""" _load_model() yield
app = FastAPI( title="Lalandre Rerank Service", description="Cross-encoder reranking service", version="0.1.0", lifespan=lifespan, )
[docs] @app.get("/health") async def health(): """Return rerank-service readiness and current model metadata.""" return { "status": "healthy" if _model is not None else "loading", "service": "rerank-service", "model": _model_name, "busy": _inference_lock.locked(), }
[docs] @app.post("/rerank", response_model=RerankResponse) async def rerank(request: RerankRequest): """Rerank the supplied documents for the given query.""" if _model is None: raise HTTPException(status_code=503, detail="Model not loaded yet") if not request.documents: return RerankResponse(results=[], model=_model_name, duration_ms=0.0) config = get_config() max_chars = config.search.rerank_max_chars batch_size = config.search.rerank_batch_size pairs: list[tuple[str, str]] = [] for doc in request.documents: content = doc.content[:max_chars] if max_chars > 0 else doc.content pairs.append((request.query, content)) async with _inference_lock: try: scores, duration_ms = await asyncio.to_thread(_run_inference, pairs, batch_size) except Exception as exc: logger.error("Reranker inference failed: %s", exc, exc_info=True) raise HTTPException(status_code=500, detail=f"Inference failed: {exc}") from exc scored = [(doc, score) for doc, score in zip(request.documents, scores)] scored.sort(key=lambda x: x[1], reverse=True) top_k = request.top_k if top_k is not None: scored = scored[:top_k] results = [RerankResult(id=doc.id, score=score, rank=rank) for rank, (doc, score) in enumerate(scored, start=1)] logger.info( "Rerank completed: %d pairs, %.1f ms", len(pairs), duration_ms, ) return RerankResponse( results=results, model=_model_name, duration_ms=duration_ms, )
[docs] @app.post("/reload") async def reload_model(): """Reload the reranker model from the current config (app_config.yaml).""" old = _model_name try: await asyncio.to_thread(_load_model) return {"status": "reloaded", "previous_model": old, "current_model": _model_name} except Exception as exc: logger.error("Failed to reload reranker: %s", exc, exc_info=True) raise HTTPException(status_code=500, detail=f"Reload failed: {exc}")