Source code for lalandre_rag.prompts

"""
Centralized prompt loaders for lalandre_rag.
All prompt text lives in ``prompts/`` to keep code and content separated.
"""

import importlib.resources
import json
from functools import lru_cache
from typing import Any, Sequence, Tuple, cast

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from llama_index.core.prompts import PromptTemplate

_LangChainMessage = Tuple[str, str]


@lru_cache(maxsize=None)
def _load_text(relative_path: str) -> str:
    """Load a prompt text file from the package data, cached in memory."""
    ref = importlib.resources.files(__package__).joinpath(relative_path)
    return ref.read_text(encoding="utf-8")


def _build_chat_prompt(messages: Sequence[Any]) -> ChatPromptTemplate:
    # langchain-core renamed ``ChatPromptTemplate.from_messages`` in some versions;
    # getattr keeps compatibility across versions.
    factory: Any = getattr(ChatPromptTemplate, "from_messages")
    return cast(ChatPromptTemplate, factory(messages))


[docs] def get_langchain_prompt( prompt_type: str, *, with_history: bool = False, ) -> ChatPromptTemplate: """Return the LangChain chat prompt for the given *prompt_type*. When *with_history* is ``True`` a ``MessagesPlaceholder("chat_history")`` is inserted between the system and human messages so that conversation history can be injected at invocation time. The placeholder is marked ``optional=True`` so that callers without history can simply omit the key (or pass an empty list) and the prompt remains unchanged. """ if prompt_type in ("rag", "rag_graph"): # Unified RAG prompt — graph_context is injected when available (empty otherwise) system_text = _load_text("langchain/rag_system.txt") human_text = _load_text("langchain/rag_human.txt") else: system_text = _load_text("langchain/default_system.txt") human_text = _load_text("langchain/default_human.txt") messages: list[Any] = [("system", system_text)] if with_history: messages.append( MessagesPlaceholder(variable_name="chat_history", optional=True), ) messages.append(("human", human_text)) return _build_chat_prompt(messages)
[docs] def get_llamaindex_prompt(prompt_type: str) -> PromptTemplate: """Return the LlamaIndex prompt template for summary/comparison (with fallback).""" if prompt_type == "summary": template = _load_text("llamaindex/summary.txt") elif prompt_type == "comparison": template = _load_text("llamaindex/comparison.txt") else: template = _load_text("llamaindex/default.txt") return PromptTemplate(template)
[docs] def render_llm_only_prompt(*, question: str) -> str: """Prompt used by LLM-only mode (no retrieval).""" template = _load_text("llm/llm_only.txt") return template.format(question=question)
[docs] def render_nl_to_cypher_prompt(*, question: str, max_graph_depth: int, row_limit: int) -> str: """System prompt to translate natural language to Cypher (graph_helpers).""" template = _load_text("graph/nl_to_cypher.txt") return template.format( max_graph_depth=max(int(max_graph_depth), 1), row_limit=max(int(row_limit), 1), question=json.dumps(question, ensure_ascii=False), )
[docs] def render_planner_prompt(*, question: str) -> str: """Prompt for the retrieval planner that decides multi-step strategy.""" template = _load_text("retrieval/planner.txt") return template.format(question=question)
[docs] def render_compressor_prompt( *, celex: str, title: str, level: str, fragments: str, max_chars: int, ) -> str: """Prompt for context compression of multiple fragments from one act.""" template = _load_text("retrieval/compressor.txt") return template.format( celex=celex, title=title, level=level or "inconnu", fragments=fragments, max_chars=max_chars, )
[docs] def render_intent_parser_prompt( *, question: str, top_k: int, requested_granularity: str, ) -> str: """Prompt for the LLM intent parser used in query_parser.""" template = _load_text("retrieval/intent_parser.txt") return template.format( question=json.dumps(question, ensure_ascii=False), top_k=max(int(top_k), 1), requested_granularity=requested_granularity, )
[docs] def get_text2cypher_prompt_template() -> str: """Template consumed by neo4j_graphrag Text2Cypher retriever.""" return _load_text("graph/text2cypher.txt")
__all__ = [ "get_langchain_prompt", "get_llamaindex_prompt", "render_llm_only_prompt", "render_planner_prompt", "render_compressor_prompt", "render_nl_to_cypher_prompt", "render_intent_parser_prompt", "get_text2cypher_prompt_template", ]