Source code for scripts.generate_eval_dataset

#!/usr/bin/env python3
"""
Generate evaluation datasets from database contents.

Connects to PostgreSQL, selects acts and subdivisions, then uses the
LLM (Mistral) to generate test queries with expected results and
ground-truth answers.

Produces two JSONL files:
  - rag_test_retrieval.jsonl  (queries + expected_celex)
  - rag_test_qa.jsonl         (queries + ground_truth answers)

Usage:
    python scripts/generate_eval_dataset.py
    python scripts/generate_eval_dataset.py --limit 20 --output-dir scripts/datasets
    python scripts/generate_eval_dataset.py --celex 32016R0679 --qa-per-act 3
"""

from __future__ import annotations

import argparse
import json
import os
import random
import sys
from pathlib import Path
from typing import Any, Optional

# Ensure packages are importable
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(_PROJECT_ROOT))

from lalandre_core.config import get_config, get_env_settings  # noqa: E402
from lalandre_db_postgres import PostgresRepository  # noqa: E402

_DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent / "datasets"

_RETRIEVAL_PROMPT = """\
Tu es un expert juridique. Voici le titre d'un acte legislatif europeen :

Titre : {title}
CELEX : {celex}
Type : {act_type}

Genere {n} questions de recherche variees qu'un juriste poserait pour trouver cet acte \
dans une base documentaire. Les questions doivent etre formulees en francais, etre \
realistes et diversifiees (recherche par sujet, par reference, par contenu).

Reponds UNIQUEMENT avec un tableau JSON de strings, sans explication.
Exemple : ["Question 1", "Question 2", "Question 3"]
"""

_QA_PROMPT = """\
Tu es un expert juridique. Voici un extrait d'un acte legislatif europeen :

Titre de l'acte : {title}
CELEX : {celex}
Type de subdivision : {sub_type}
Numero : {sub_number}

Contenu :
{content}

A partir de cet extrait, genere une paire question/reponse pour evaluer un systeme RAG juridique.
La question doit etre precise et la reponse doit etre fidele au contenu de l'extrait.

Reponds UNIQUEMENT avec un objet JSON au format :
{{"question": "...", "answer": "..."}}
"""


def _build_llm(
    *,
    model: str,
    api_key: str,
    temperature: float,
    timeout: float,
) -> Any:
    from lalandre_core.llm import build_chat_model
    from langchain_core.output_parsers import StrOutputParser

    llm = build_chat_model(
        provider="mistral",
        model=model,
        api_key=api_key,
        temperature=temperature,
        timeout_seconds=timeout,
    )
    return llm | StrOutputParser()


def _parse_json_array(text: str) -> list[str]:
    text = text.strip()
    start = text.find("[")
    end = text.rfind("]")
    if start == -1 or end == -1:
        return []
    try:
        parsed = json.loads(text[start : end + 1])
        if isinstance(parsed, list):
            return [str(item).strip() for item in parsed if str(item).strip()]
    except json.JSONDecodeError:
        pass
    return []


def _parse_json_object(text: str) -> Optional[dict[str, str]]:
    text = text.strip()
    start = text.find("{")
    end = text.rfind("}")
    if start == -1 or end == -1:
        return None
    try:
        parsed = json.loads(text[start : end + 1])
        if isinstance(parsed, dict):
            return parsed
    except json.JSONDecodeError:
        pass
    return None


[docs] def generate_retrieval_queries( acts: list[Any], llm_chain: Any, queries_per_act: int, ) -> list[dict[str, Any]]: """Generate retrieval-oriented evaluation queries for the selected acts.""" dataset: list[dict[str, Any]] = [] for i, act in enumerate(acts): celex = act.celex or "" title = act.title or celex act_type = getattr(act, "act_type", None) act_type_str = getattr(act_type, "value", None) or str(act_type or "regulation") print(f" [{i + 1}/{len(acts)}] Generating retrieval queries for {celex}...", end=" ", flush=True) prompt = _RETRIEVAL_PROMPT.format( title=title, celex=celex, act_type=act_type_str, n=queries_per_act, ) try: response = llm_chain.invoke(prompt) questions = _parse_json_array(response) except Exception as exc: print(f"ERROR: {exc}") continue for question in questions[:queries_per_act]: dataset.append( { "query": question, "expected_celex": [celex], } ) print(f"{len(questions[:queries_per_act])} queries") return dataset
[docs] def generate_qa_pairs( acts: list[Any], repo: PostgresRepository, session: Any, llm_chain: Any, qa_per_act: int, min_content_chars: int, ) -> list[dict[str, Any]]: """Generate grounded QA pairs from representative act subdivisions.""" dataset: list[dict[str, Any]] = [] for i, act in enumerate(acts): celex = act.celex or "" title = act.title or celex subdivisions = repo.list_subdivisions_for_act(session, act.id) eligible = [s for s in subdivisions if (s.content or "") and len(s.content or "") >= min_content_chars] if not eligible: print(f" [{i + 1}/{len(acts)}] {celex}: no eligible subdivisions, skipping") continue selected = random.sample(eligible, min(qa_per_act, len(eligible))) print(f" [{i + 1}/{len(acts)}] Generating QA for {celex} ({len(selected)} subdivisions)...", flush=True) for sub in selected: sub_type = getattr(sub, "subdivision_type", None) sub_type_str = getattr(sub_type, "value", None) or str(sub_type or "article") content = (sub.content or "")[:3000] prompt = _QA_PROMPT.format( title=title, celex=celex, sub_type=sub_type_str, sub_number=sub.number or "N/A", content=content, ) try: response = llm_chain.invoke(prompt) qa_pair = _parse_json_object(response) except Exception as exc: print(f" ERROR on subdivision {sub.id}: {exc}") continue if qa_pair and qa_pair.get("question") and qa_pair.get("answer"): dataset.append( { "query": qa_pair["question"], "ground_truth": qa_pair["answer"], "mode": "rag", "expected_celex": [celex], } ) return dataset
[docs] def parse_args() -> argparse.Namespace: """Parse CLI arguments for dataset generation.""" parser = argparse.ArgumentParser(description="Generate RAG evaluation datasets from database") parser.add_argument("--limit", type=int, default=10, help="Max number of acts to process") parser.add_argument("--celex", type=str, default=None, help="Specific CELEX to process") parser.add_argument("--queries-per-act", type=int, default=3, help="Retrieval queries per act") parser.add_argument("--qa-per-act", type=int, default=2, help="QA pairs per act") parser.add_argument("--min-content", type=int, default=200, help="Min subdivision content length for QA") parser.add_argument("--output-dir", default=None, help="Output directory for datasets") parser.add_argument("--model", default="mistral-small-latest", help="Mistral model for generation") parser.add_argument("--temperature", type=float, default=0.3, help="LLM temperature") parser.add_argument("--timeout", type=float, default=60.0, help="LLM timeout") parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") return parser.parse_args()
[docs] def main() -> int: """Generate retrieval and QA datasets from the current database snapshot.""" args = parse_args() random.seed(args.seed) output_dir = Path(args.output_dir) if args.output_dir else _DEFAULT_OUTPUT_DIR output_dir.mkdir(parents=True, exist_ok=True) config = get_config() settings = get_env_settings() api_key = settings.MISTRAL_API_KEY if not api_key: api_key = os.getenv("MISTRAL_API_KEY", "") if not api_key: print("ERROR: MISTRAL_API_KEY not set. Set it in .env or as env var.") return 1 print("Connecting to database...") repo = PostgresRepository(config.database.connection_string) llm_chain = _build_llm( model=args.model, api_key=api_key, temperature=args.temperature, timeout=args.timeout, ) with repo.get_session() as session: if args.celex: act = repo.get_act_by_celex(session, args.celex) if not act: print(f"ERROR: Act {args.celex} not found") return 1 acts = [act] else: all_acts = repo.list_acts_with_metadata(session) if not all_acts: print("ERROR: No acts in database") return 1 acts = random.sample(all_acts, min(args.limit, len(all_acts))) print(f"Selected {len(acts)} acts for dataset generation\n") # Generate retrieval dataset print("Phase 1: Generating retrieval queries...") retrieval_data = generate_retrieval_queries(acts, llm_chain, args.queries_per_act) retrieval_path = output_dir / "rag_test_retrieval.jsonl" with retrieval_path.open("w", encoding="utf-8") as f: for item in retrieval_data: f.write(json.dumps(item, ensure_ascii=False) + "\n") print(f"\n Saved {len(retrieval_data)} retrieval queries to {retrieval_path}") # Generate QA dataset print("\nPhase 2: Generating QA pairs...") qa_data = generate_qa_pairs( acts, repo, session, llm_chain, qa_per_act=args.qa_per_act, min_content_chars=args.min_content, ) qa_path = output_dir / "rag_test_qa.jsonl" with qa_path.open("w", encoding="utf-8") as f: for item in qa_data: f.write(json.dumps(item, ensure_ascii=False) + "\n") print(f"\n Saved {len(qa_data)} QA pairs to {qa_path}") print(f"\nDatasets ready in {output_dir}/") return 0
if __name__ == "__main__": raise SystemExit(main())