#!/usr/bin/env python3
"""
Unified RAG benchmark pipeline.
Runs retrieval and/or generation evaluation, measures latency,
computes deltas against a previous run, and produces a timestamped JSON report.
Usage:
python scripts/benchmark_rag.py # full benchmark
python scripts/benchmark_rag.py --retrieval-only # retrieval only (no LLM cost)
python scripts/benchmark_rag.py --compare benchmarks/prev.json # compare with previous
python scripts/benchmark_rag.py --tag "post-rerank-tuning"
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional, Sequence
# Ensure scripts/ is on the import path
_SCRIPTS_DIR = Path(__file__).resolve().parent
if str(_SCRIPTS_DIR) not in sys.path:
sys.path.insert(0, str(_SCRIPTS_DIR))
from bench_utils import build_ragas_embeddings as _build_ragas_embeddings # noqa: E402
from bench_utils import format_float as _format_float # noqa: E402
from eval_ragas import ( # noqa: E402
collect_samples as collect_qa_samples,
)
from eval_ragas import (
collect_simple_metrics,
run_ragas_evaluation,
)
from eval_ragas import (
load_dataset as load_qa_dataset,
)
from eval_retrieval import ( # noqa: E402
evaluate as evaluate_retrieval,
)
from eval_retrieval import (
load_dataset as load_retrieval_dataset,
)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_DEFAULT_RETRIEVAL_DATASET = _SCRIPTS_DIR / "datasets" / "rag_test_retrieval.jsonl"
_DEFAULT_QA_DATASET = _SCRIPTS_DIR / "datasets" / "rag_test_qa.jsonl"
_DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "benchmarks"
_DEFAULT_MODES = ("hybrid", "semantic", "lexical")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _ragas_available() -> bool:
try:
__import__("ragas")
__import__("datasets")
return True
except ImportError:
return False
def _delta_str(old: float, new: float) -> str:
if old == 0.0 and new == 0.0:
return "(=)"
if old == 0.0:
return f"(+{new:.3f})"
pct = ((new - old) / abs(old)) * 100.0
if abs(pct) < 0.05:
return "(=)"
sign = "+" if pct > 0 else ""
return f"({sign}{pct:.1f}%)"
# ---------------------------------------------------------------------------
# Phase 1: Retrieval benchmark
# ---------------------------------------------------------------------------
[docs]
def run_retrieval_benchmark(
*,
dataset_path: Path,
base_url: str,
modes: Sequence[str],
top_k: int,
score_threshold: Optional[float],
timeout_seconds: float,
filters: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Run retrieval evaluation for each requested mode and collect summaries."""
examples = load_retrieval_dataset(dataset_path)
if not examples:
print(" [SKIP] Retrieval dataset is empty")
return {}
results: Dict[str, Any] = {
"dataset": str(dataset_path),
"examples_total": len(examples),
"top_k": top_k,
"modes": {},
}
for mode in modes:
print(f" [{mode}] Evaluating {len(examples)} queries...", end=" ", flush=True)
report = evaluate_retrieval(
examples=examples,
base_url=base_url,
top_k=top_k,
mode=mode,
score_threshold=score_threshold,
timeout_seconds=timeout_seconds,
filters=filters,
)
results["modes"][mode] = report
summary = report["summary"]
print(
f"Hit@{top_k}={_format_float(summary['celex_hit_at_k'])} "
f"MRR={_format_float(summary['celex_mrr'])} "
f"NDCG={_format_float(summary['celex_ndcg_at_k'])} "
f"p50={summary['latency_p50_ms']:.0f}ms"
)
return results
# ---------------------------------------------------------------------------
# Phase 2: Generation benchmark
# ---------------------------------------------------------------------------
[docs]
def run_generation_benchmark(
*,
dataset_path: Path,
base_url: str,
default_mode: str,
default_top_k: int,
default_min_score: Optional[float],
timeout_seconds: float,
max_contexts: int,
use_ragas: bool,
judge_provider: str,
judge_model: str,
judge_base_url: str,
judge_api_key: Optional[str],
judge_timeout_seconds: float,
judge_temperature: float,
faithfulness_threshold: float,
metric_names: Sequence[str],
) -> Dict[str, Any]:
"""Run answer-generation evaluation and optional RAGAS scoring."""
examples = load_qa_dataset(dataset_path)
if not examples:
print(" [SKIP] QA dataset is empty")
return {}
print(f" Collecting {len(examples)} query responses (mode={default_mode})...", flush=True)
rows = collect_qa_samples(
examples=examples,
base_url=base_url,
default_mode=default_mode,
default_top_k=default_top_k,
default_min_score=default_min_score,
timeout_seconds=timeout_seconds,
max_contexts=max_contexts,
include_full_content=True,
)
ragas_ok = use_ragas and _ragas_available()
if ragas_ok:
print(f" Running RAGAS evaluation ({', '.join(metric_names)})...", flush=True)
ragas_embeddings = _build_ragas_embeddings(
provider=judge_provider,
api_key=judge_api_key,
)
report = run_ragas_evaluation(
rows=rows,
judge_provider=judge_provider,
judge_model=judge_model,
judge_base_url=judge_base_url,
judge_api_key=judge_api_key,
judge_timeout_seconds=judge_timeout_seconds,
judge_temperature=judge_temperature,
metric_names=metric_names,
faithfulness_threshold=faithfulness_threshold,
embeddings=ragas_embeddings,
)
else:
if use_ragas:
print(" [WARN] RAGAS not installed, collecting simple metrics only")
report = collect_simple_metrics(rows)
report["dataset"] = str(dataset_path)
return report
# ---------------------------------------------------------------------------
# Phase 3: Report & comparison
# ---------------------------------------------------------------------------
def _print_retrieval_summary(retrieval: Dict[str, Any], top_k: int) -> None:
modes_data = retrieval.get("modes", {})
if not modes_data:
return
print(f"\n RETRIEVAL ({retrieval.get('examples_total', '?')} queries)")
print(" " + "\u2500" * 72)
hit_col = "Hit@" + str(top_k)
ndcg_col = "NDCG@" + str(top_k)
p_col = "P@" + str(top_k)
header = f" {'Mode':<10} {hit_col:<8} {'MRR':<8} {ndcg_col:<9} {p_col:<8} {'p50(ms)':<9} {'p95(ms)':<9}"
print(header)
for mode_name, mode_report in modes_data.items():
s = mode_report["summary"]
print(
f" {mode_name:<10} "
f"{_format_float(s['celex_hit_at_k']):<8} "
f"{_format_float(s['celex_mrr']):<8} "
f"{_format_float(s['celex_ndcg_at_k']):<9} "
f"{_format_float(s['celex_precision_at_k']):<8} "
f"{s['latency_p50_ms']:<9.0f} "
f"{s['latency_p95_ms']:<9.0f}"
)
def _print_generation_summary(generation: Dict[str, Any]) -> None:
summary = generation.get("summary", {})
if not summary:
return
total = summary.get("examples_total", "?")
print(f"\n GENERATION ({total} queries)")
print(" " + "\u2500" * 72)
metrics = summary.get("metrics_evaluated", [])
for metric_name in metrics:
mean_key = f"{metric_name}_mean"
mean_val = summary.get(mean_key, 0.0)
extra = ""
if metric_name == "faithfulness":
threshold = summary.get("faithfulness_threshold", 0.8)
extra = f" (threshold: {threshold})"
label = metric_name.replace("_", " ").title()
print(f" {label:<22}: {_format_float(mean_val)}{extra}")
if not metrics:
print(" (RAGAS not available - simple metrics only)")
print(f" {'Avg answer chars':<22}: {summary.get('avg_answer_chars', 0):.0f}")
print(f" {'Avg sources':<22}: {summary.get('avg_sources_per_example', 0):.1f}")
p50 = summary.get("latency_p50_ms", 0)
p95 = summary.get("latency_p95_ms", 0)
print(f" {'Latency p50/p95':<22}: {p50:.0f}ms / {p95:.0f}ms")
def _print_comparison(current: Dict[str, Any], previous: Dict[str, Any], prev_path: str) -> None:
print(f"\n COMPARISON vs {prev_path}")
print(" " + "\u2500" * 72)
# Compare retrieval metrics
curr_ret = current.get("retrieval", {}).get("modes", {})
prev_ret = previous.get("retrieval", {}).get("modes", {})
for mode_name in curr_ret:
if mode_name not in prev_ret:
continue
cs = curr_ret[mode_name]["summary"]
ps = prev_ret[mode_name]["summary"]
for metric_key, label in [
("celex_hit_at_k", "Hit@K"),
("celex_mrr", "MRR"),
("celex_ndcg_at_k", "NDCG"),
]:
old_val = ps.get(metric_key, 0.0)
new_val = cs.get(metric_key, 0.0)
delta = _delta_str(old_val, new_val)
print(f" {mode_name} {label:<8}: {_format_float(old_val)} -> {_format_float(new_val)} {delta}")
# Compare generation metrics
curr_gen = current.get("generation", {}).get("summary", {})
prev_gen = previous.get("generation", {}).get("summary", {})
if curr_gen and prev_gen:
for metric_name in curr_gen.get("metrics_evaluated", []):
mean_key = f"{metric_name}_mean"
old_val = prev_gen.get(mean_key, 0.0)
new_val = curr_gen.get(mean_key, 0.0)
label = metric_name.replace("_", " ").title()
delta = _delta_str(old_val, new_val)
print(f" {label:<22}: {_format_float(old_val)} -> {_format_float(new_val)} {delta}")
def _find_latest_report(output_dir: Path, exclude: Optional[Path] = None) -> Optional[Path]:
if not output_dir.exists():
return None
reports = sorted(
(f for f in output_dir.glob("*.json") if f.name != ".gitkeep" and f != exclude),
key=lambda f: f.stat().st_mtime,
reverse=True,
)
return reports[0] if reports else None
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
[docs]
def parse_args() -> argparse.Namespace:
"""Parse CLI arguments for the unified benchmark runner."""
parser = argparse.ArgumentParser(
description="Unified RAG benchmark pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Scope
parser.add_argument("--retrieval-only", action="store_true", help="Run retrieval benchmark only")
parser.add_argument("--generation-only", action="store_true", help="Run generation benchmark only")
# Service
parser.add_argument("--base-url", default="http://localhost:8001", help="rag-service base URL")
parser.add_argument("--timeout", type=float, default=180.0, help="HTTP timeout (seconds)")
# Retrieval
parser.add_argument("--retrieval-dataset", default=None, help="Path to retrieval dataset")
parser.add_argument(
"--modes",
default="hybrid,semantic,lexical",
help="Comma-separated retrieval modes to test",
)
parser.add_argument("--top-k", type=int, default=10, help="Top K for retrieval")
parser.add_argument("--score-threshold", type=float, default=None, help="Score threshold")
parser.add_argument(
"--filters",
type=str,
default=None,
help='JSON metadata filters for retrieval (e.g. \'{"act_type": "regulation"}\')',
)
# Generation
parser.add_argument("--qa-dataset", default=None, help="Path to QA dataset")
parser.add_argument("--qa-mode", default="rag", choices=["rag", "graph"], help="Default query mode")
parser.add_argument("--qa-top-k", type=int, default=12, help="Top K for QA queries")
parser.add_argument("--qa-min-score", type=float, default=None, help="Min score for QA")
parser.add_argument("--max-contexts", type=int, default=12, help="Max contexts per example")
parser.add_argument("--no-ragas", action="store_true", help="Disable RAGAS even if installed")
parser.add_argument(
"--metrics",
default="faithfulness,answer_relevancy,context_precision,context_recall",
help="Comma-separated RAGAS metrics",
)
# Judge LLM
parser.add_argument("--judge-provider", default="mistral", choices=["mistral", "openai_compat"])
parser.add_argument("--judge-model", default="mistral-small-latest")
parser.add_argument(
"--judge-base-url",
default=os.getenv("RAGAS_JUDGE_BASE_URL", "https://api.mistral.ai/v1"),
)
parser.add_argument("--judge-api-key", default=os.getenv("RAGAS_JUDGE_API_KEY", ""))
parser.add_argument("--judge-timeout", type=float, default=120.0)
parser.add_argument("--judge-temperature", type=float, default=0.0)
parser.add_argument("--faithfulness-threshold", type=float, default=0.8)
# Output
parser.add_argument("--output-dir", default=None, help="Report output directory")
parser.add_argument("--compare", default=None, help="Path to previous report for comparison")
parser.add_argument("--tag", default="", help="Free-form tag added to report metadata")
return parser.parse_args()
[docs]
def main() -> int:
"""Execute the benchmark pipeline and persist the generated report."""
args = parse_args()
now = datetime.now(timezone.utc)
timestamp_str = now.strftime("%Y-%m-%d_%H%M%S")
output_dir = Path(args.output_dir) if args.output_dir else _DEFAULT_OUTPUT_DIR
output_dir.mkdir(parents=True, exist_ok=True)
run_retrieval = not args.generation_only
run_generation = not args.retrieval_only
retrieval_dataset = Path(args.retrieval_dataset) if args.retrieval_dataset else _DEFAULT_RETRIEVAL_DATASET
qa_dataset = Path(args.qa_dataset) if args.qa_dataset else _DEFAULT_QA_DATASET
retrieval_filters = json.loads(args.filters) if args.filters else None
modes = [m.strip() for m in args.modes.split(",") if m.strip()]
metric_names = [m.strip() for m in args.metrics.split(",") if m.strip()]
separator = "\u2550" * 60
print(f"\n {separator}")
print(f" RAG BENCHMARK \u2014 {now.strftime('%Y-%m-%d %H:%M:%S UTC')}")
if args.tag:
print(f" Tag: {args.tag}")
print(f" {separator}")
report: Dict[str, Any] = {
"timestamp": now.isoformat(),
"tag": args.tag,
"base_url": args.base_url,
}
# Phase 1: Retrieval
if run_retrieval:
if not retrieval_dataset.exists():
print(f"\n [SKIP] Retrieval dataset not found: {retrieval_dataset}")
print(" Generate one with: just rag-bench-generate-dataset")
else:
print(f"\n Phase 1: Retrieval ({retrieval_dataset.name})")
report["retrieval"] = run_retrieval_benchmark(
dataset_path=retrieval_dataset,
base_url=args.base_url,
modes=modes,
top_k=args.top_k,
score_threshold=args.score_threshold,
timeout_seconds=args.timeout,
filters=retrieval_filters,
)
# Phase 2: Generation
if run_generation:
if not qa_dataset.exists():
print(f"\n [SKIP] QA dataset not found: {qa_dataset}")
print(" Generate one with: just rag-bench-generate-dataset")
else:
print(f"\n Phase 2: Generation ({qa_dataset.name})")
report["generation"] = run_generation_benchmark(
dataset_path=qa_dataset,
base_url=args.base_url,
default_mode=args.qa_mode,
default_top_k=args.qa_top_k,
default_min_score=args.qa_min_score,
timeout_seconds=args.timeout,
max_contexts=args.max_contexts,
use_ragas=not args.no_ragas,
judge_provider=args.judge_provider,
judge_model=args.judge_model,
judge_base_url=args.judge_base_url,
judge_api_key=str(args.judge_api_key).strip() or None,
judge_timeout_seconds=args.judge_timeout,
judge_temperature=args.judge_temperature,
faithfulness_threshold=args.faithfulness_threshold,
metric_names=metric_names,
)
# Phase 3: Print summary
if "retrieval" in report:
_print_retrieval_summary(report["retrieval"], args.top_k)
if "generation" in report:
_print_generation_summary(report["generation"])
# Save report
report_path = output_dir / f"{timestamp_str}.json"
report_json = json.dumps(report, indent=2, ensure_ascii=False)
report_path.write_text(report_json + "\n", encoding="utf-8")
print(f"\n Report saved: {report_path}")
# Load previous report (used by both console comparison and markdown report)
compare_path: Optional[Path] = None
if args.compare:
compare_path = Path(args.compare)
else:
compare_path = _find_latest_report(output_dir, exclude=report_path)
previous_data: Optional[Dict[str, Any]] = None
if compare_path and compare_path.exists():
try:
previous_data = json.loads(compare_path.read_text(encoding="utf-8"))
except Exception as exc:
print(f"\n [WARN] Could not load previous report: {exc}")
# Console comparison
if previous_data is not None and compare_path is not None:
_print_comparison(report, previous_data, compare_path.name)
# Generate markdown report with charts
try:
from report_generator import generate_report
charts_dir = output_dir / "charts"
charts_dir.mkdir(parents=True, exist_ok=True)
md_path = report_path.with_suffix(".md")
generate_report(report, md_path, charts_dir, previous_report=previous_data)
print(f" Markdown report: {md_path}")
except Exception as exc:
print(f" [WARN] Markdown report generation failed: {exc}")
print(f"\n {separator}\n")
return 0
if __name__ == "__main__":
raise SystemExit(main())