Source code for scripts.benchmark_graphrag

#!/usr/bin/env python3
"""Reproducible local GraphRAG microbenchmarks.

The goal is not to simulate end-to-end production latency. This script measures
deterministic building blocks that materially affect GraphRAG responsiveness:

- read-only Cypher validation,
- graph node ranking,
- token-budget-aware graph context assembly,
- parallel semantic retrieval fan-out versus a sequential baseline.
"""

from __future__ import annotations

import json
import platform
import statistics
import subprocess
import time
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Callable, cast

from lalandre_db_neo4j.repository import Neo4jRepository
from lalandre_rag.graph.context_budget import GraphContextBudget
from lalandre_rag.graph.ranker import rank_graph_nodes, rank_relationships
from lalandre_rag.retrieval.query_expansion import ExpandedQuery
from lalandre_rag.retrieval.result import RetrievalResult
from lalandre_rag.retrieval.service import RetrievalService

ROOT = Path(__file__).resolve().parents[1]


[docs] @dataclass class BenchmarkSummary: """Latency summary for one benchmarked operation.""" name: str median_ms: float p95_ms: float min_ms: float max_ms: float iterations: int extra: dict[str, Any]
def _measure(name: str, fn: Callable[[], Any], *, iterations: int, warmups: int = 3) -> BenchmarkSummary: for _ in range(warmups): fn() samples_ms: list[float] = [] for _ in range(iterations): started_at = time.perf_counter() fn() samples_ms.append((time.perf_counter() - started_at) * 1000.0) samples_ms.sort() p95_index = min(len(samples_ms) - 1, max(int(round(len(samples_ms) * 0.95)) - 1, 0)) return BenchmarkSummary( name=name, median_ms=round(statistics.median(samples_ms), 3), p95_ms=round(samples_ms[p95_index], 3), min_ms=round(samples_ms[0], 3), max_ms=round(samples_ms[-1], 3), iterations=iterations, extra={}, ) def _cpu_model() -> str: candidates = [ "lscpu | sed -n 's/^Model name:[[:space:]]*//p' | head -n1", "sysctl -n machdep.cpu.brand_string 2>/dev/null", ] for command in candidates: try: result = subprocess.run( ["bash", "-lc", command], check=True, capture_output=True, text=True, ) except Exception: continue value = result.stdout.strip() if value: return value return platform.processor() or "unknown" def _make_graph_fixture( *, nodes_count: int = 1000, relation_span: int = 5 ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: nodes = [ { "id": index, "celex": f"ACT-{index:05d}", "title": f"Regulatory Act {index}", } for index in range(1, nodes_count + 1) ] relationships: list[dict[str, Any]] = [] relation_types = ("AMENDS", "IMPLEMENTS", "SUPPLEMENTS", "CITES", "DEROGATES") for node_id in range(1, nodes_count + 1): for offset in range(1, relation_span + 1): target = node_id + offset if target > nodes_count: break relationships.append( { "start_node": node_id, "end_node": target, "type": relation_types[(node_id + offset) % len(relation_types)], "description": f"Relationship {node_id}->{target}", } ) return nodes, relationships def _make_semantic_docs(count: int) -> list[Any]: class _Doc: def __init__(self, payload: dict[str, Any]) -> None: self.payload = payload docs: list[Any] = [] for index in range(1, count + 1): docs.append( _Doc( { "content": ("This regulatory passage explains obligations and controls. " * 24)[:900], "celex": f"32016R{index:04d}", "act_title": f"Regulation {index}", "chunk_id": index, "chunk_index": index - 1, "subdivision_id": 1000 + index, "act_id": index, } ) ) return docs class _FakeQdrantRepo: def search(self, *_args: Any, **_kwargs: Any) -> list[Any]: return [] class _FakePGRepo: def search_bm25(self, **_kwargs: Any) -> list[dict[str, Any]]: return [] def search_bm25_chunks(self, **_kwargs: Any) -> list[dict[str, Any]]: return [] class _FakeEmbeddingService: model_name = "benchmark-embedding" def get_vector_size(self) -> int: return 2 def embed_text(self, text: str) -> list[float]: return [float(len(text)), 0.0] def embed_batch(self, texts: list[str], batch_size: int | None = None) -> list[list[float]]: del batch_size return [[float(index + 1), 0.0] for index, _ in enumerate(texts)] def _make_retrieval_result(content: str, score: float, subdivision_id: int, act_id: int, celex: str) -> RetrievalResult: return RetrievalResult( content=content, score=score, subdivision_id=subdivision_id, act_id=act_id, celex=celex, subdivision_type="article", sequence_order=subdivision_id, metadata={}, ) def _make_retrieval_service() -> RetrievalService: return RetrievalService( qdrant_repos=cast(dict[str, Any], {"chunks": _FakeQdrantRepo(), "subdivisions": _FakeQdrantRepo()}), pg_repo=cast(Any, _FakePGRepo()), embedding_service=cast(Any, _FakeEmbeddingService()), reranker=None, result_cache_ttl=0, ) def _attach_fake_semantic_search(service: RetrievalService) -> None: def fake_search( *, query_vector: list[float], top_k: int, filters: dict[str, Any] | None, collections: list[str] | None, ) -> list[RetrievalResult]: del top_k, filters variant_index = int(query_vector[0]) collection = (collections or ["chunks"])[0] time.sleep( { (1, "chunks"): 0.05, (1, "subdivisions"): 0.01, (2, "chunks"): 0.01, (2, "subdivisions"): 0.04, (3, "chunks"): 0.03, (3, "subdivisions"): 0.02, }[(variant_index, collection)] ) payloads: dict[tuple[int, str], list[RetrievalResult]] = { (1, "chunks"): [_make_retrieval_result("A", 0.91, 101, 1, "A")], (1, "subdivisions"): [_make_retrieval_result("A-dup", 0.88, 101, 1, "A")], (2, "chunks"): [_make_retrieval_result("B", 0.86, 102, 2, "B")], (2, "subdivisions"): [_make_retrieval_result("C", 0.84, 103, 3, "C")], (3, "chunks"): [_make_retrieval_result("D", 0.83, 104, 4, "D")], (3, "subdivisions"): [], } return payloads[(variant_index, collection)] service.semantic_service.search = fake_search # type: ignore[method-assign] def _semantic_parallel_benchmark() -> dict[str, Any]: service = _make_retrieval_service() _attach_fake_semantic_search(service) expanded_queries = [ ExpandedQuery(text="alpha", weight=1.0, strategy="original"), ExpandedQuery(text="beta", weight=0.9, strategy="keyword_focus"), ExpandedQuery(text="gamma", weight=0.8, strategy="bilingual_mirror"), ] query_vectors = service.embedding_service.embed_batch([item.text for item in expanded_queries], batch_size=3) def sequential() -> list[RetrievalResult]: combined: list[RetrievalResult] = [] for vector in query_vectors: combined.extend( service._semantic_search_multi_by_vector( query_vector=vector, top_k=6, filters=None, collections=["chunks", "subdivisions"], ) ) return combined def parallel() -> list[RetrievalResult]: return service._semantic_search_with_expansion( query="Question complexe", expanded_queries=expanded_queries, top_k=6, filters=None, collections=["chunks", "subdivisions"], output_top_k=6, ) sequential_summary = _measure("sequential_semantic_baseline", sequential, iterations=15, warmups=2) parallel_summary = _measure("parallel_semantic_search", parallel, iterations=15, warmups=2) speedup = sequential_summary.median_ms / parallel_summary.median_ms if parallel_summary.median_ms else 0.0 return { "variants": len(expanded_queries), "collections": 2, "sequential_median_ms": sequential_summary.median_ms, "parallel_median_ms": parallel_summary.median_ms, "speedup_x": round(speedup, 2), "parallel_p95_ms": parallel_summary.p95_ms, }
[docs] def main() -> None: """Run the local GraphRAG benchmark suite and write a JSON report.""" graph_nodes, graph_relationships = _make_graph_fixture() semantic_ids = {node_id for node_id in range(1, 1001, 11)} seed_ids = set(range(1, 21)) ranked_nodes = rank_graph_nodes( graph_context=graph_nodes, relationships=graph_relationships, semantic_act_ids=semantic_ids, seed_act_ids=seed_ids, max_depth=4, ) top_act_ids = {int(node["id"]) for node in ranked_nodes[:150]} ranked_relationships = rank_relationships( relationships=graph_relationships, top_act_ids=top_act_ids, ) semantic_docs = _make_semantic_docs(24) cypher_query = """ MATCH (a:Act) CALL { WITH a MATCH (a)-[r]->(b:Act) RETURN count(r) AS relation_count } RETURN a.id AS id, relation_count LIMIT 12 """ cypher_validation = _measure( "cypher_validation", lambda: Neo4jRepository._validate_read_only_cypher(cypher_query), iterations=5000, ) cypher_validation.extra = { "query_length_chars": len(cypher_query), "throughput_queries_per_second": round(1000.0 / cypher_validation.median_ms, 1) if cypher_validation.median_ms else None, } graph_ranking = _measure( "graph_ranking", lambda: rank_graph_nodes( graph_context=graph_nodes, relationships=graph_relationships, semantic_act_ids=semantic_ids, seed_act_ids=seed_ids, max_depth=4, ), iterations=80, ) graph_ranking.extra = { "nodes": len(graph_nodes), "relationships": len(graph_relationships), "seed_acts": len(seed_ids), } context_budget = _measure( "graph_context_budget", lambda: GraphContextBudget(max_chars=20_000).build( semantic_results=semantic_docs, ranked_nodes=ranked_nodes[:150], ranked_relationships=ranked_relationships[:300], ), iterations=120, ) context_budget.extra = { "semantic_results": len(semantic_docs), "ranked_nodes_considered": 150, "ranked_relationships_considered": 300, "budget_chars": 20_000, } semantic_parallel = _semantic_parallel_benchmark() payload = { "generated_at_utc": datetime.now(UTC).isoformat(), "python": platform.python_version(), "platform": { "system": platform.system(), "release": platform.release(), "machine": platform.machine(), "cpu_model": _cpu_model(), }, "benchmarks": { cypher_validation.name: { "median_ms": cypher_validation.median_ms, "p95_ms": cypher_validation.p95_ms, "min_ms": cypher_validation.min_ms, "max_ms": cypher_validation.max_ms, "iterations": cypher_validation.iterations, **cypher_validation.extra, }, graph_ranking.name: { "median_ms": graph_ranking.median_ms, "p95_ms": graph_ranking.p95_ms, "min_ms": graph_ranking.min_ms, "max_ms": graph_ranking.max_ms, "iterations": graph_ranking.iterations, **graph_ranking.extra, }, context_budget.name: { "median_ms": context_budget.median_ms, "p95_ms": context_budget.p95_ms, "min_ms": context_budget.min_ms, "max_ms": context_budget.max_ms, "iterations": context_budget.iterations, **context_budget.extra, }, "semantic_parallel_search": semantic_parallel, }, } print(json.dumps(payload, indent=2, ensure_ascii=False))
if __name__ == "__main__": main()