Source code for lalandre_rag.graph.helpers
"""
Graph-mode helper functions for the RAG service.
Contains NL→Cypher prompt building, Cypher extraction from LLM output,
and Cypher-row context formatting. These are domain-level utilities
that live in the *package* layer, not in the service layer.
"""
import json
import re
from typing import Any, Dict, List, Optional
from lalandre_core.utils import extract_json_object
from lalandre_rag.prompts import render_nl_to_cypher_prompt
# ---------------------------------------------------------------------------
# NL → Cypher prompt
# ---------------------------------------------------------------------------
[docs]
def build_nl_to_cypher_prompt(
*,
question: str,
max_graph_depth: int,
row_limit: int,
) -> str:
"""Return the system prompt that translates a natural-language *question* into Cypher."""
bounded_depth = max(int(max_graph_depth), 1)
bounded_limit = max(int(row_limit), 1)
return render_nl_to_cypher_prompt(
question=question,
max_graph_depth=bounded_depth,
row_limit=bounded_limit,
)
# ---------------------------------------------------------------------------
# Cypher extraction from LLM output
# ---------------------------------------------------------------------------
_CYPHER_PREFIX_RE = re.compile(
r"^\s*(?:cypher|query|cypher\s+query|generated_cypher)\s*:?\s*",
flags=re.IGNORECASE,
)
_CYPHER_PLAN_RE = re.compile(r"^\s*(?:EXPLAIN|PROFILE)\b\s*", flags=re.IGNORECASE)
[docs]
def normalize_cypher_candidate(candidate: str) -> str:
"""Strip common model artifacts around an otherwise valid Cypher query."""
normalized = candidate.strip()
if not normalized:
return ""
if normalized.startswith("```"):
normalized = re.sub(
r"^\s*```(?:cypher)?\s*",
"",
normalized,
count=1,
flags=re.IGNORECASE,
)
normalized = re.sub(r"\s*```\s*$", "", normalized, count=1)
while True:
updated = _CYPHER_PREFIX_RE.sub("", normalized, count=1)
updated = _CYPHER_PLAN_RE.sub("", updated, count=1)
updated = updated.strip()
if updated == normalized:
break
normalized = updated
if normalized.endswith(";"):
normalized = normalized[:-1].strip()
lines = [line.strip() for line in normalized.splitlines() if line.strip()]
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Context formatting
# ---------------------------------------------------------------------------
[docs]
def format_cypher_rows_for_context(
*,
rows: List[Dict[str, Any]],
max_rows: int,
max_chars: int,
) -> str:
"""Serialize Cypher result *rows* into a compact text block for LLM context."""
limited_rows = rows[: max(int(max_rows), 1)]
sections: List[str] = []
remaining_chars = max(int(max_chars), 0)
for index, row in enumerate(limited_rows, start=1):
if remaining_chars <= 0:
break
serialized = json.dumps(row, ensure_ascii=False, default=str)
line = f"[C{index}] {serialized}"
if len(line) > remaining_chars:
line = line[:remaining_chars]
sections.append(line)
remaining_chars -= len(line) + 1
return "\n".join(sections)