"""
Optional Neo4j GraphRAG integration for graph-mode retrieval.
This module prefers official Neo4j GraphRAG retrievers when the dependency is
installed, while keeping the existing Lalandre graph pipeline as a fallback.
"""
import json
import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, LiteralString, Optional, cast
from lalandre_core.llm import normalize_base_url, normalize_provider, resolve_api_key
from lalandre_core.utils.api_key_pool import APIKeyPool
from lalandre_core.utils.shared_key_pool import SharedKeyPoolProxy, build_clients_by_key
from neo4j import Driver, Query
from neo4j_graphrag.llm.mistralai_llm import MistralAILLM
from neo4j_graphrag.llm.openai_llm import OpenAILLM
from neo4j_graphrag.retrievers import Text2CypherRetriever
from neo4j_graphrag.schema import get_schema
from neo4j_graphrag.types import RetrieverResultItem
from lalandre_rag.prompts import get_text2cypher_prompt_template
from .helpers import normalize_cypher_candidate
logger = logging.getLogger(__name__)
[docs]
@dataclass(frozen=True)
class Text2CypherSearchOutput:
"""Result payload for official Text2Cypher retrieval."""
generated_cypher: str
rows: List[Dict[str, Any]]
metadata: Dict[str, Any]
class _ReadOnlyDriverProxy(Driver):
"""Proxy Neo4j driver that preserves a read-only Cypher validator."""
@staticmethod
def _to_query(cypher: str) -> Query:
"""Build a Neo4j Query from validated dynamic cypher text."""
return Query(cast(LiteralString, cypher))
def __init__(
self,
delegate: Driver,
validator: Callable[[str], str],
) -> None:
self._delegate = delegate
self._validator = validator
self._sync_driver_state()
def _sync_driver_state(self) -> None:
"""Mirror internal Driver state expected by neo4j-graphrag internals."""
self._pool = getattr(self._delegate, "_pool", None)
self._default_workspace_config = getattr(
self._delegate,
"_default_workspace_config",
None,
)
self._query_bookmark_manager = getattr(
self._delegate,
"_query_bookmark_manager",
None,
)
self._closed = bool(getattr(self._delegate, "_closed", False))
def __del__(self) -> None:
"""The delegate driver owns the real pool lifecycle and leak warnings."""
return None
def __getattr__(self, name: str) -> Any:
return getattr(self._delegate, name)
def session(self, *args: Any, **kwargs: Any) -> Any:
return self._delegate.session(*args, **kwargs)
def close(self) -> None:
self._closed = True
self._delegate.close()
def verify_connectivity(self, *args: Any, **kwargs: Any) -> Any:
return self._delegate.verify_connectivity(*args, **kwargs)
def get_server_info(self, *args: Any, **kwargs: Any) -> Any:
return self._delegate.get_server_info(*args, **kwargs)
def supports_session_auth(self) -> bool:
supports = getattr(self._delegate, "supports_session_auth", None)
if callable(supports):
return bool(supports())
return False
def execute_query(
self,
query_: Any,
parameters_: Optional[Dict[str, Any]] = None,
routing_: Any = None,
database_: Optional[str] = None,
impersonated_user_: Optional[str] = None,
bookmark_manager_: Any = None,
auth_: Any = None,
result_transformer_: Any = None,
**kwargs: Any,
) -> Any:
query_text = getattr(query_, "text", None)
if not isinstance(query_text, str):
query_text = str(query_ if not isinstance(query_, Query) else query_.text)
validated_query = self._validator(normalize_cypher_candidate(query_text))
execute_kwargs: Dict[str, Any] = {"query_": self._to_query(validated_query)}
if parameters_ is not None:
execute_kwargs["parameters_"] = parameters_
if routing_ is not None:
execute_kwargs["routing_"] = routing_
if database_ is not None:
execute_kwargs["database_"] = database_
if impersonated_user_ is not None:
execute_kwargs["impersonated_user_"] = impersonated_user_
if bookmark_manager_ is not None:
execute_kwargs["bookmark_manager_"] = bookmark_manager_
if auth_ is not None:
execute_kwargs["auth_"] = auth_
if result_transformer_ is not None:
execute_kwargs["result_transformer_"] = result_transformer_
execute_kwargs.update(kwargs)
return self._delegate.execute_query(**execute_kwargs)
[docs]
class Neo4jGraphRAGAdapter:
"""Bridge between Lalandre graph mode and official Neo4j GraphRAG retrievers."""
def __init__(
self,
*,
neo4j_driver: Driver,
neo4j_database: Optional[str],
qdrant_client: Any,
qdrant_collection_name: str,
llm_provider: str,
llm_model: str,
llm_temperature: float,
llm_max_tokens: int,
llm_api_key: Optional[str],
mistral_api_key: Optional[str],
llm_base_url: Optional[str],
key_pool: Optional[APIKeyPool],
read_only_validator: Callable[[str], str],
row_serializer: Callable[[Any], Any],
) -> None:
self._neo4j_driver = neo4j_driver
self._neo4j_database = neo4j_database
self._qdrant_client = qdrant_client
self._qdrant_collection_name = qdrant_collection_name
self._provider = normalize_provider(llm_provider)
self._model = llm_model.strip()
self._temperature = float(llm_temperature)
self._max_tokens = int(llm_max_tokens)
self._api_key = llm_api_key
self._mistral_api_key = mistral_api_key
self._base_url = llm_base_url
self._key_pool = key_pool
self._read_only_driver = _ReadOnlyDriverProxy(neo4j_driver, read_only_validator)
self._row_serializer = row_serializer
self._schema_cache: Optional[str] = None
self._llm_cache: Any = None
[docs]
def is_available(self) -> bool:
"""Return whether the adapter is ready to serve official GraphRAG calls."""
return True
[docs]
def text_to_cypher_search(
self,
*,
question: str,
max_graph_depth: int,
row_limit: int,
) -> Text2CypherSearchOutput:
"""Run the official Neo4j Text2Cypher retriever and normalize its output."""
retriever = Text2CypherRetriever(
driver=self._read_only_driver,
llm=self._get_llm(),
neo4j_schema=self._get_schema(),
result_formatter=self._build_cypher_row_formatter(),
custom_prompt=self._build_text2cypher_prompt(),
neo4j_database=self._neo4j_database,
)
result = retriever.search(
query_text=question,
prompt_params={
"max_graph_depth": max(int(max_graph_depth), 1),
"row_limit": max(int(row_limit), 1),
},
)
rows: List[Dict[str, Any]] = []
for item in self._extract_items(result):
metadata = self._item_metadata(item)
row = metadata.get("row")
if isinstance(row, dict):
rows.append(cast(Dict[str, Any], row))
metadata = self._result_metadata(result)
generated_cypher = ""
for key in ("cypher", "generated_cypher", "generated_query"):
candidate = metadata.get(key)
if isinstance(candidate, str) and candidate.strip():
generated_cypher = normalize_cypher_candidate(candidate)
break
if not generated_cypher:
raise ValueError("Neo4j GraphRAG did not return a generated Cypher query")
return Text2CypherSearchOutput(
generated_cypher=generated_cypher,
rows=rows,
metadata=metadata,
)
def _result_item_class(self) -> type:
return RetrieverResultItem
def _get_llm(self) -> Any:
if self._llm_cache is not None:
return self._llm_cache
effective_key = resolve_api_key(
provider=self._provider,
api_key=self._api_key,
mistral_api_key=self._mistral_api_key,
)
if self._provider == "mistral":
if self._key_pool is not None and len(self._key_pool) > 1:
clients_by_key = build_clients_by_key(
key_pool=self._key_pool,
factory=lambda key: MistralAILLM(
model_name=self._model,
api_key=key,
model_params={
"temperature": self._temperature,
"max_tokens": self._max_tokens,
},
),
)
self._llm_cache = SharedKeyPoolProxy(
key_pool=self._key_pool,
clients_by_key=clients_by_key,
)
else:
self._llm_cache = MistralAILLM(
model_name=self._model,
api_key=effective_key,
model_params={
"temperature": self._temperature,
"max_tokens": self._max_tokens,
},
)
return self._llm_cache
if self._provider == "openai_compatible":
resolved_base_url = normalize_base_url(
provider=self._provider,
base_url=self._base_url or "",
)
self._llm_cache = OpenAILLM(
model_name=self._model,
api_key=effective_key,
base_url=resolved_base_url,
model_params={
"temperature": self._temperature,
"max_tokens": self._max_tokens,
},
)
return self._llm_cache
raise ValueError(f"Unsupported Neo4j GraphRAG provider {self._provider!r}. Use mistral or openai_compatible.")
def _get_schema(self) -> str:
if self._schema_cache is not None:
return self._schema_cache
try:
self._schema_cache = str(get_schema(self._neo4j_driver, database=self._neo4j_database))
except Exception as exc:
logger.warning(
"Neo4j GraphRAG schema introspection failed, fallback to static schema: %s",
exc,
)
self._schema_cache = self._fallback_schema()
return self._schema_cache
@staticmethod
def _build_text2cypher_prompt() -> str:
return get_text2cypher_prompt_template()
def _build_cypher_row_formatter(self) -> Callable[[Any], Any]:
result_item_cls = self._result_item_class()
def _formatter(record: Any) -> Any:
row: Dict[str, Any] = {}
for key in record.keys():
row[str(key)] = self._row_serializer(record.get(key))
return result_item_cls(
content=json.dumps(row, ensure_ascii=False, default=str),
metadata={"row": row},
)
return _formatter
@staticmethod
def _extract_items(result: Any) -> List[Any]:
items = getattr(result, "items", None)
return list(cast(List[Any], items)) if isinstance(items, list) else []
@staticmethod
def _item_metadata(item: Any) -> Dict[str, Any]:
metadata = getattr(item, "metadata", None)
return cast(Dict[str, Any], metadata) if isinstance(metadata, dict) else {}
@staticmethod
def _result_metadata(result: Any) -> Dict[str, Any]:
metadata = getattr(result, "metadata", None)
return cast(Dict[str, Any], metadata) if isinstance(metadata, dict) else {}
@staticmethod
def _fallback_schema() -> str:
return (
"Node properties:\n"
"Act {id: INTEGER, celex: STRING, title: STRING, act_type: STRING, "
"language: STRING, adoption_date: DATE, force_date: DATE, end_date: DATE, "
"sector: INTEGER, level: INTEGER, official_journal_reference: STRING, "
"eli: STRING, url_eurlex: STRING}\n"
"Entity {name: STRING, entity_type: STRING, description: STRING}\n"
"Community {id: INTEGER, num_acts: INTEGER, num_relations: INTEGER, "
"relation_types: STRING, central_acts: STRING, summary: STRING, "
"modularity: FLOAT, resolution: FLOAT}\n"
"Relationship properties:\n"
"AMENDS {relation_type: STRING, effect_date: DATE, description: STRING, "
"source_subdivision_id: INTEGER, target_subdivision_id: INTEGER}\n"
"REPEALS {relation_type: STRING, effect_date: DATE, description: STRING}\n"
"REPLACES {relation_type: STRING, effect_date: DATE, description: STRING}\n"
"IMPLEMENTS {relation_type: STRING, effect_date: DATE, description: STRING}\n"
"CITES {relation_type: STRING, description: STRING}\n"
"DEROGATES {relation_type: STRING, effect_date: DATE, description: STRING}\n"
"SUPPLEMENTS {relation_type: STRING, description: STRING}\n"
"CORRECTS {relation_type: STRING, description: STRING}\n"
"MENTIONS {} (Act)-[:MENTIONS]->(Entity)\n"
"BELONGS_TO {} (Act)-[:BELONGS_TO]->(Community)\n"
"The relationships AMENDS through CORRECTS are directed between :Act nodes."
)