"""
Centralized prompt loaders for lalandre_rag.
All prompt text lives in ``prompts/`` to keep code and content separated.
"""
import importlib.resources
import json
from functools import lru_cache
from typing import Any, Sequence, Tuple, cast
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from llama_index.core.prompts import PromptTemplate
_LangChainMessage = Tuple[str, str]
@lru_cache(maxsize=None)
def _load_text(relative_path: str) -> str:
"""Load a prompt text file from the package data, cached in memory."""
ref = importlib.resources.files(__package__).joinpath(relative_path)
return ref.read_text(encoding="utf-8")
def _build_chat_prompt(messages: Sequence[Any]) -> ChatPromptTemplate:
# langchain-core renamed ``ChatPromptTemplate.from_messages`` in some versions;
# getattr keeps compatibility across versions.
factory: Any = getattr(ChatPromptTemplate, "from_messages")
return cast(ChatPromptTemplate, factory(messages))
[docs]
def get_langchain_prompt(
prompt_type: str,
*,
with_history: bool = False,
) -> ChatPromptTemplate:
"""Return the LangChain chat prompt for the given *prompt_type*.
When *with_history* is ``True`` a ``MessagesPlaceholder("chat_history")``
is inserted between the system and human messages so that conversation
history can be injected at invocation time. The placeholder is marked
``optional=True`` so that callers without history can simply omit the
key (or pass an empty list) and the prompt remains unchanged.
"""
if prompt_type in ("rag", "rag_graph"):
# Unified RAG prompt — graph_context is injected when available (empty otherwise)
system_text = _load_text("langchain/rag_system.txt")
human_text = _load_text("langchain/rag_human.txt")
else:
system_text = _load_text("langchain/default_system.txt")
human_text = _load_text("langchain/default_human.txt")
messages: list[Any] = [("system", system_text)]
if with_history:
messages.append(
MessagesPlaceholder(variable_name="chat_history", optional=True),
)
messages.append(("human", human_text))
return _build_chat_prompt(messages)
[docs]
def get_llamaindex_prompt(prompt_type: str) -> PromptTemplate:
"""Return the LlamaIndex prompt template for summary/comparison (with fallback)."""
if prompt_type == "summary":
template = _load_text("llamaindex/summary.txt")
elif prompt_type == "comparison":
template = _load_text("llamaindex/comparison.txt")
else:
template = _load_text("llamaindex/default.txt")
return PromptTemplate(template)
[docs]
def render_llm_only_prompt(*, question: str) -> str:
"""Prompt used by LLM-only mode (no retrieval)."""
template = _load_text("llm/llm_only.txt")
return template.format(question=question)
[docs]
def render_nl_to_cypher_prompt(*, question: str, max_graph_depth: int, row_limit: int) -> str:
"""System prompt to translate natural language to Cypher (graph_helpers)."""
template = _load_text("graph/nl_to_cypher.txt")
return template.format(
max_graph_depth=max(int(max_graph_depth), 1),
row_limit=max(int(row_limit), 1),
question=json.dumps(question, ensure_ascii=False),
)
[docs]
def render_planner_prompt(*, question: str) -> str:
"""Prompt for the retrieval planner that decides multi-step strategy."""
template = _load_text("retrieval/planner.txt")
return template.format(question=question)
[docs]
def render_compressor_prompt(
*,
celex: str,
title: str,
level: str,
fragments: str,
max_chars: int,
) -> str:
"""Prompt for context compression of multiple fragments from one act."""
template = _load_text("retrieval/compressor.txt")
return template.format(
celex=celex,
title=title,
level=level or "inconnu",
fragments=fragments,
max_chars=max_chars,
)
[docs]
def render_intent_parser_prompt(
*,
question: str,
top_k: int,
requested_granularity: str,
) -> str:
"""Prompt for the LLM intent parser used in query_parser."""
template = _load_text("retrieval/intent_parser.txt")
return template.format(
question=json.dumps(question, ensure_ascii=False),
top_k=max(int(top_k), 1),
requested_granularity=requested_granularity,
)
[docs]
def get_text2cypher_prompt_template() -> str:
"""Template consumed by neo4j_graphrag Text2Cypher retriever."""
return _load_text("graph/text2cypher.txt")
__all__ = [
"get_langchain_prompt",
"get_llamaindex_prompt",
"render_llm_only_prompt",
"render_planner_prompt",
"render_compressor_prompt",
"render_nl_to_cypher_prompt",
"render_intent_parser_prompt",
"get_text2cypher_prompt_template",
]