Source code for rag_service.graph.cypher_branch

"""
Cypher branch — orchestration layer.

Collects Cypher-derived graph evidence to enrich (not replace) the main RAG
answer.  Pipeline: intent → template → Text2Cypher → NL→Cypher fallback.

Routing and template selection are delegated to ``intent`` and ``templates``
modules.
"""

import asyncio
import logging
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from fastapi import HTTPException
from lalandre_core.utils import as_document_list, sanitize_error_text
from lalandre_rag.graph.helpers import build_nl_to_cypher_prompt, extract_cypher
from lalandre_rag.graph.source_payloads import build_cypher_row_source_item
from langchain_core.output_parsers import StrOutputParser
from rag_service.models import QueryRequest

from .intent import analyze_relation_intent
from .templates import match_template

logger = logging.getLogger(__name__)


# ── Data types ────────────────────────────────────────────────────────────────


[docs] @dataclass(frozen=True) class CypherResolution: """Resolved Cypher-support result returned by the graph branch.""" status: str strategy: str generated_cypher: Optional[str] rows: List[Dict[str, Any]] detail: str
# ── Execution ───────────────────────────────────────────────────────────────── def _error_text(exc: BaseException) -> str: return sanitize_error_text(exc if isinstance(exc, Exception) else Exception(str(exc))) async def _execute_read_only_query( *, neo4j_repo: Any, generated_cypher: str, cypher_max_rows: int, timings: Dict[str, float], ) -> List[Dict[str, Any]]: t0 = time.perf_counter() try: raw = await asyncio.to_thread( neo4j_repo.execute_read_only_query, generated_cypher, result_limit=cypher_max_rows, ) except ValueError as exc: raise HTTPException( status_code=422, detail={ "error": "Generated Cypher rejected by read-only validator", "reason": _error_text(exc), "generated_cypher": generated_cypher, }, ) from exc except Exception as exc: raise HTTPException( status_code=502, detail={ "error": "Neo4j execution failed", "reason": _error_text(exc), "generated_cypher": generated_cypher, }, ) from exc timings["cypher_execution_ms"] = round((time.perf_counter() - t0) * 1000.0, 1) return as_document_list(raw) # ── Resolution pipeline ────────────────────────────────────────────────────── async def _resolve_cypher_support( *, query_request: QueryRequest, graph_rag_service: Any, rag_service: Any, config: Any, graph_depth: int, timings: Dict[str, float], ) -> CypherResolution: neo4j_repo = getattr(graph_rag_service, "neo4j", None) if neo4j_repo is None: raise HTTPException( status_code=503, detail="Graph mode unavailable: Neo4j repository is not initialized", ) graph_cfg = config.graph cypher_timeout = ( query_request.graph_cypher_timeout_seconds if query_request.graph_cypher_timeout_seconds is not None else graph_cfg.cypher_timeout_seconds ) cypher_max_rows = ( query_request.graph_cypher_max_rows if query_request.graph_cypher_max_rows is not None else graph_cfg.cypher_max_rows ) # ── Step 1: Analyze intent ── embedding_service = getattr(graph_rag_service, "embedding", None) intent = analyze_relation_intent( query_request.question, embedding_service=embedding_service, ) if not intent.has_intent: return CypherResolution( status="skipped", strategy="skipped_not_relevant", generated_cypher=None, rows=[], detail="Question peu relationnelle : appui Cypher non utilisé", ) # ── Step 2: Try deterministic template ── t0 = time.perf_counter() template = match_template( intent, max_graph_depth=graph_depth, row_limit=cypher_max_rows, ) timings["cypher_template_planning_ms"] = round((time.perf_counter() - t0) * 1000.0, 1) if template is not None: rows = await _execute_read_only_query( neo4j_repo=neo4j_repo, generated_cypher=template.cypher, cypher_max_rows=cypher_max_rows, timings=timings, ) return CypherResolution( status="ok", strategy=template.strategy, generated_cypher=template.cypher, rows=rows, detail=template.detail, ) # ── Step 3: Official Text2Cypher (Neo4j GraphRAG) ── if graph_rag_service.supports_official_text2cypher(): t0 = time.perf_counter() try: official_result = await asyncio.wait_for( asyncio.to_thread( graph_rag_service.text_to_cypher_search, question=query_request.question, max_graph_depth=graph_depth, row_limit=cypher_max_rows, ), timeout=float(cypher_timeout), ) except asyncio.TimeoutError as exc: raise HTTPException( status_code=504, detail={ "error": "Neo4j GraphRAG Text2Cypher timed out", "timeout_seconds": float(cypher_timeout), }, ) from exc except Exception as exc: logger.warning( "Neo4j GraphRAG Text2Cypher failed, fallback to legacy NL->Cypher: %s", exc, ) else: timings["cypher_retrieval_ms"] = round((time.perf_counter() - t0) * 1000.0, 1) return CypherResolution( status="ok", strategy="neo4j_graphrag_text2cypher", generated_cypher=official_result.generated_cypher, rows=official_result.rows, detail="Cypher généré via Neo4j GraphRAG", ) # ── Step 4: Legacy NL→Cypher (LLM generation) ── _str_llm = rag_service.llm | StrOutputParser() cypher_prompt = build_nl_to_cypher_prompt( question=query_request.question, max_graph_depth=graph_depth, row_limit=cypher_max_rows, ) t0 = time.perf_counter() try: cypher_raw = await asyncio.wait_for( asyncio.to_thread(_str_llm.invoke, cypher_prompt), timeout=float(cypher_timeout), ) except asyncio.TimeoutError as exc: raise HTTPException( status_code=504, detail={ "error": "NL->Cypher generation timed out", "timeout_seconds": float(cypher_timeout), }, ) from exc timings["cypher_generation_ms"] = round((time.perf_counter() - t0) * 1000.0, 1) generated_cypher = extract_cypher(cypher_raw) if not generated_cypher: raise HTTPException( status_code=422, detail={ "error": "Unable to extract Cypher from LLM output", "llm_output_preview": cypher_raw[:800], }, ) rows = await _execute_read_only_query( neo4j_repo=neo4j_repo, generated_cypher=generated_cypher, cypher_max_rows=cypher_max_rows, timings=timings, ) return CypherResolution( status="ok", strategy="legacy_nl_to_cypher", generated_cypher=generated_cypher, rows=rows, detail="Cypher généré en fallback NL->Cypher", ) # ── Public API ──────────────────────────────────────────────────────────────── def _build_cypher_documents( *, query_request: QueryRequest, config: Any, cypher_rows: List[Dict[str, Any]], graph_query_strategy: str, generated_cypher: Optional[str], ) -> List[Dict[str, Any]]: preview_chars = ( getattr(getattr(config, "context_budget", None), "content_preview_chars", None) or getattr(getattr(config, "graph", None), "content_preview_chars", None) or 200 ) return [ build_cypher_row_source_item( row=row, row_index=i, include_full_content=query_request.include_full_content, content_preview_chars=int(preview_chars), query_id=query_request.query_id, graph_query_strategy=graph_query_strategy, generated_cypher=generated_cypher, ) for i, row in enumerate(cypher_rows, start=1) ]
[docs] async def collect_cypher_support( *, query_request: QueryRequest, graph_rag_service: Any, rag_service: Any, config: Any, graph_depth: int, strict_graph_mode: bool, ) -> tuple[Optional[Dict[str, Any]], Dict[str, Any]]: """Collect Cypher-derived graph evidence without replacing the main RAG answer.""" timings: Dict[str, float] = {} t0 = time.perf_counter() try: resolution = await _resolve_cypher_support( query_request=query_request, graph_rag_service=graph_rag_service, rag_service=rag_service, config=config, graph_depth=graph_depth, timings=timings, ) except HTTPException as exc: detail = exc.detail if isinstance(exc.detail, dict) else {"error": str(exc.detail)} return None, { "graph_cypher_support": { "enabled": True, "status": "error", **detail, } } except Exception as exc: logger.warning("Cypher support collection failed: %s", _error_text(exc)) return None, { "graph_cypher_support": { "enabled": True, "status": "error", "error": _error_text(exc), } } timings["total_graph_pipeline_ms"] = round((time.perf_counter() - t0) * 1000.0, 1) if resolution.status == "skipped": return None, { "graph_cypher_support": { "enabled": True, "status": "skipped", "strict_mode": strict_graph_mode, "graph_depth": graph_depth, "strategy": resolution.strategy, "reason": resolution.detail, "phase_timings_ms": timings, } } cypher_documents = ( _build_cypher_documents( query_request=query_request, config=config, cypher_rows=resolution.rows, graph_query_strategy=resolution.strategy, generated_cypher=resolution.generated_cypher, ) if query_request.sources else [] ) sources_payload: Optional[Dict[str, Any]] = None if query_request.sources: sources_payload = { "total": len(cypher_documents), "documents": [], "cypher_rows": cypher_documents, "graph_query": { "cypher": resolution.generated_cypher, "strategy": resolution.strategy, "row_count": len(resolution.rows), "strict_mode": strict_graph_mode, "graph_depth": graph_depth, }, } return sources_payload, { "graph_cypher_support": { "enabled": True, "status": "ok", "strict_mode": strict_graph_mode, "graph_depth": graph_depth, "strategy": resolution.strategy, "generated_cypher": resolution.generated_cypher, "row_count": len(resolution.rows), "detail": resolution.detail, "phase_timings_ms": timings, } }