Source code for lalandre_rag.graph.helpers

"""
Graph-mode helper functions for the RAG service.

Contains NL→Cypher prompt building, Cypher extraction from LLM output,
and Cypher-row context formatting.  These are domain-level utilities
that live in the *package* layer, not in the service layer.
"""

import json
import re
from typing import Any, Dict, List, Optional

from lalandre_core.utils import extract_json_object

from lalandre_rag.prompts import render_nl_to_cypher_prompt

# ---------------------------------------------------------------------------
# NL → Cypher prompt
# ---------------------------------------------------------------------------


[docs] def build_nl_to_cypher_prompt( *, question: str, max_graph_depth: int, row_limit: int, ) -> str: """Return the system prompt that translates a natural-language *question* into Cypher.""" bounded_depth = max(int(max_graph_depth), 1) bounded_limit = max(int(row_limit), 1) return render_nl_to_cypher_prompt( question=question, max_graph_depth=bounded_depth, row_limit=bounded_limit, )
# --------------------------------------------------------------------------- # Cypher extraction from LLM output # --------------------------------------------------------------------------- _CYPHER_PREFIX_RE = re.compile( r"^\s*(?:cypher|query|cypher\s+query|generated_cypher)\s*:?\s*", flags=re.IGNORECASE, ) _CYPHER_PLAN_RE = re.compile(r"^\s*(?:EXPLAIN|PROFILE)\b\s*", flags=re.IGNORECASE)
[docs] def normalize_cypher_candidate(candidate: str) -> str: """Strip common model artifacts around an otherwise valid Cypher query.""" normalized = candidate.strip() if not normalized: return "" if normalized.startswith("```"): normalized = re.sub( r"^\s*```(?:cypher)?\s*", "", normalized, count=1, flags=re.IGNORECASE, ) normalized = re.sub(r"\s*```\s*$", "", normalized, count=1) while True: updated = _CYPHER_PREFIX_RE.sub("", normalized, count=1) updated = _CYPHER_PLAN_RE.sub("", updated, count=1) updated = updated.strip() if updated == normalized: break normalized = updated if normalized.endswith(";"): normalized = normalized[:-1].strip() lines = [line.strip() for line in normalized.splitlines() if line.strip()] return "\n".join(lines)
[docs] def extract_cypher(text: str) -> Optional[str]: """Attempt to extract a Cypher query from raw LLM *text*. Tries, in order: 1. JSON object with a ``"cypher"`` key. 2. Fenced code block (```cypher … ```). 3. Bare Cypher starting with a keyword (MATCH, WITH, …). """ payload = extract_json_object(text) if payload is not None: cypher_val = payload.get("cypher") if isinstance(cypher_val, str) and cypher_val.strip(): normalized = normalize_cypher_candidate(cypher_val) return normalized or None fence_match = re.search(r"```(?:cypher)?\s*(.*?)```", text, flags=re.IGNORECASE | re.DOTALL) if fence_match: fenced = fence_match.group(1).strip() if fenced: normalized = normalize_cypher_candidate(fenced) return normalized or None match_start = re.search( r"\b(MATCH|OPTIONAL MATCH|WITH|UNWIND|CALL)\b", text, flags=re.IGNORECASE, ) if match_start: candidate = text[match_start.start() :].strip() if candidate: normalized = normalize_cypher_candidate(candidate) return normalized or None return None
# --------------------------------------------------------------------------- # Context formatting # ---------------------------------------------------------------------------
[docs] def format_cypher_rows_for_context( *, rows: List[Dict[str, Any]], max_rows: int, max_chars: int, ) -> str: """Serialize Cypher result *rows* into a compact text block for LLM context.""" limited_rows = rows[: max(int(max_rows), 1)] sections: List[str] = [] remaining_chars = max(int(max_chars), 0) for index, row in enumerate(limited_rows, start=1): if remaining_chars <= 0: break serialized = json.dumps(row, ensure_ascii=False, default=str) line = f"[C{index}] {serialized}" if len(line) > remaining_chars: line = line[:remaining_chars] sections.append(line) remaining_chars -= len(line) + 1 return "\n".join(sections)