#!/usr/bin/env python3
"""
Generate evaluation datasets from database contents.
Connects to PostgreSQL, selects acts and subdivisions, then uses the
LLM (Mistral) to generate test queries with expected results and
ground-truth answers.
Produces two JSONL files:
- rag_test_retrieval.jsonl (queries + expected_celex)
- rag_test_qa.jsonl (queries + ground_truth answers)
Usage:
python scripts/generate_eval_dataset.py
python scripts/generate_eval_dataset.py --limit 20 --output-dir scripts/datasets
python scripts/generate_eval_dataset.py --celex 32016R0679 --qa-per-act 3
"""
from __future__ import annotations
import argparse
import json
import os
import random
import sys
from pathlib import Path
from typing import Any, Optional
# Ensure packages are importable
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(_PROJECT_ROOT))
from lalandre_core.config import get_config, get_env_settings # noqa: E402
from lalandre_db_postgres import PostgresRepository # noqa: E402
_DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent / "datasets"
_RETRIEVAL_PROMPT = """\
Tu es un expert juridique. Voici le titre d'un acte legislatif europeen :
Titre : {title}
CELEX : {celex}
Type : {act_type}
Genere {n} questions de recherche variees qu'un juriste poserait pour trouver cet acte \
dans une base documentaire. Les questions doivent etre formulees en francais, etre \
realistes et diversifiees (recherche par sujet, par reference, par contenu).
Reponds UNIQUEMENT avec un tableau JSON de strings, sans explication.
Exemple : ["Question 1", "Question 2", "Question 3"]
"""
_QA_PROMPT = """\
Tu es un expert juridique. Voici un extrait d'un acte legislatif europeen :
Titre de l'acte : {title}
CELEX : {celex}
Type de subdivision : {sub_type}
Numero : {sub_number}
Contenu :
{content}
A partir de cet extrait, genere une paire question/reponse pour evaluer un systeme RAG juridique.
La question doit etre precise et la reponse doit etre fidele au contenu de l'extrait.
Reponds UNIQUEMENT avec un objet JSON au format :
{{"question": "...", "answer": "..."}}
"""
def _build_llm(
*,
model: str,
api_key: str,
temperature: float,
timeout: float,
) -> Any:
from lalandre_core.llm import build_chat_model
from langchain_core.output_parsers import StrOutputParser
llm = build_chat_model(
provider="mistral",
model=model,
api_key=api_key,
temperature=temperature,
timeout_seconds=timeout,
)
return llm | StrOutputParser()
def _parse_json_array(text: str) -> list[str]:
text = text.strip()
start = text.find("[")
end = text.rfind("]")
if start == -1 or end == -1:
return []
try:
parsed = json.loads(text[start : end + 1])
if isinstance(parsed, list):
return [str(item).strip() for item in parsed if str(item).strip()]
except json.JSONDecodeError:
pass
return []
def _parse_json_object(text: str) -> Optional[dict[str, str]]:
text = text.strip()
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1:
return None
try:
parsed = json.loads(text[start : end + 1])
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
return None
[docs]
def generate_retrieval_queries(
acts: list[Any],
llm_chain: Any,
queries_per_act: int,
) -> list[dict[str, Any]]:
"""Generate retrieval-oriented evaluation queries for the selected acts."""
dataset: list[dict[str, Any]] = []
for i, act in enumerate(acts):
celex = act.celex or ""
title = act.title or celex
act_type = getattr(act, "act_type", None)
act_type_str = getattr(act_type, "value", None) or str(act_type or "regulation")
print(f" [{i + 1}/{len(acts)}] Generating retrieval queries for {celex}...", end=" ", flush=True)
prompt = _RETRIEVAL_PROMPT.format(
title=title,
celex=celex,
act_type=act_type_str,
n=queries_per_act,
)
try:
response = llm_chain.invoke(prompt)
questions = _parse_json_array(response)
except Exception as exc:
print(f"ERROR: {exc}")
continue
for question in questions[:queries_per_act]:
dataset.append(
{
"query": question,
"expected_celex": [celex],
}
)
print(f"{len(questions[:queries_per_act])} queries")
return dataset
[docs]
def generate_qa_pairs(
acts: list[Any],
repo: PostgresRepository,
session: Any,
llm_chain: Any,
qa_per_act: int,
min_content_chars: int,
) -> list[dict[str, Any]]:
"""Generate grounded QA pairs from representative act subdivisions."""
dataset: list[dict[str, Any]] = []
for i, act in enumerate(acts):
celex = act.celex or ""
title = act.title or celex
subdivisions = repo.list_subdivisions_for_act(session, act.id)
eligible = [s for s in subdivisions if (s.content or "") and len(s.content or "") >= min_content_chars]
if not eligible:
print(f" [{i + 1}/{len(acts)}] {celex}: no eligible subdivisions, skipping")
continue
selected = random.sample(eligible, min(qa_per_act, len(eligible)))
print(f" [{i + 1}/{len(acts)}] Generating QA for {celex} ({len(selected)} subdivisions)...", flush=True)
for sub in selected:
sub_type = getattr(sub, "subdivision_type", None)
sub_type_str = getattr(sub_type, "value", None) or str(sub_type or "article")
content = (sub.content or "")[:3000]
prompt = _QA_PROMPT.format(
title=title,
celex=celex,
sub_type=sub_type_str,
sub_number=sub.number or "N/A",
content=content,
)
try:
response = llm_chain.invoke(prompt)
qa_pair = _parse_json_object(response)
except Exception as exc:
print(f" ERROR on subdivision {sub.id}: {exc}")
continue
if qa_pair and qa_pair.get("question") and qa_pair.get("answer"):
dataset.append(
{
"query": qa_pair["question"],
"ground_truth": qa_pair["answer"],
"mode": "rag",
"expected_celex": [celex],
}
)
return dataset
[docs]
def parse_args() -> argparse.Namespace:
"""Parse CLI arguments for dataset generation."""
parser = argparse.ArgumentParser(description="Generate RAG evaluation datasets from database")
parser.add_argument("--limit", type=int, default=10, help="Max number of acts to process")
parser.add_argument("--celex", type=str, default=None, help="Specific CELEX to process")
parser.add_argument("--queries-per-act", type=int, default=3, help="Retrieval queries per act")
parser.add_argument("--qa-per-act", type=int, default=2, help="QA pairs per act")
parser.add_argument("--min-content", type=int, default=200, help="Min subdivision content length for QA")
parser.add_argument("--output-dir", default=None, help="Output directory for datasets")
parser.add_argument("--model", default="mistral-small-latest", help="Mistral model for generation")
parser.add_argument("--temperature", type=float, default=0.3, help="LLM temperature")
parser.add_argument("--timeout", type=float, default=60.0, help="LLM timeout")
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
return parser.parse_args()
[docs]
def main() -> int:
"""Generate retrieval and QA datasets from the current database snapshot."""
args = parse_args()
random.seed(args.seed)
output_dir = Path(args.output_dir) if args.output_dir else _DEFAULT_OUTPUT_DIR
output_dir.mkdir(parents=True, exist_ok=True)
config = get_config()
settings = get_env_settings()
api_key = settings.MISTRAL_API_KEY
if not api_key:
api_key = os.getenv("MISTRAL_API_KEY", "")
if not api_key:
print("ERROR: MISTRAL_API_KEY not set. Set it in .env or as env var.")
return 1
print("Connecting to database...")
repo = PostgresRepository(config.database.connection_string)
llm_chain = _build_llm(
model=args.model,
api_key=api_key,
temperature=args.temperature,
timeout=args.timeout,
)
with repo.get_session() as session:
if args.celex:
act = repo.get_act_by_celex(session, args.celex)
if not act:
print(f"ERROR: Act {args.celex} not found")
return 1
acts = [act]
else:
all_acts = repo.list_acts_with_metadata(session)
if not all_acts:
print("ERROR: No acts in database")
return 1
acts = random.sample(all_acts, min(args.limit, len(all_acts)))
print(f"Selected {len(acts)} acts for dataset generation\n")
# Generate retrieval dataset
print("Phase 1: Generating retrieval queries...")
retrieval_data = generate_retrieval_queries(acts, llm_chain, args.queries_per_act)
retrieval_path = output_dir / "rag_test_retrieval.jsonl"
with retrieval_path.open("w", encoding="utf-8") as f:
for item in retrieval_data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"\n Saved {len(retrieval_data)} retrieval queries to {retrieval_path}")
# Generate QA dataset
print("\nPhase 2: Generating QA pairs...")
qa_data = generate_qa_pairs(
acts,
repo,
session,
llm_chain,
qa_per_act=args.qa_per_act,
min_content_chars=args.min_content,
)
qa_path = output_dir / "rag_test_qa.jsonl"
with qa_path.open("w", encoding="utf-8") as f:
for item in qa_data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"\n Saved {len(qa_data)} QA pairs to {qa_path}")
print(f"\nDatasets ready in {output_dir}/")
return 0
if __name__ == "__main__":
raise SystemExit(main())