Source code for scripts.eval_retrieval

#!/usr/bin/env python3
"""Run offline retrieval evaluation against the running RAG service.

The dataset may be provided as JSONL or as a JSON list. Each record defines a
``query`` plus optional ``expected_celex`` and ``expected_subdivision_ids``
values used to score retrieval quality.
"""

from __future__ import annotations

import argparse
import json
import math
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence

import requests
from bench_utils import percentile as _percentile
from bench_utils import safe_mean as _safe_mean


[docs] @dataclass(frozen=True) class EvalExample: """One retrieval evaluation example with expected hits.""" query: str expected_celex: tuple[str, ...] expected_subdivision_ids: tuple[int, ...]
def _normalize_celex_values(values: Sequence[Any]) -> tuple[str, ...]: normalized: list[str] = [] seen: set[str] = set() for item in values: text = str(item).strip() if not text: continue upper = text.upper() if upper in seen: continue seen.add(upper) normalized.append(upper) return tuple(normalized) def _normalize_int_values(values: Sequence[Any]) -> tuple[int, ...]: normalized: list[int] = [] seen: set[int] = set() for item in values: try: parsed = int(item) except (TypeError, ValueError): continue if parsed in seen: continue seen.add(parsed) normalized.append(parsed) return tuple(normalized) def _parse_example(raw: Dict[str, Any], *, index: int) -> EvalExample: query = str(raw.get("query", "")).strip() if not query: raise ValueError(f"Dataset item #{index} is missing a non-empty 'query' field") raw_celex = raw.get("expected_celex", []) if isinstance(raw_celex, str): expected_celex = _normalize_celex_values([raw_celex]) elif isinstance(raw_celex, list): expected_celex = _normalize_celex_values(raw_celex) else: expected_celex = () raw_subdivisions = raw.get("expected_subdivision_ids", []) if isinstance(raw_subdivisions, list): expected_subdivision_ids = _normalize_int_values(raw_subdivisions) else: expected_subdivision_ids = () return EvalExample( query=query, expected_celex=expected_celex, expected_subdivision_ids=expected_subdivision_ids, )
[docs] def load_dataset(dataset_path: Path) -> List[EvalExample]: """Load retrieval examples from JSON or JSONL.""" if not dataset_path.exists(): raise FileNotFoundError(f"Dataset not found: {dataset_path}") if dataset_path.suffix.lower() == ".jsonl": examples: list[EvalExample] = [] with dataset_path.open("r", encoding="utf-8") as handle: for index, line in enumerate(handle, start=1): stripped = line.strip() if not stripped: continue parsed = json.loads(stripped) if not isinstance(parsed, dict): raise ValueError(f"Invalid JSONL item at line {index}: expected object") examples.append(_parse_example(parsed, index=index)) return examples parsed = json.loads(dataset_path.read_text(encoding="utf-8")) if not isinstance(parsed, list): raise ValueError("JSON dataset must be a list of objects") return [_parse_example(item, index=index) for index, item in enumerate(parsed, start=1) if isinstance(item, dict)]
def _first_rank_match(values: Iterable[Any], expected: set[Any]) -> Optional[int]: for rank, value in enumerate(values, start=1): if value in expected: return rank return None def _ndcg_at_k( result_values: Sequence[Any], expected: set[Any], k: int, ) -> float: """Compute NDCG@K for a single query.""" dcg = 0.0 for i, val in enumerate(result_values[:k]): if val in expected: dcg += 1.0 / math.log2(i + 2) # i+2 because rank starts at 1 # Ideal DCG: all relevant items at the top n_relevant = min(len(expected), k) idcg = sum(1.0 / math.log2(i + 2) for i in range(n_relevant)) if idcg == 0.0: return 0.0 return dcg / idcg def _precision_at_k( result_values: Sequence[Any], expected: set[Any], k: int, ) -> float: """Compute Precision@K for a single query.""" hits = sum(1 for val in result_values[:k] if val in expected) return hits / k if k > 0 else 0.0
[docs] def evaluate( *, examples: Sequence[EvalExample], base_url: str, top_k: int, mode: str, score_threshold: Optional[float], timeout_seconds: float, filters: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Evaluate `/search` responses against the expected identifiers.""" per_example: list[Dict[str, Any]] = [] celex_hits: list[float] = [] celex_rr: list[float] = [] celex_ndcg: list[float] = [] celex_precision: list[float] = [] subdivision_hits: list[float] = [] subdivision_rr: list[float] = [] subdivision_ndcg: list[float] = [] subdivision_precision: list[float] = [] result_sizes: list[float] = [] latencies_ms: list[float] = [] endpoint = base_url.rstrip("/") + "/search" for example in examples: payload: Dict[str, Any] = { "query": example.query, "top_k": top_k, "mode": mode, } if score_threshold is not None: payload["score_threshold"] = score_threshold if filters is not None: payload["filters"] = filters t0 = time.perf_counter() response = requests.post(endpoint, json=payload, timeout=timeout_seconds) latency = (time.perf_counter() - t0) * 1000.0 response.raise_for_status() latencies_ms.append(latency) data = response.json() results_raw = data.get("results", []) results = results_raw if isinstance(results_raw, list) else [] result_sizes.append(float(len(results))) celex_values = [ str(item.get("celex", "")).upper() for item in results if isinstance(item, dict) and item.get("celex") ] subdivision_values = [ int(v) for item in results if isinstance(item, dict) and (v := item.get("subdivision_id")) is not None ] celex_rank = None if example.expected_celex: expected_set = set(example.expected_celex) celex_rank = _first_rank_match(celex_values, expected_set) celex_hits.append(1.0 if celex_rank is not None else 0.0) celex_rr.append(1.0 / celex_rank if celex_rank is not None else 0.0) celex_ndcg.append(_ndcg_at_k(celex_values, expected_set, top_k)) celex_precision.append(_precision_at_k(celex_values, expected_set, top_k)) subdivision_rank = None if example.expected_subdivision_ids: expected_set = set(example.expected_subdivision_ids) subdivision_rank = _first_rank_match( subdivision_values, expected_set, ) subdivision_hits.append(1.0 if subdivision_rank is not None else 0.0) subdivision_rr.append(1.0 / subdivision_rank if subdivision_rank is not None else 0.0) subdivision_ndcg.append(_ndcg_at_k(subdivision_values, expected_set, top_k)) subdivision_precision.append(_precision_at_k(subdivision_values, expected_set, top_k)) per_example.append( { "query": example.query, "results_count": len(results), "celex_rank": celex_rank, "subdivision_rank": subdivision_rank, "latency_ms": round(latency, 1), } ) summary: Dict[str, Any] = { "examples_total": len(examples), "mode": mode, "top_k": top_k, "score_threshold": score_threshold, "avg_results_count": _safe_mean(result_sizes), "celex_expectation_count": sum(1 for item in examples if item.expected_celex), "subdivision_expectation_count": sum(1 for item in examples if item.expected_subdivision_ids), "celex_hit_at_k": _safe_mean(celex_hits), "celex_mrr": _safe_mean(celex_rr), "celex_ndcg_at_k": _safe_mean(celex_ndcg), "celex_precision_at_k": _safe_mean(celex_precision), "subdivision_hit_at_k": _safe_mean(subdivision_hits), "subdivision_mrr": _safe_mean(subdivision_rr), "subdivision_ndcg_at_k": _safe_mean(subdivision_ndcg), "subdivision_precision_at_k": _safe_mean(subdivision_precision), "latency_p50_ms": round(_percentile(latencies_ms, 50), 1), "latency_p95_ms": round(_percentile(latencies_ms, 95), 1), "latency_p99_ms": round(_percentile(latencies_ms, 99), 1), } return {"summary": summary, "details": per_example}
[docs] def parse_args() -> argparse.Namespace: """Parse CLI arguments for the retrieval evaluation command.""" parser = argparse.ArgumentParser(description="Evaluate retrieval quality against /search") parser.add_argument("--dataset", required=True, help="Path to JSON or JSONL dataset") parser.add_argument("--base-url", default="http://localhost:8001", help="RAG service base URL") parser.add_argument("--mode", default="hybrid", choices=["semantic", "lexical", "hybrid"]) parser.add_argument("--top-k", type=int, default=12) parser.add_argument("--score-threshold", type=float, default=None) parser.add_argument("--timeout", type=float, default=30.0) parser.add_argument( "--filters", type=str, default=None, help='JSON metadata filters (e.g. \'{"act_type": "regulation"}\')', ) parser.add_argument("--output", default="", help="Optional path to write JSON report") return parser.parse_args()
[docs] def main() -> int: """Run the retrieval evaluation CLI and print the JSON report.""" args = parse_args() dataset_path = Path(args.dataset).resolve() examples = load_dataset(dataset_path) if not examples: raise ValueError("Dataset is empty after parsing") filters = json.loads(args.filters) if args.filters else None report = evaluate( examples=examples, base_url=args.base_url, top_k=max(args.top_k, 1), mode=args.mode, score_threshold=args.score_threshold, timeout_seconds=max(args.timeout, 1.0), filters=filters, ) output_json = json.dumps(report, indent=2, ensure_ascii=False) print(output_json) if args.output: output_path = Path(args.output).resolve() output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(output_json + "\n", encoding="utf-8") return 0
if __name__ == "__main__": raise SystemExit(main())