Source code for lalandre_rag.graph.neo4j_adapter

"""
Optional Neo4j GraphRAG integration for graph-mode retrieval.

This module prefers official Neo4j GraphRAG retrievers when the dependency is
installed, while keeping the existing Lalandre graph pipeline as a fallback.
"""

import json
import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, LiteralString, Optional, cast

from lalandre_core.llm import normalize_base_url, normalize_provider, resolve_api_key
from lalandre_core.utils.api_key_pool import APIKeyPool
from lalandre_core.utils.shared_key_pool import SharedKeyPoolProxy, build_clients_by_key
from neo4j import Driver, Query
from neo4j_graphrag.llm.mistralai_llm import MistralAILLM
from neo4j_graphrag.llm.openai_llm import OpenAILLM
from neo4j_graphrag.retrievers import Text2CypherRetriever
from neo4j_graphrag.schema import get_schema
from neo4j_graphrag.types import RetrieverResultItem

from lalandre_rag.prompts import get_text2cypher_prompt_template

from .helpers import normalize_cypher_candidate

logger = logging.getLogger(__name__)


[docs] @dataclass(frozen=True) class Text2CypherSearchOutput: """Result payload for official Text2Cypher retrieval.""" generated_cypher: str rows: List[Dict[str, Any]] metadata: Dict[str, Any]
class _ReadOnlyDriverProxy(Driver): """Proxy Neo4j driver that preserves a read-only Cypher validator.""" @staticmethod def _to_query(cypher: str) -> Query: """Build a Neo4j Query from validated dynamic cypher text.""" return Query(cast(LiteralString, cypher)) def __init__( self, delegate: Driver, validator: Callable[[str], str], ) -> None: self._delegate = delegate self._validator = validator self._sync_driver_state() def _sync_driver_state(self) -> None: """Mirror internal Driver state expected by neo4j-graphrag internals.""" self._pool = getattr(self._delegate, "_pool", None) self._default_workspace_config = getattr( self._delegate, "_default_workspace_config", None, ) self._query_bookmark_manager = getattr( self._delegate, "_query_bookmark_manager", None, ) self._closed = bool(getattr(self._delegate, "_closed", False)) def __del__(self) -> None: """The delegate driver owns the real pool lifecycle and leak warnings.""" return None def __getattr__(self, name: str) -> Any: return getattr(self._delegate, name) def session(self, *args: Any, **kwargs: Any) -> Any: return self._delegate.session(*args, **kwargs) def close(self) -> None: self._closed = True self._delegate.close() def verify_connectivity(self, *args: Any, **kwargs: Any) -> Any: return self._delegate.verify_connectivity(*args, **kwargs) def get_server_info(self, *args: Any, **kwargs: Any) -> Any: return self._delegate.get_server_info(*args, **kwargs) def supports_session_auth(self) -> bool: supports = getattr(self._delegate, "supports_session_auth", None) if callable(supports): return bool(supports()) return False def execute_query( self, query_: Any, parameters_: Optional[Dict[str, Any]] = None, routing_: Any = None, database_: Optional[str] = None, impersonated_user_: Optional[str] = None, bookmark_manager_: Any = None, auth_: Any = None, result_transformer_: Any = None, **kwargs: Any, ) -> Any: query_text = getattr(query_, "text", None) if not isinstance(query_text, str): query_text = str(query_ if not isinstance(query_, Query) else query_.text) validated_query = self._validator(normalize_cypher_candidate(query_text)) execute_kwargs: Dict[str, Any] = {"query_": self._to_query(validated_query)} if parameters_ is not None: execute_kwargs["parameters_"] = parameters_ if routing_ is not None: execute_kwargs["routing_"] = routing_ if database_ is not None: execute_kwargs["database_"] = database_ if impersonated_user_ is not None: execute_kwargs["impersonated_user_"] = impersonated_user_ if bookmark_manager_ is not None: execute_kwargs["bookmark_manager_"] = bookmark_manager_ if auth_ is not None: execute_kwargs["auth_"] = auth_ if result_transformer_ is not None: execute_kwargs["result_transformer_"] = result_transformer_ execute_kwargs.update(kwargs) return self._delegate.execute_query(**execute_kwargs)
[docs] class Neo4jGraphRAGAdapter: """Bridge between Lalandre graph mode and official Neo4j GraphRAG retrievers.""" def __init__( self, *, neo4j_driver: Driver, neo4j_database: Optional[str], qdrant_client: Any, qdrant_collection_name: str, llm_provider: str, llm_model: str, llm_temperature: float, llm_max_tokens: int, llm_api_key: Optional[str], mistral_api_key: Optional[str], llm_base_url: Optional[str], key_pool: Optional[APIKeyPool], read_only_validator: Callable[[str], str], row_serializer: Callable[[Any], Any], ) -> None: self._neo4j_driver = neo4j_driver self._neo4j_database = neo4j_database self._qdrant_client = qdrant_client self._qdrant_collection_name = qdrant_collection_name self._provider = normalize_provider(llm_provider) self._model = llm_model.strip() self._temperature = float(llm_temperature) self._max_tokens = int(llm_max_tokens) self._api_key = llm_api_key self._mistral_api_key = mistral_api_key self._base_url = llm_base_url self._key_pool = key_pool self._read_only_driver = _ReadOnlyDriverProxy(neo4j_driver, read_only_validator) self._row_serializer = row_serializer self._schema_cache: Optional[str] = None self._llm_cache: Any = None
[docs] def is_available(self) -> bool: """Return whether the adapter is ready to serve official GraphRAG calls.""" return True
def _result_item_class(self) -> type: return RetrieverResultItem def _get_llm(self) -> Any: if self._llm_cache is not None: return self._llm_cache effective_key = resolve_api_key( provider=self._provider, api_key=self._api_key, mistral_api_key=self._mistral_api_key, ) if self._provider == "mistral": if self._key_pool is not None and len(self._key_pool) > 1: clients_by_key = build_clients_by_key( key_pool=self._key_pool, factory=lambda key: MistralAILLM( model_name=self._model, api_key=key, model_params={ "temperature": self._temperature, "max_tokens": self._max_tokens, }, ), ) self._llm_cache = SharedKeyPoolProxy( key_pool=self._key_pool, clients_by_key=clients_by_key, ) else: self._llm_cache = MistralAILLM( model_name=self._model, api_key=effective_key, model_params={ "temperature": self._temperature, "max_tokens": self._max_tokens, }, ) return self._llm_cache if self._provider == "openai_compatible": resolved_base_url = normalize_base_url( provider=self._provider, base_url=self._base_url or "", ) self._llm_cache = OpenAILLM( model_name=self._model, api_key=effective_key, base_url=resolved_base_url, model_params={ "temperature": self._temperature, "max_tokens": self._max_tokens, }, ) return self._llm_cache raise ValueError(f"Unsupported Neo4j GraphRAG provider {self._provider!r}. Use mistral or openai_compatible.") def _get_schema(self) -> str: if self._schema_cache is not None: return self._schema_cache try: self._schema_cache = str(get_schema(self._neo4j_driver, database=self._neo4j_database)) except Exception as exc: logger.warning( "Neo4j GraphRAG schema introspection failed, fallback to static schema: %s", exc, ) self._schema_cache = self._fallback_schema() return self._schema_cache @staticmethod def _build_text2cypher_prompt() -> str: return get_text2cypher_prompt_template() def _build_cypher_row_formatter(self) -> Callable[[Any], Any]: result_item_cls = self._result_item_class() def _formatter(record: Any) -> Any: row: Dict[str, Any] = {} for key in record.keys(): row[str(key)] = self._row_serializer(record.get(key)) return result_item_cls( content=json.dumps(row, ensure_ascii=False, default=str), metadata={"row": row}, ) return _formatter @staticmethod def _extract_items(result: Any) -> List[Any]: items = getattr(result, "items", None) return list(cast(List[Any], items)) if isinstance(items, list) else [] @staticmethod def _item_metadata(item: Any) -> Dict[str, Any]: metadata = getattr(item, "metadata", None) return cast(Dict[str, Any], metadata) if isinstance(metadata, dict) else {} @staticmethod def _result_metadata(result: Any) -> Dict[str, Any]: metadata = getattr(result, "metadata", None) return cast(Dict[str, Any], metadata) if isinstance(metadata, dict) else {} @staticmethod def _fallback_schema() -> str: return ( "Node properties:\n" "Act {id: INTEGER, celex: STRING, title: STRING, act_type: STRING, " "language: STRING, adoption_date: DATE, force_date: DATE, end_date: DATE, " "sector: INTEGER, level: INTEGER, official_journal_reference: STRING, " "eli: STRING, url_eurlex: STRING}\n" "Entity {name: STRING, entity_type: STRING, description: STRING}\n" "Community {id: INTEGER, num_acts: INTEGER, num_relations: INTEGER, " "relation_types: STRING, central_acts: STRING, summary: STRING, " "modularity: FLOAT, resolution: FLOAT}\n" "Relationship properties:\n" "AMENDS {relation_type: STRING, effect_date: DATE, description: STRING, " "source_subdivision_id: INTEGER, target_subdivision_id: INTEGER}\n" "REPEALS {relation_type: STRING, effect_date: DATE, description: STRING}\n" "REPLACES {relation_type: STRING, effect_date: DATE, description: STRING}\n" "IMPLEMENTS {relation_type: STRING, effect_date: DATE, description: STRING}\n" "CITES {relation_type: STRING, description: STRING}\n" "DEROGATES {relation_type: STRING, effect_date: DATE, description: STRING}\n" "SUPPLEMENTS {relation_type: STRING, description: STRING}\n" "CORRECTS {relation_type: STRING, description: STRING}\n" "MENTIONS {} (Act)-[:MENTIONS]->(Entity)\n" "BELONGS_TO {} (Act)-[:BELONGS_TO]->(Community)\n" "The relationships AMENDS through CORRECTS are directed between :Act nodes." )