Source code for lalandre_db_neo4j.repository

"""
Neo4j Repository
Handles all interactions with the Neo4j graph database
"""

import logging
import re
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, LiteralString, Optional, cast

from lalandre_core.config import GraphConfig, get_config
from lalandre_core.repositories.base import BaseRepository
from neo4j import Driver, GraphDatabase, Query, Session
from neo4j.graph import Node, Path, Relationship

from .models import (
    ActNode,
    ActRelationship,
    CommunityNode,
    EntityNode,
    GraphQueryResult,
)

logger = logging.getLogger(__name__)

_LEADING_READ_ONLY_CLAUSE_RE = re.compile(
    r"^\s*(?:MATCH|OPTIONAL\s+MATCH|WITH|UNWIND)\b|^\s*CALL\s*\{",
    flags=re.IGNORECASE,
)
_FORBIDDEN_WRITE_CLAUSE_RE = re.compile(
    r"\b(CREATE|MERGE|SET|DELETE|DETACH|REMOVE|DROP|FOREACH)\b|"
    r"\bLOAD\s+CSV\b",
    flags=re.IGNORECASE,
)
_PROCEDURAL_CALL_RE = re.compile(r"\bCALL\b(?!\s*\{)", flags=re.IGNORECASE)
_BLOCK_COMMENT_RE = re.compile(r"/\*.*?\*/", flags=re.DOTALL)
_INLINE_COMMENT_RE = re.compile(r"(?m)\s+//.*$")
_HASH_COMMENT_RE = re.compile(r"(?m)^\s*#.*$")
_LEADING_PLAN_RE = re.compile(r"^\s*(?:EXPLAIN|PROFILE)\b\s*", flags=re.IGNORECASE)


[docs] class Neo4jRepository(BaseRepository): """ Repository for Neo4j graph operations Responsibilities: - Manage Neo4j driver and sessions - Create/update Act nodes and relationships - Execute graph traversal queries - Support Graph RAG operations Note: Subdivisions and Versions are managed in PostgreSQL/Qdrant only. """ def __init__(self, settings: Optional[GraphConfig] = None): """ Initialize Neo4j connection Args: settings: Neo4j connection settings (defaults to global config) """ if settings is None: settings = get_config().graph self.settings = settings self.driver: Optional[Driver] = None self._connect() def _connect(self): """Establish connection to Neo4j""" try: uri = self.settings.uri user = self.settings.user password = self.settings.password if uri is None or user is None or password is None: raise ValueError("Neo4j configuration incomplete: uri, user, and password are required") self.driver = GraphDatabase.driver( uri, auth=(user, password), max_connection_lifetime=self.settings.max_connection_lifetime, max_connection_pool_size=self.settings.max_connection_pool_size, connection_timeout=self.settings.connection_timeout, ) # Test connection self.driver.verify_connectivity() logger.info(f"Connected to Neo4j at {self.settings.uri}") except Exception as e: logger.error(f"Failed to connect to Neo4j: {e}") raise
[docs] @contextmanager def get_session(self) -> Generator[Session, None, None]: """Context manager for Neo4j sessions""" if not self.driver: raise RuntimeError("Neo4j driver not initialized") session = self.driver.session(database=self.settings.database) try: yield session finally: session.close()
[docs] def close(self): """Close the Neo4j driver""" if self.driver: self.driver.close() logger.info("Neo4j connection closed")
[docs] def health_check(self) -> bool: """Verify Neo4j connectivity""" try: if not self.driver: return False self.driver.verify_connectivity() return True except Exception as e: logger.error(f"Neo4j health check failed: {e}") return False
@staticmethod def _to_query(cypher: str) -> Query: """Build a Neo4j Query from validated dynamic cypher text.""" return Query(cast(LiteralString, cypher)) @staticmethod def _validate_relationship_type(rel_type: str) -> str: """Allow only valid Cypher relationship-type identifiers.""" normalized = rel_type.upper() if not normalized or not normalized[0].isalpha(): raise ValueError(f"Invalid relationship type '{rel_type}'") if not all(ch.isalnum() or ch == "_" for ch in normalized): raise ValueError(f"Invalid relationship type '{rel_type}'") return normalized @staticmethod def _validate_read_only_cypher(cypher: str) -> str: """ Validate that a Cypher statement is read-only. The check is intentionally conservative: write clauses and procedural calls are rejected. """ normalized = (cypher or "").strip() if not normalized: raise ValueError("Cypher query is empty.") normalized = _BLOCK_COMMENT_RE.sub(" ", normalized) normalized = _INLINE_COMMENT_RE.sub("", normalized) normalized = _HASH_COMMENT_RE.sub("", normalized) normalized = _LEADING_PLAN_RE.sub("", normalized, count=1).strip() if normalized.endswith(";"): normalized = normalized[:-1].strip() if not normalized: raise ValueError("Cypher query is empty.") if ";" in normalized: raise ValueError("Multiple Cypher statements are not allowed.") if not re.search(r"\bRETURN\b", normalized, flags=re.IGNORECASE): raise ValueError("Cypher query must include a RETURN clause.") if _PROCEDURAL_CALL_RE.search(normalized): raise ValueError( "Cypher query contains procedural CALL clauses; only read-only CALL { ... } subqueries are allowed." ) if not _LEADING_READ_ONLY_CLAUSE_RE.match(normalized): raise ValueError("Cypher query must start with MATCH, OPTIONAL MATCH, WITH, UNWIND, or CALL {.") if _FORBIDDEN_WRITE_CLAUSE_RE.search(normalized): raise ValueError("Cypher query contains write clauses; only read-only queries are allowed.") return normalized @classmethod def _serialize_neo4j_value(cls, value: Any) -> Any: """Convert Neo4j values to JSON-friendly payloads.""" if value is None or isinstance(value, (str, int, float, bool)): return value if isinstance(value, list): items = cast(list[Any], value) return [cls._serialize_neo4j_value(item) for item in items] if isinstance(value, tuple): items = cast(tuple[Any, ...], value) return [cls._serialize_neo4j_value(item) for item in items] if isinstance(value, dict): mapping = cast(dict[Any, Any], value) return {str(key): cls._serialize_neo4j_value(item) for key, item in mapping.items()} if isinstance(value, Node): node_properties = cast(dict[Any, Any], dict(value)) return { "_kind": "node", "labels": sorted(str(label) for label in value.labels), "properties": {str(key): cls._serialize_neo4j_value(item) for key, item in node_properties.items()}, } if isinstance(value, Relationship): relationship_properties = cast(dict[Any, Any], dict(value)) start_node_id = value.start_node.get("id") if value.start_node is not None else None end_node_id = value.end_node.get("id") if value.end_node is not None else None return { "_kind": "relationship", "type": value.type, "start_node_id": start_node_id, "end_node_id": end_node_id, "properties": { str(key): cls._serialize_neo4j_value(item) for key, item in relationship_properties.items() }, } if isinstance(value, Path): return { "_kind": "path", "nodes": [cls._serialize_neo4j_value(node) for node in value.nodes], "relationships": [cls._serialize_neo4j_value(rel) for rel in value.relationships], } return str(value)
[docs] def validate_read_only_cypher(self, cypher: str) -> str: """Public wrapper — validate that a Cypher statement is read-only.""" return self._validate_read_only_cypher(cypher)
[docs] def serialize_neo4j_value(self, value: Any) -> Any: """Public wrapper — convert a Neo4j value to a JSON-friendly payload.""" return self._serialize_neo4j_value(value)
[docs] def execute_read_only_query( self, cypher: str, *, params: Optional[Dict[str, Any]] = None, result_limit: Optional[int] = None, ) -> List[Dict[str, Any]]: """ Execute a validated read-only Cypher query and return JSON-safe rows. """ validated_query = self._validate_read_only_cypher(cypher) effective_limit = result_limit if result_limit is not None else get_config().graph.cypher_max_rows bounded_limit = max(1, min(int(effective_limit), 1000)) query_params = params or {} rows: List[Dict[str, Any]] = [] with self.get_session() as session: result = session.run(self._to_query(validated_query), **query_params) for index, record in enumerate(result): if index >= bounded_limit: break row: Dict[str, Any] = {} for key in record.keys(): row[str(key)] = self._serialize_neo4j_value(record.get(key)) rows.append(row) return rows
# === ACT NODE OPERATIONS ===
[docs] def create_act_node(self, act: ActNode) -> int: """ Create an Act node in the graph Args: act: ActNode data Returns: The act ID """ with self.get_session() as session: result = session.run( """ MERGE (a:Act {id: $id}) SET a += $props RETURN elementId(a) as neo4j_element_id, a.id as id """, id=act.id, props=act.to_neo4j_properties(), ) record = result.single() if record is None: raise RuntimeError(f"Failed to create Act node for {act.celex} (id={act.id})") logger.info(f"Created Act node: {act.celex} (id={act.id})") return int(record["id"])
# === ENTITY NODE OPERATIONS ===
[docs] def upsert_entity_mention(self, act_celex: str, entity: EntityNode) -> None: """ Idempotently create (or update) an Entity node and a MENTIONS edge from the Act. Uses MERGE on (name, type) so duplicate extractions are safe to call repeatedly. """ with self.get_session() as session: session.run( """ MATCH (a:Act {celex: $celex}) MERGE (e:Entity {name: $name, type: $entity_type}) ON CREATE SET e.description = $description MERGE (a)-[:MENTIONS]->(e) """, celex=act_celex, name=entity.name, entity_type=entity.entity_type, description=entity.description, )
# === RELATIONSHIP OPERATIONS ===
[docs] def create_act_relationship(self, relationship: ActRelationship) -> bool: """ Create a relationship between two Acts Args: relationship: ActRelationship data Returns: True if created successfully """ with self.get_session() as session: rel_type = self._validate_relationship_type(relationship.get_neo4j_type()) props = relationship.to_neo4j_properties() # Dynamic relationship type using apoc or variable query = self._to_query( f""" MATCH (source:Act {{id: $source_id}}) MATCH (target:Act {{id: $target_id}}) MERGE (source)-[r:{rel_type}]->(target) SET r += $props RETURN elementId(r) as rel_id """ ) result = session.run( query, source_id=relationship.source_act_id, target_id=relationship.target_act_id, props=props ) record = result.single() if record: logger.info( f"Created relationship {rel_type}: {relationship.source_act_id} -> {relationship.target_act_id}" ) return True return False
# === COMMUNITY OPERATIONS ===
[docs] def clear_act_relationships(self) -> int: """Delete all relationships between Act nodes and return the count removed.""" with self.get_session() as session: result = session.run("MATCH (:Act)-[r]->(:Act) DELETE r RETURN count(r) AS deleted") record = result.single() deleted = int(record["deleted"]) if record else 0 logger.info("Cleared %d Act-to-Act relationships", deleted) return deleted
[docs] def clear_communities(self) -> int: """Delete all Community nodes and BELONGS_TO relationships. Returns count deleted.""" with self.get_session() as session: result = session.run("MATCH (c:Community) DETACH DELETE c RETURN count(c) AS deleted") record = result.single() deleted = int(record["deleted"]) if record else 0 # Also clean up legacy community_id property on Act nodes session.run("MATCH (a:Act) WHERE a.community_id IS NOT NULL REMOVE a.community_id") logger.info("Cleared %d Community nodes and legacy community_id properties", deleted) return deleted
[docs] def upsert_community(self, community: CommunityNode) -> None: """Create or update a Community node.""" with self.get_session() as session: session.run( """ MERGE (c:Community {id: $id}) SET c += $props """, id=community.id, props=community.to_neo4j_properties(), )
[docs] def upsert_communities_batch( self, communities: List[CommunityNode], batch_size: int = 100, ) -> None: """Batch upsert Community nodes.""" with self.get_session() as session: for start in range(0, len(communities), batch_size): batch = [c.to_neo4j_properties() for c in communities[start : start + batch_size]] session.run( """ UNWIND $batch AS props MERGE (c:Community {id: props.id}) SET c += props """, batch=batch, )
[docs] def get_communities_for_acts(self, act_ids: List[int]) -> List[Dict[str, Any]]: """ Given seed act IDs, return the Community nodes they belong to. Returns a list of community dicts with all properties. """ if not act_ids: return [] try: with self.get_session() as session: result = session.run( """ UNWIND $ids AS aid MATCH (a:Act {id: aid})-[:BELONGS_TO]->(c:Community) RETURN DISTINCT c.id AS id, c.num_acts AS num_acts, c.summary AS summary, c.central_acts AS central_acts, c.num_relations AS num_relations ORDER BY c.num_acts DESC """, ids=act_ids, ) return [dict(record) for record in result] except Exception as e: logger.warning("Failed to get communities for acts: %s", e) return []
# === GRAPH TRAVERSAL QUERIES ===
[docs] def expand_from_acts( self, act_ids: List[int], max_depth: Optional[int] = None, ) -> GraphQueryResult: """ Expand graph context from multiple seed act IDs in a single query. Returns deduplicated nodes and relationships reachable within max_depth hops from any of the seed acts. """ if not act_ids: return GraphQueryResult() try: with self.get_session() as session: depth = max(int(max_depth if max_depth is not None else get_config().graph.depth), 1) result = session.run( self._to_query( """ MATCH path = (a:Act)-[*1..""" + str(depth) + """]-(related:Act) WHERE a.id IN $ids WITH collect(distinct a) + collect(distinct related) as all_acts, collect(distinct relationships(path)) as all_rels UNWIND all_acts as act WITH collect(distinct act) as nodes, all_rels RETURN nodes, all_rels """ ), ids=act_ids, ) record = result.single() if not record: return GraphQueryResult() nodes = [dict(act) for act in record["nodes"]] relationships: List[Dict[str, Any]] = [] seen_rels: set[tuple[Any, str, Any]] = set() for rel_list in record["all_rels"]: for rel in rel_list: key = (rel.start_node.get("id"), rel.type, rel.end_node.get("id")) if key not in seen_rels: seen_rels.add(key) relationships.append( { "type": rel.type, "properties": dict(rel), "start_node": rel.start_node.get("id"), "end_node": rel.end_node.get("id"), } ) return GraphQueryResult( nodes=nodes, relationships=relationships, metadata={ "seed_act_ids": act_ids, "max_depth": depth, "total_nodes": len(nodes), "total_relationships": len(relationships), }, ) except Exception as e: logger.warning("Neo4j batch expand failed: %s", e) return GraphQueryResult()
# === UTILITIES ===
[docs] def get_statistics(self) -> Dict[str, Any]: """ Returns: Dictionary with counts and metrics """ with self.get_session() as session: result = session.run( """ MATCH (a:Act) OPTIONAL MATCH ()-[r]->() RETURN count(distinct a) as act_count, count(distinct r) as relationship_count """ ) record = result.single() if record is None: return { "acts": 0, "relationships": 0, } return {"acts": int(record["act_count"]), "relationships": int(record["relationship_count"])}