Source code for scripts.benchmark_rag

#!/usr/bin/env python3
"""
Unified RAG benchmark pipeline.

Runs retrieval and/or generation evaluation, measures latency,
computes deltas against a previous run, and produces a timestamped JSON report.

Usage:
    python scripts/benchmark_rag.py                          # full benchmark
    python scripts/benchmark_rag.py --retrieval-only         # retrieval only (no LLM cost)
    python scripts/benchmark_rag.py --compare benchmarks/prev.json  # compare with previous
    python scripts/benchmark_rag.py --tag "post-rerank-tuning"
"""

from __future__ import annotations

import argparse
import json
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional, Sequence

# Ensure scripts/ is on the import path
_SCRIPTS_DIR = Path(__file__).resolve().parent
if str(_SCRIPTS_DIR) not in sys.path:
    sys.path.insert(0, str(_SCRIPTS_DIR))

from bench_utils import build_ragas_embeddings as _build_ragas_embeddings  # noqa: E402
from bench_utils import format_float as _format_float  # noqa: E402
from eval_ragas import (  # noqa: E402
    collect_samples as collect_qa_samples,
)
from eval_ragas import (
    collect_simple_metrics,
    run_ragas_evaluation,
)
from eval_ragas import (
    load_dataset as load_qa_dataset,
)
from eval_retrieval import (  # noqa: E402
    evaluate as evaluate_retrieval,
)
from eval_retrieval import (
    load_dataset as load_retrieval_dataset,
)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

_DEFAULT_RETRIEVAL_DATASET = _SCRIPTS_DIR / "datasets" / "rag_test_retrieval.jsonl"
_DEFAULT_QA_DATASET = _SCRIPTS_DIR / "datasets" / "rag_test_qa.jsonl"
_DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "benchmarks"
_DEFAULT_MODES = ("hybrid", "semantic", "lexical")


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _ragas_available() -> bool:
    try:
        __import__("ragas")
        __import__("datasets")
        return True
    except ImportError:
        return False


def _delta_str(old: float, new: float) -> str:
    if old == 0.0 and new == 0.0:
        return "(=)"
    if old == 0.0:
        return f"(+{new:.3f})"
    pct = ((new - old) / abs(old)) * 100.0
    if abs(pct) < 0.05:
        return "(=)"
    sign = "+" if pct > 0 else ""
    return f"({sign}{pct:.1f}%)"


# ---------------------------------------------------------------------------
# Phase 1: Retrieval benchmark
# ---------------------------------------------------------------------------


[docs] def run_retrieval_benchmark( *, dataset_path: Path, base_url: str, modes: Sequence[str], top_k: int, score_threshold: Optional[float], timeout_seconds: float, filters: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Run retrieval evaluation for each requested mode and collect summaries.""" examples = load_retrieval_dataset(dataset_path) if not examples: print(" [SKIP] Retrieval dataset is empty") return {} results: Dict[str, Any] = { "dataset": str(dataset_path), "examples_total": len(examples), "top_k": top_k, "modes": {}, } for mode in modes: print(f" [{mode}] Evaluating {len(examples)} queries...", end=" ", flush=True) report = evaluate_retrieval( examples=examples, base_url=base_url, top_k=top_k, mode=mode, score_threshold=score_threshold, timeout_seconds=timeout_seconds, filters=filters, ) results["modes"][mode] = report summary = report["summary"] print( f"Hit@{top_k}={_format_float(summary['celex_hit_at_k'])} " f"MRR={_format_float(summary['celex_mrr'])} " f"NDCG={_format_float(summary['celex_ndcg_at_k'])} " f"p50={summary['latency_p50_ms']:.0f}ms" ) return results
# --------------------------------------------------------------------------- # Phase 2: Generation benchmark # ---------------------------------------------------------------------------
[docs] def run_generation_benchmark( *, dataset_path: Path, base_url: str, default_mode: str, default_top_k: int, default_min_score: Optional[float], timeout_seconds: float, max_contexts: int, use_ragas: bool, judge_provider: str, judge_model: str, judge_base_url: str, judge_api_key: Optional[str], judge_timeout_seconds: float, judge_temperature: float, faithfulness_threshold: float, metric_names: Sequence[str], ) -> Dict[str, Any]: """Run answer-generation evaluation and optional RAGAS scoring.""" examples = load_qa_dataset(dataset_path) if not examples: print(" [SKIP] QA dataset is empty") return {} print(f" Collecting {len(examples)} query responses (mode={default_mode})...", flush=True) rows = collect_qa_samples( examples=examples, base_url=base_url, default_mode=default_mode, default_top_k=default_top_k, default_min_score=default_min_score, timeout_seconds=timeout_seconds, max_contexts=max_contexts, include_full_content=True, ) ragas_ok = use_ragas and _ragas_available() if ragas_ok: print(f" Running RAGAS evaluation ({', '.join(metric_names)})...", flush=True) ragas_embeddings = _build_ragas_embeddings( provider=judge_provider, api_key=judge_api_key, ) report = run_ragas_evaluation( rows=rows, judge_provider=judge_provider, judge_model=judge_model, judge_base_url=judge_base_url, judge_api_key=judge_api_key, judge_timeout_seconds=judge_timeout_seconds, judge_temperature=judge_temperature, metric_names=metric_names, faithfulness_threshold=faithfulness_threshold, embeddings=ragas_embeddings, ) else: if use_ragas: print(" [WARN] RAGAS not installed, collecting simple metrics only") report = collect_simple_metrics(rows) report["dataset"] = str(dataset_path) return report
# --------------------------------------------------------------------------- # Phase 3: Report & comparison # --------------------------------------------------------------------------- def _print_retrieval_summary(retrieval: Dict[str, Any], top_k: int) -> None: modes_data = retrieval.get("modes", {}) if not modes_data: return print(f"\n RETRIEVAL ({retrieval.get('examples_total', '?')} queries)") print(" " + "\u2500" * 72) hit_col = "Hit@" + str(top_k) ndcg_col = "NDCG@" + str(top_k) p_col = "P@" + str(top_k) header = f" {'Mode':<10} {hit_col:<8} {'MRR':<8} {ndcg_col:<9} {p_col:<8} {'p50(ms)':<9} {'p95(ms)':<9}" print(header) for mode_name, mode_report in modes_data.items(): s = mode_report["summary"] print( f" {mode_name:<10} " f"{_format_float(s['celex_hit_at_k']):<8} " f"{_format_float(s['celex_mrr']):<8} " f"{_format_float(s['celex_ndcg_at_k']):<9} " f"{_format_float(s['celex_precision_at_k']):<8} " f"{s['latency_p50_ms']:<9.0f} " f"{s['latency_p95_ms']:<9.0f}" ) def _print_generation_summary(generation: Dict[str, Any]) -> None: summary = generation.get("summary", {}) if not summary: return total = summary.get("examples_total", "?") print(f"\n GENERATION ({total} queries)") print(" " + "\u2500" * 72) metrics = summary.get("metrics_evaluated", []) for metric_name in metrics: mean_key = f"{metric_name}_mean" mean_val = summary.get(mean_key, 0.0) extra = "" if metric_name == "faithfulness": threshold = summary.get("faithfulness_threshold", 0.8) extra = f" (threshold: {threshold})" label = metric_name.replace("_", " ").title() print(f" {label:<22}: {_format_float(mean_val)}{extra}") if not metrics: print(" (RAGAS not available - simple metrics only)") print(f" {'Avg answer chars':<22}: {summary.get('avg_answer_chars', 0):.0f}") print(f" {'Avg sources':<22}: {summary.get('avg_sources_per_example', 0):.1f}") p50 = summary.get("latency_p50_ms", 0) p95 = summary.get("latency_p95_ms", 0) print(f" {'Latency p50/p95':<22}: {p50:.0f}ms / {p95:.0f}ms") def _print_comparison(current: Dict[str, Any], previous: Dict[str, Any], prev_path: str) -> None: print(f"\n COMPARISON vs {prev_path}") print(" " + "\u2500" * 72) # Compare retrieval metrics curr_ret = current.get("retrieval", {}).get("modes", {}) prev_ret = previous.get("retrieval", {}).get("modes", {}) for mode_name in curr_ret: if mode_name not in prev_ret: continue cs = curr_ret[mode_name]["summary"] ps = prev_ret[mode_name]["summary"] for metric_key, label in [ ("celex_hit_at_k", "Hit@K"), ("celex_mrr", "MRR"), ("celex_ndcg_at_k", "NDCG"), ]: old_val = ps.get(metric_key, 0.0) new_val = cs.get(metric_key, 0.0) delta = _delta_str(old_val, new_val) print(f" {mode_name} {label:<8}: {_format_float(old_val)} -> {_format_float(new_val)} {delta}") # Compare generation metrics curr_gen = current.get("generation", {}).get("summary", {}) prev_gen = previous.get("generation", {}).get("summary", {}) if curr_gen and prev_gen: for metric_name in curr_gen.get("metrics_evaluated", []): mean_key = f"{metric_name}_mean" old_val = prev_gen.get(mean_key, 0.0) new_val = curr_gen.get(mean_key, 0.0) label = metric_name.replace("_", " ").title() delta = _delta_str(old_val, new_val) print(f" {label:<22}: {_format_float(old_val)} -> {_format_float(new_val)} {delta}") def _find_latest_report(output_dir: Path, exclude: Optional[Path] = None) -> Optional[Path]: if not output_dir.exists(): return None reports = sorted( (f for f in output_dir.glob("*.json") if f.name != ".gitkeep" and f != exclude), key=lambda f: f.stat().st_mtime, reverse=True, ) return reports[0] if reports else None # --------------------------------------------------------------------------- # Main # ---------------------------------------------------------------------------
[docs] def parse_args() -> argparse.Namespace: """Parse CLI arguments for the unified benchmark runner.""" parser = argparse.ArgumentParser( description="Unified RAG benchmark pipeline", formatter_class=argparse.RawDescriptionHelpFormatter, ) # Scope parser.add_argument("--retrieval-only", action="store_true", help="Run retrieval benchmark only") parser.add_argument("--generation-only", action="store_true", help="Run generation benchmark only") # Service parser.add_argument("--base-url", default="http://localhost:8001", help="rag-service base URL") parser.add_argument("--timeout", type=float, default=180.0, help="HTTP timeout (seconds)") # Retrieval parser.add_argument("--retrieval-dataset", default=None, help="Path to retrieval dataset") parser.add_argument( "--modes", default="hybrid,semantic,lexical", help="Comma-separated retrieval modes to test", ) parser.add_argument("--top-k", type=int, default=10, help="Top K for retrieval") parser.add_argument("--score-threshold", type=float, default=None, help="Score threshold") parser.add_argument( "--filters", type=str, default=None, help='JSON metadata filters for retrieval (e.g. \'{"act_type": "regulation"}\')', ) # Generation parser.add_argument("--qa-dataset", default=None, help="Path to QA dataset") parser.add_argument("--qa-mode", default="rag", choices=["rag", "graph"], help="Default query mode") parser.add_argument("--qa-top-k", type=int, default=12, help="Top K for QA queries") parser.add_argument("--qa-min-score", type=float, default=None, help="Min score for QA") parser.add_argument("--max-contexts", type=int, default=12, help="Max contexts per example") parser.add_argument("--no-ragas", action="store_true", help="Disable RAGAS even if installed") parser.add_argument( "--metrics", default="faithfulness,answer_relevancy,context_precision,context_recall", help="Comma-separated RAGAS metrics", ) # Judge LLM parser.add_argument("--judge-provider", default="mistral", choices=["mistral", "openai_compat"]) parser.add_argument("--judge-model", default="mistral-small-latest") parser.add_argument( "--judge-base-url", default=os.getenv("RAGAS_JUDGE_BASE_URL", "https://api.mistral.ai/v1"), ) parser.add_argument("--judge-api-key", default=os.getenv("RAGAS_JUDGE_API_KEY", "")) parser.add_argument("--judge-timeout", type=float, default=120.0) parser.add_argument("--judge-temperature", type=float, default=0.0) parser.add_argument("--faithfulness-threshold", type=float, default=0.8) # Output parser.add_argument("--output-dir", default=None, help="Report output directory") parser.add_argument("--compare", default=None, help="Path to previous report for comparison") parser.add_argument("--tag", default="", help="Free-form tag added to report metadata") return parser.parse_args()
[docs] def main() -> int: """Execute the benchmark pipeline and persist the generated report.""" args = parse_args() now = datetime.now(timezone.utc) timestamp_str = now.strftime("%Y-%m-%d_%H%M%S") output_dir = Path(args.output_dir) if args.output_dir else _DEFAULT_OUTPUT_DIR output_dir.mkdir(parents=True, exist_ok=True) run_retrieval = not args.generation_only run_generation = not args.retrieval_only retrieval_dataset = Path(args.retrieval_dataset) if args.retrieval_dataset else _DEFAULT_RETRIEVAL_DATASET qa_dataset = Path(args.qa_dataset) if args.qa_dataset else _DEFAULT_QA_DATASET retrieval_filters = json.loads(args.filters) if args.filters else None modes = [m.strip() for m in args.modes.split(",") if m.strip()] metric_names = [m.strip() for m in args.metrics.split(",") if m.strip()] separator = "\u2550" * 60 print(f"\n {separator}") print(f" RAG BENCHMARK \u2014 {now.strftime('%Y-%m-%d %H:%M:%S UTC')}") if args.tag: print(f" Tag: {args.tag}") print(f" {separator}") report: Dict[str, Any] = { "timestamp": now.isoformat(), "tag": args.tag, "base_url": args.base_url, } # Phase 1: Retrieval if run_retrieval: if not retrieval_dataset.exists(): print(f"\n [SKIP] Retrieval dataset not found: {retrieval_dataset}") print(" Generate one with: just rag-bench-generate-dataset") else: print(f"\n Phase 1: Retrieval ({retrieval_dataset.name})") report["retrieval"] = run_retrieval_benchmark( dataset_path=retrieval_dataset, base_url=args.base_url, modes=modes, top_k=args.top_k, score_threshold=args.score_threshold, timeout_seconds=args.timeout, filters=retrieval_filters, ) # Phase 2: Generation if run_generation: if not qa_dataset.exists(): print(f"\n [SKIP] QA dataset not found: {qa_dataset}") print(" Generate one with: just rag-bench-generate-dataset") else: print(f"\n Phase 2: Generation ({qa_dataset.name})") report["generation"] = run_generation_benchmark( dataset_path=qa_dataset, base_url=args.base_url, default_mode=args.qa_mode, default_top_k=args.qa_top_k, default_min_score=args.qa_min_score, timeout_seconds=args.timeout, max_contexts=args.max_contexts, use_ragas=not args.no_ragas, judge_provider=args.judge_provider, judge_model=args.judge_model, judge_base_url=args.judge_base_url, judge_api_key=str(args.judge_api_key).strip() or None, judge_timeout_seconds=args.judge_timeout, judge_temperature=args.judge_temperature, faithfulness_threshold=args.faithfulness_threshold, metric_names=metric_names, ) # Phase 3: Print summary if "retrieval" in report: _print_retrieval_summary(report["retrieval"], args.top_k) if "generation" in report: _print_generation_summary(report["generation"]) # Save report report_path = output_dir / f"{timestamp_str}.json" report_json = json.dumps(report, indent=2, ensure_ascii=False) report_path.write_text(report_json + "\n", encoding="utf-8") print(f"\n Report saved: {report_path}") # Load previous report (used by both console comparison and markdown report) compare_path: Optional[Path] = None if args.compare: compare_path = Path(args.compare) else: compare_path = _find_latest_report(output_dir, exclude=report_path) previous_data: Optional[Dict[str, Any]] = None if compare_path and compare_path.exists(): try: previous_data = json.loads(compare_path.read_text(encoding="utf-8")) except Exception as exc: print(f"\n [WARN] Could not load previous report: {exc}") # Console comparison if previous_data is not None and compare_path is not None: _print_comparison(report, previous_data, compare_path.name) # Generate markdown report with charts try: from report_generator import generate_report charts_dir = output_dir / "charts" charts_dir.mkdir(parents=True, exist_ok=True) md_path = report_path.with_suffix(".md") generate_report(report, md_path, charts_dir, previous_report=previous_data) print(f" Markdown report: {md_path}") except Exception as exc: print(f" [WARN] Markdown report generation failed: {exc}") print(f"\n {separator}\n") return 0
if __name__ == "__main__": raise SystemExit(main())