"""
Cypher branch — orchestration layer.
Collects Cypher-derived graph evidence to enrich (not replace) the main RAG
answer. Pipeline: intent → template → Text2Cypher → NL→Cypher fallback.
Routing and template selection are delegated to ``intent`` and ``templates``
modules.
"""
import asyncio
import logging
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from fastapi import HTTPException
from lalandre_core.utils import as_document_list, sanitize_error_text
from lalandre_rag.graph.helpers import build_nl_to_cypher_prompt, extract_cypher
from lalandre_rag.graph.source_payloads import build_cypher_row_source_item
from langchain_core.output_parsers import StrOutputParser
from rag_service.models import QueryRequest
from .intent import analyze_relation_intent
from .templates import match_template
logger = logging.getLogger(__name__)
# ── Data types ────────────────────────────────────────────────────────────────
[docs]
@dataclass(frozen=True)
class CypherResolution:
"""Resolved Cypher-support result returned by the graph branch."""
status: str
strategy: str
generated_cypher: Optional[str]
rows: List[Dict[str, Any]]
detail: str
# ── Execution ─────────────────────────────────────────────────────────────────
def _error_text(exc: BaseException) -> str:
return sanitize_error_text(exc if isinstance(exc, Exception) else Exception(str(exc)))
async def _execute_read_only_query(
*,
neo4j_repo: Any,
generated_cypher: str,
cypher_max_rows: int,
timings: Dict[str, float],
) -> List[Dict[str, Any]]:
t0 = time.perf_counter()
try:
raw = await asyncio.to_thread(
neo4j_repo.execute_read_only_query,
generated_cypher,
result_limit=cypher_max_rows,
)
except ValueError as exc:
raise HTTPException(
status_code=422,
detail={
"error": "Generated Cypher rejected by read-only validator",
"reason": _error_text(exc),
"generated_cypher": generated_cypher,
},
) from exc
except Exception as exc:
raise HTTPException(
status_code=502,
detail={
"error": "Neo4j execution failed",
"reason": _error_text(exc),
"generated_cypher": generated_cypher,
},
) from exc
timings["cypher_execution_ms"] = round((time.perf_counter() - t0) * 1000.0, 1)
return as_document_list(raw)
# ── Resolution pipeline ──────────────────────────────────────────────────────
async def _resolve_cypher_support(
*,
query_request: QueryRequest,
graph_rag_service: Any,
rag_service: Any,
config: Any,
graph_depth: int,
timings: Dict[str, float],
) -> CypherResolution:
neo4j_repo = getattr(graph_rag_service, "neo4j", None)
if neo4j_repo is None:
raise HTTPException(
status_code=503,
detail="Graph mode unavailable: Neo4j repository is not initialized",
)
graph_cfg = config.graph
cypher_timeout = (
query_request.graph_cypher_timeout_seconds
if query_request.graph_cypher_timeout_seconds is not None
else graph_cfg.cypher_timeout_seconds
)
cypher_max_rows = (
query_request.graph_cypher_max_rows
if query_request.graph_cypher_max_rows is not None
else graph_cfg.cypher_max_rows
)
# ── Step 1: Analyze intent ──
embedding_service = getattr(graph_rag_service, "embedding", None)
intent = analyze_relation_intent(
query_request.question,
embedding_service=embedding_service,
)
if not intent.has_intent:
return CypherResolution(
status="skipped",
strategy="skipped_not_relevant",
generated_cypher=None,
rows=[],
detail="Question peu relationnelle : appui Cypher non utilisé",
)
# ── Step 2: Try deterministic template ──
t0 = time.perf_counter()
template = match_template(
intent,
max_graph_depth=graph_depth,
row_limit=cypher_max_rows,
)
timings["cypher_template_planning_ms"] = round((time.perf_counter() - t0) * 1000.0, 1)
if template is not None:
rows = await _execute_read_only_query(
neo4j_repo=neo4j_repo,
generated_cypher=template.cypher,
cypher_max_rows=cypher_max_rows,
timings=timings,
)
return CypherResolution(
status="ok",
strategy=template.strategy,
generated_cypher=template.cypher,
rows=rows,
detail=template.detail,
)
# ── Step 3: Official Text2Cypher (Neo4j GraphRAG) ──
if graph_rag_service.supports_official_text2cypher():
t0 = time.perf_counter()
try:
official_result = await asyncio.wait_for(
asyncio.to_thread(
graph_rag_service.text_to_cypher_search,
question=query_request.question,
max_graph_depth=graph_depth,
row_limit=cypher_max_rows,
),
timeout=float(cypher_timeout),
)
except asyncio.TimeoutError as exc:
raise HTTPException(
status_code=504,
detail={
"error": "Neo4j GraphRAG Text2Cypher timed out",
"timeout_seconds": float(cypher_timeout),
},
) from exc
except Exception as exc:
logger.warning(
"Neo4j GraphRAG Text2Cypher failed, fallback to legacy NL->Cypher: %s",
exc,
)
else:
timings["cypher_retrieval_ms"] = round((time.perf_counter() - t0) * 1000.0, 1)
return CypherResolution(
status="ok",
strategy="neo4j_graphrag_text2cypher",
generated_cypher=official_result.generated_cypher,
rows=official_result.rows,
detail="Cypher généré via Neo4j GraphRAG",
)
# ── Step 4: Legacy NL→Cypher (LLM generation) ──
_str_llm = rag_service.llm | StrOutputParser()
cypher_prompt = build_nl_to_cypher_prompt(
question=query_request.question,
max_graph_depth=graph_depth,
row_limit=cypher_max_rows,
)
t0 = time.perf_counter()
try:
cypher_raw = await asyncio.wait_for(
asyncio.to_thread(_str_llm.invoke, cypher_prompt),
timeout=float(cypher_timeout),
)
except asyncio.TimeoutError as exc:
raise HTTPException(
status_code=504,
detail={
"error": "NL->Cypher generation timed out",
"timeout_seconds": float(cypher_timeout),
},
) from exc
timings["cypher_generation_ms"] = round((time.perf_counter() - t0) * 1000.0, 1)
generated_cypher = extract_cypher(cypher_raw)
if not generated_cypher:
raise HTTPException(
status_code=422,
detail={
"error": "Unable to extract Cypher from LLM output",
"llm_output_preview": cypher_raw[:800],
},
)
rows = await _execute_read_only_query(
neo4j_repo=neo4j_repo,
generated_cypher=generated_cypher,
cypher_max_rows=cypher_max_rows,
timings=timings,
)
return CypherResolution(
status="ok",
strategy="legacy_nl_to_cypher",
generated_cypher=generated_cypher,
rows=rows,
detail="Cypher généré en fallback NL->Cypher",
)
# ── Public API ────────────────────────────────────────────────────────────────
def _build_cypher_documents(
*,
query_request: QueryRequest,
config: Any,
cypher_rows: List[Dict[str, Any]],
graph_query_strategy: str,
generated_cypher: Optional[str],
) -> List[Dict[str, Any]]:
preview_chars = (
getattr(getattr(config, "context_budget", None), "content_preview_chars", None)
or getattr(getattr(config, "graph", None), "content_preview_chars", None)
or 200
)
return [
build_cypher_row_source_item(
row=row,
row_index=i,
include_full_content=query_request.include_full_content,
content_preview_chars=int(preview_chars),
query_id=query_request.query_id,
graph_query_strategy=graph_query_strategy,
generated_cypher=generated_cypher,
)
for i, row in enumerate(cypher_rows, start=1)
]
[docs]
async def collect_cypher_support(
*,
query_request: QueryRequest,
graph_rag_service: Any,
rag_service: Any,
config: Any,
graph_depth: int,
strict_graph_mode: bool,
) -> tuple[Optional[Dict[str, Any]], Dict[str, Any]]:
"""Collect Cypher-derived graph evidence without replacing the main RAG answer."""
timings: Dict[str, float] = {}
t0 = time.perf_counter()
try:
resolution = await _resolve_cypher_support(
query_request=query_request,
graph_rag_service=graph_rag_service,
rag_service=rag_service,
config=config,
graph_depth=graph_depth,
timings=timings,
)
except HTTPException as exc:
detail = exc.detail if isinstance(exc.detail, dict) else {"error": str(exc.detail)}
return None, {
"graph_cypher_support": {
"enabled": True,
"status": "error",
**detail,
}
}
except Exception as exc:
logger.warning("Cypher support collection failed: %s", _error_text(exc))
return None, {
"graph_cypher_support": {
"enabled": True,
"status": "error",
"error": _error_text(exc),
}
}
timings["total_graph_pipeline_ms"] = round((time.perf_counter() - t0) * 1000.0, 1)
if resolution.status == "skipped":
return None, {
"graph_cypher_support": {
"enabled": True,
"status": "skipped",
"strict_mode": strict_graph_mode,
"graph_depth": graph_depth,
"strategy": resolution.strategy,
"reason": resolution.detail,
"phase_timings_ms": timings,
}
}
cypher_documents = (
_build_cypher_documents(
query_request=query_request,
config=config,
cypher_rows=resolution.rows,
graph_query_strategy=resolution.strategy,
generated_cypher=resolution.generated_cypher,
)
if query_request.sources
else []
)
sources_payload: Optional[Dict[str, Any]] = None
if query_request.sources:
sources_payload = {
"total": len(cypher_documents),
"documents": [],
"cypher_rows": cypher_documents,
"graph_query": {
"cypher": resolution.generated_cypher,
"strategy": resolution.strategy,
"row_count": len(resolution.rows),
"strict_mode": strict_graph_mode,
"graph_depth": graph_depth,
},
}
return sources_payload, {
"graph_cypher_support": {
"enabled": True,
"status": "ok",
"strict_mode": strict_graph_mode,
"graph_depth": graph_depth,
"strategy": resolution.strategy,
"generated_cypher": resolution.generated_cypher,
"row_count": len(resolution.rows),
"detail": resolution.detail,
"phase_timings_ms": timings,
}
}