#!/usr/bin/env python3
"""Reproducible local GraphRAG microbenchmarks.
The goal is not to simulate end-to-end production latency. This script measures
deterministic building blocks that materially affect GraphRAG responsiveness:
- read-only Cypher validation,
- graph node ranking,
- token-budget-aware graph context assembly,
- parallel semantic retrieval fan-out versus a sequential baseline.
"""
from __future__ import annotations
import json
import platform
import statistics
import subprocess
import time
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Callable, cast
from lalandre_db_neo4j.repository import Neo4jRepository
from lalandre_rag.graph.context_budget import GraphContextBudget
from lalandre_rag.graph.ranker import rank_graph_nodes, rank_relationships
from lalandre_rag.retrieval.query_expansion import ExpandedQuery
from lalandre_rag.retrieval.result import RetrievalResult
from lalandre_rag.retrieval.service import RetrievalService
ROOT = Path(__file__).resolve().parents[1]
[docs]
@dataclass
class BenchmarkSummary:
"""Latency summary for one benchmarked operation."""
name: str
median_ms: float
p95_ms: float
min_ms: float
max_ms: float
iterations: int
extra: dict[str, Any]
def _measure(name: str, fn: Callable[[], Any], *, iterations: int, warmups: int = 3) -> BenchmarkSummary:
for _ in range(warmups):
fn()
samples_ms: list[float] = []
for _ in range(iterations):
started_at = time.perf_counter()
fn()
samples_ms.append((time.perf_counter() - started_at) * 1000.0)
samples_ms.sort()
p95_index = min(len(samples_ms) - 1, max(int(round(len(samples_ms) * 0.95)) - 1, 0))
return BenchmarkSummary(
name=name,
median_ms=round(statistics.median(samples_ms), 3),
p95_ms=round(samples_ms[p95_index], 3),
min_ms=round(samples_ms[0], 3),
max_ms=round(samples_ms[-1], 3),
iterations=iterations,
extra={},
)
def _cpu_model() -> str:
candidates = [
"lscpu | sed -n 's/^Model name:[[:space:]]*//p' | head -n1",
"sysctl -n machdep.cpu.brand_string 2>/dev/null",
]
for command in candidates:
try:
result = subprocess.run(
["bash", "-lc", command],
check=True,
capture_output=True,
text=True,
)
except Exception:
continue
value = result.stdout.strip()
if value:
return value
return platform.processor() or "unknown"
def _make_graph_fixture(
*, nodes_count: int = 1000, relation_span: int = 5
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
nodes = [
{
"id": index,
"celex": f"ACT-{index:05d}",
"title": f"Regulatory Act {index}",
}
for index in range(1, nodes_count + 1)
]
relationships: list[dict[str, Any]] = []
relation_types = ("AMENDS", "IMPLEMENTS", "SUPPLEMENTS", "CITES", "DEROGATES")
for node_id in range(1, nodes_count + 1):
for offset in range(1, relation_span + 1):
target = node_id + offset
if target > nodes_count:
break
relationships.append(
{
"start_node": node_id,
"end_node": target,
"type": relation_types[(node_id + offset) % len(relation_types)],
"description": f"Relationship {node_id}->{target}",
}
)
return nodes, relationships
def _make_semantic_docs(count: int) -> list[Any]:
class _Doc:
def __init__(self, payload: dict[str, Any]) -> None:
self.payload = payload
docs: list[Any] = []
for index in range(1, count + 1):
docs.append(
_Doc(
{
"content": ("This regulatory passage explains obligations and controls. " * 24)[:900],
"celex": f"32016R{index:04d}",
"act_title": f"Regulation {index}",
"chunk_id": index,
"chunk_index": index - 1,
"subdivision_id": 1000 + index,
"act_id": index,
}
)
)
return docs
class _FakeQdrantRepo:
def search(self, *_args: Any, **_kwargs: Any) -> list[Any]:
return []
class _FakePGRepo:
def search_bm25(self, **_kwargs: Any) -> list[dict[str, Any]]:
return []
def search_bm25_chunks(self, **_kwargs: Any) -> list[dict[str, Any]]:
return []
class _FakeEmbeddingService:
model_name = "benchmark-embedding"
def get_vector_size(self) -> int:
return 2
def embed_text(self, text: str) -> list[float]:
return [float(len(text)), 0.0]
def embed_batch(self, texts: list[str], batch_size: int | None = None) -> list[list[float]]:
del batch_size
return [[float(index + 1), 0.0] for index, _ in enumerate(texts)]
def _make_retrieval_result(content: str, score: float, subdivision_id: int, act_id: int, celex: str) -> RetrievalResult:
return RetrievalResult(
content=content,
score=score,
subdivision_id=subdivision_id,
act_id=act_id,
celex=celex,
subdivision_type="article",
sequence_order=subdivision_id,
metadata={},
)
def _make_retrieval_service() -> RetrievalService:
return RetrievalService(
qdrant_repos=cast(dict[str, Any], {"chunks": _FakeQdrantRepo(), "subdivisions": _FakeQdrantRepo()}),
pg_repo=cast(Any, _FakePGRepo()),
embedding_service=cast(Any, _FakeEmbeddingService()),
reranker=None,
result_cache_ttl=0,
)
def _attach_fake_semantic_search(service: RetrievalService) -> None:
def fake_search(
*,
query_vector: list[float],
top_k: int,
filters: dict[str, Any] | None,
collections: list[str] | None,
) -> list[RetrievalResult]:
del top_k, filters
variant_index = int(query_vector[0])
collection = (collections or ["chunks"])[0]
time.sleep(
{
(1, "chunks"): 0.05,
(1, "subdivisions"): 0.01,
(2, "chunks"): 0.01,
(2, "subdivisions"): 0.04,
(3, "chunks"): 0.03,
(3, "subdivisions"): 0.02,
}[(variant_index, collection)]
)
payloads: dict[tuple[int, str], list[RetrievalResult]] = {
(1, "chunks"): [_make_retrieval_result("A", 0.91, 101, 1, "A")],
(1, "subdivisions"): [_make_retrieval_result("A-dup", 0.88, 101, 1, "A")],
(2, "chunks"): [_make_retrieval_result("B", 0.86, 102, 2, "B")],
(2, "subdivisions"): [_make_retrieval_result("C", 0.84, 103, 3, "C")],
(3, "chunks"): [_make_retrieval_result("D", 0.83, 104, 4, "D")],
(3, "subdivisions"): [],
}
return payloads[(variant_index, collection)]
service.semantic_service.search = fake_search # type: ignore[method-assign]
def _semantic_parallel_benchmark() -> dict[str, Any]:
service = _make_retrieval_service()
_attach_fake_semantic_search(service)
expanded_queries = [
ExpandedQuery(text="alpha", weight=1.0, strategy="original"),
ExpandedQuery(text="beta", weight=0.9, strategy="keyword_focus"),
ExpandedQuery(text="gamma", weight=0.8, strategy="bilingual_mirror"),
]
query_vectors = service.embedding_service.embed_batch([item.text for item in expanded_queries], batch_size=3)
def sequential() -> list[RetrievalResult]:
combined: list[RetrievalResult] = []
for vector in query_vectors:
combined.extend(
service._semantic_search_multi_by_vector(
query_vector=vector,
top_k=6,
filters=None,
collections=["chunks", "subdivisions"],
)
)
return combined
def parallel() -> list[RetrievalResult]:
return service._semantic_search_with_expansion(
query="Question complexe",
expanded_queries=expanded_queries,
top_k=6,
filters=None,
collections=["chunks", "subdivisions"],
output_top_k=6,
)
sequential_summary = _measure("sequential_semantic_baseline", sequential, iterations=15, warmups=2)
parallel_summary = _measure("parallel_semantic_search", parallel, iterations=15, warmups=2)
speedup = sequential_summary.median_ms / parallel_summary.median_ms if parallel_summary.median_ms else 0.0
return {
"variants": len(expanded_queries),
"collections": 2,
"sequential_median_ms": sequential_summary.median_ms,
"parallel_median_ms": parallel_summary.median_ms,
"speedup_x": round(speedup, 2),
"parallel_p95_ms": parallel_summary.p95_ms,
}
[docs]
def main() -> None:
"""Run the local GraphRAG benchmark suite and write a JSON report."""
graph_nodes, graph_relationships = _make_graph_fixture()
semantic_ids = {node_id for node_id in range(1, 1001, 11)}
seed_ids = set(range(1, 21))
ranked_nodes = rank_graph_nodes(
graph_context=graph_nodes,
relationships=graph_relationships,
semantic_act_ids=semantic_ids,
seed_act_ids=seed_ids,
max_depth=4,
)
top_act_ids = {int(node["id"]) for node in ranked_nodes[:150]}
ranked_relationships = rank_relationships(
relationships=graph_relationships,
top_act_ids=top_act_ids,
)
semantic_docs = _make_semantic_docs(24)
cypher_query = """
MATCH (a:Act)
CALL {
WITH a
MATCH (a)-[r]->(b:Act)
RETURN count(r) AS relation_count
}
RETURN a.id AS id, relation_count
LIMIT 12
"""
cypher_validation = _measure(
"cypher_validation",
lambda: Neo4jRepository._validate_read_only_cypher(cypher_query),
iterations=5000,
)
cypher_validation.extra = {
"query_length_chars": len(cypher_query),
"throughput_queries_per_second": round(1000.0 / cypher_validation.median_ms, 1)
if cypher_validation.median_ms
else None,
}
graph_ranking = _measure(
"graph_ranking",
lambda: rank_graph_nodes(
graph_context=graph_nodes,
relationships=graph_relationships,
semantic_act_ids=semantic_ids,
seed_act_ids=seed_ids,
max_depth=4,
),
iterations=80,
)
graph_ranking.extra = {
"nodes": len(graph_nodes),
"relationships": len(graph_relationships),
"seed_acts": len(seed_ids),
}
context_budget = _measure(
"graph_context_budget",
lambda: GraphContextBudget(max_chars=20_000).build(
semantic_results=semantic_docs,
ranked_nodes=ranked_nodes[:150],
ranked_relationships=ranked_relationships[:300],
),
iterations=120,
)
context_budget.extra = {
"semantic_results": len(semantic_docs),
"ranked_nodes_considered": 150,
"ranked_relationships_considered": 300,
"budget_chars": 20_000,
}
semantic_parallel = _semantic_parallel_benchmark()
payload = {
"generated_at_utc": datetime.now(UTC).isoformat(),
"python": platform.python_version(),
"platform": {
"system": platform.system(),
"release": platform.release(),
"machine": platform.machine(),
"cpu_model": _cpu_model(),
},
"benchmarks": {
cypher_validation.name: {
"median_ms": cypher_validation.median_ms,
"p95_ms": cypher_validation.p95_ms,
"min_ms": cypher_validation.min_ms,
"max_ms": cypher_validation.max_ms,
"iterations": cypher_validation.iterations,
**cypher_validation.extra,
},
graph_ranking.name: {
"median_ms": graph_ranking.median_ms,
"p95_ms": graph_ranking.p95_ms,
"min_ms": graph_ranking.min_ms,
"max_ms": graph_ranking.max_ms,
"iterations": graph_ranking.iterations,
**graph_ranking.extra,
},
context_budget.name: {
"median_ms": context_budget.median_ms,
"p95_ms": context_budget.p95_ms,
"min_ms": context_budget.min_ms,
"max_ms": context_budget.max_ms,
"iterations": context_budget.iterations,
**context_budget.extra,
},
"semantic_parallel_search": semantic_parallel,
},
}
print(json.dumps(payload, indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()