Source code for extraction_worker.graph.community_builder

"""
Graph community detection for extraction worker.

Mirrors the logic in scripts/build_communities.py but runs inside the worker
process, reusing the already-open Neo4j driver from bootstrap.

Communities are persisted as :Community nodes with BELONGS_TO relationships
in Neo4j (no JSON file on disk).

Community summaries are generated by LLM (GraphRAG-style) when a client is
available, with deterministic fallback.
"""

import json
import logging
from collections import Counter, defaultdict
from typing import Any, Dict, List, Optional, Tuple, cast

import networkx as nx
from lalandre_core.config import get_config
from lalandre_core.http.llm_client import JSONHTTPLLMClient

logger = logging.getLogger(__name__)


# ── Neo4j extraction ──────────────────────────────────────────────────────


def _extract_graph(
    driver: Any,
    database: str,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    with driver.session(database=database) as session:
        node_result = session.run(
            """
            MATCH (a:Act)
            RETURN a.id AS id,
                   a.celex AS celex,
                   a.title AS title,
                   a.act_type AS act_type,
                   a.language AS language,
                   a.level AS level,
                   a.sector AS sector,
                   a.adoption_date AS adoption_date
            """
        )
        nodes: List[Dict[str, Any]] = [dict(r) for r in node_result]

        rel_result = session.run(
            """
            MATCH (a:Act)-[r]->(b:Act)
            RETURN a.id AS source_id, b.id AS target_id, type(r) AS rel_type
            """
        )
        relationships: List[Dict[str, Any]] = [dict(r) for r in rel_result]

    logger.info(
        "Community builder: extracted %d nodes and %d relationships",
        len(nodes),
        len(relationships),
    )
    return nodes, relationships


# ── Louvain detection ─────────────────────────────────────────────────────


def _get_relation_weights() -> tuple[Dict[str, float], float]:
    cfg = get_config().graph
    return cfg.community_relation_weights, cfg.community_default_relation_weight


def _detect_and_filter(
    nodes: List[Dict[str, Any]],
    relationships: List[Dict[str, Any]],
    resolution: float,
    min_community_size: int,
) -> Tuple[Dict[int, int], Dict[int, List[int]], float]:
    """Return (partition, communities_dict, modularity)."""
    from community import community_louvain  # type: ignore[import-untyped]

    G: nx.Graph[Any] = nx.Graph()  # type: ignore[possibly-unbound]

    for node in nodes:
        G.add_node(node["id"], celex=node.get("celex"), title=node.get("title"))

    relation_weights, default_weight = _get_relation_weights()
    for rel in relationships:
        src = rel["source_id"]
        tgt = rel["target_id"]
        rtype = rel.get("rel_type", "RELATED_TO")
        weight = relation_weights.get(rtype, default_weight)
        if G.has_edge(src, tgt):
            G[src][tgt]["weight"] += weight
        else:
            G.add_edge(src, tgt, weight=weight)

    logger.info(
        "Community builder: graph has %d nodes, %d edges",
        G.number_of_nodes(),
        G.number_of_edges(),
    )

    raw_partition: Dict[Any, int] = cast(
        Dict[Any, int],
        community_louvain.best_partition(  # type: ignore[misc]
            G, resolution=resolution, weight="weight"
        ),
    )
    modularity: float = cast(
        float,
        community_louvain.modularity(  # type: ignore[misc]
            raw_partition, G, weight="weight"
        ),
    )

    # Group by community to filter small ones
    communities: Dict[int, List[int]] = defaultdict(list)
    for act_id, comm_id in raw_partition.items():
        communities[comm_id].append(act_id)

    filtered: Dict[int, int] = {}
    filtered_communities: Dict[int, List[int]] = {}
    next_id = 0
    for comm_id in sorted(communities.keys()):
        members = communities[comm_id]
        if len(members) >= min_community_size:
            for act_id in members:
                filtered[act_id] = next_id
            filtered_communities[next_id] = sorted(members)
            next_id += 1
        else:
            for act_id in members:
                filtered[act_id] = -1

    logger.info(
        "Community builder: %d communities detected (modularity=%.4f)",
        next_id,
        modularity,
    )
    return filtered, filtered_communities, modularity


# ── Summary generation ────────────────────────────────────────────────────


def _collect_community_data(
    nodes: List[Dict[str, Any]],
    relationships: List[Dict[str, Any]],
    communities: Dict[int, List[int]],
) -> List[Dict[str, Any]]:
    """Collect structural data for each community (acts, relations, centrality)."""
    node_map: Dict[int, Dict[str, Any]] = {n["id"]: n for n in nodes}

    rel_by_src: Dict[int, List[Dict[str, Any]]] = defaultdict(list)
    for rel in relationships:
        rel_by_src[rel["source_id"]].append(rel)
        rel_by_src[rel["target_id"]].append(rel)

    result: List[Dict[str, Any]] = []

    for comm_id, members in sorted(communities.items()):
        member_set = set(members)

        # Collect + deduplicate internal relations
        seen_rels: set[Tuple[int, int, str]] = set()
        unique_rels: List[Dict[str, Any]] = []
        for member_id in members:
            for rel in rel_by_src.get(member_id, []):
                other = rel["target_id"] if rel["source_id"] == member_id else rel["source_id"]
                if other in member_set:
                    key = (
                        min(rel["source_id"], rel["target_id"]),
                        max(rel["source_id"], rel["target_id"]),
                        rel["rel_type"],
                    )
                    if key not in seen_rels:
                        seen_rels.add(key)
                        unique_rels.append(rel)

        rel_dist: Counter[str] = Counter(r["rel_type"] for r in unique_rels)

        act_type_dist: Counter[str] = Counter()
        language_dist: Counter[str] = Counter()
        level_dist: Counter[str] = Counter()
        adoption_year_dist: Counter[str] = Counter()
        for aid in members:
            node = node_map.get(aid, {})

            act_type_raw = node.get("act_type")
            if isinstance(act_type_raw, str) and act_type_raw.strip():
                act_type_dist[act_type_raw.strip()] += 1

            language_raw = node.get("language")
            if isinstance(language_raw, str) and language_raw.strip():
                language_dist[language_raw.strip()] += 1

            level_raw = node.get("level")
            if level_raw is not None:
                level_dist[str(level_raw)] += 1

            adoption_raw = node.get("adoption_date")
            if adoption_raw is not None:
                text = str(adoption_raw)
                if len(text) >= 4 and text[:4].isdigit():
                    adoption_year_dist[text[:4]] += 1

        degrees: Counter[int] = Counter()
        for rel in unique_rels:
            degrees[rel["source_id"]] += 1
            degrees[rel["target_id"]] += 1

        central_acts: List[Dict[str, Any]] = []
        for aid, _ in degrees.most_common(3):
            n = node_map.get(aid, {})
            central_acts.append(
                {
                    "act_id": aid,
                    "celex": n.get("celex", f"ACT-{aid}"),
                    "title": n.get("title", "Unknown"),
                    "degree": degrees.get(aid, 0),
                }
            )

        result.append(
            {
                "community_id": comm_id,
                "num_acts": len(members),
                "num_relations": len(unique_rels),
                "relation_types": dict(rel_dist),
                "act_type_distribution": dict(act_type_dist),
                "language_distribution": dict(language_dist),
                "level_distribution": dict(level_dist),
                "adoption_year_distribution": dict(adoption_year_dist),
                "central_acts": central_acts,
                "members": members,
                "unique_rels": unique_rels,
                "node_map": node_map,
            }
        )

    return result


def _deterministic_summary(data: Dict[str, Any]) -> str:
    """Fallback: rule-based summary when LLM is unavailable."""
    rel_dist = data["relation_types"]
    act_types = data.get("act_type_distribution", {})
    rel_dist_str = ", ".join(f"{t}:{c}" for t, c in Counter(rel_dist).most_common(5))
    act_type_str = ", ".join(f"{t}:{c}" for t, c in Counter(act_types).most_common(3))
    central = data["central_acts"]
    members = data["members"]
    node_map = data["node_map"]

    member_titles = [node_map.get(aid, {}).get("title", "Unknown") for aid in members[:10]]
    titles_str = "; ".join(member_titles)
    if len(members) > 10:
        titles_str += f" (+{len(members) - 10} more)"

    return (
        f"Communauté C{data['community_id']} : {data['num_acts']} actes réglementaires, "
        f"{data['num_relations']} relations internes "
        f"(types: {rel_dist_str or 'aucune'}). "
        f"Profils d'actes: {act_type_str or 'non renseigné'}. "
        f"Actes centraux: {', '.join(a['celex'] for a in central[:3])}. "
        f"Actes: {titles_str}"
    )


def _build_llm_summary_prompt(data: Dict[str, Any]) -> str:
    """Build a prompt for LLM-based community summarization (GraphRAG-style)."""
    node_map = data["node_map"]
    members = data["members"]
    unique_rels = data["unique_rels"]

    # Acts list (cap at 25 to fit context)
    acts_lines: List[str] = []
    for aid in members[:25]:
        n = node_map.get(aid, {})
        celex = n.get("celex", f"ACT-{aid}")
        title = n.get("title", "Unknown")
        acts_lines.append(f"- {celex}: {title}")
    if len(members) > 25:
        acts_lines.append(f"- ... (+{len(members) - 25} autres actes)")

    # Relations list (cap at 40)
    rel_lines: List[str] = []
    for rel in unique_rels[:40]:
        src = node_map.get(rel["source_id"], {}).get("celex", "?")
        tgt = node_map.get(rel["target_id"], {}).get("celex", "?")
        rtype = rel.get("rel_type", "RELATED_TO")
        rel_lines.append(f"- {src} --[{rtype}]--> {tgt}")
    if len(unique_rels) > 40:
        rel_lines.append(f"- ... (+{len(unique_rels) - 40} autres relations)")

    acts_block = "\n".join(acts_lines)
    rels_block = "\n".join(rel_lines) if rel_lines else "(aucune relation interne)"

    act_type_dist = data.get("act_type_distribution", {})
    language_dist = data.get("language_distribution", {})
    level_dist = data.get("level_distribution", {})
    year_dist = data.get("adoption_year_distribution", {})

    attr_lines = [
        f"- act_types: {json.dumps(act_type_dist, ensure_ascii=False)}",
        f"- languages: {json.dumps(language_dist, ensure_ascii=False)}",
        f"- levels: {json.dumps(level_dist, ensure_ascii=False)}",
        f"- adoption_years: {json.dumps(year_dist, ensure_ascii=False)}",
    ]
    attributes_block = "\n".join(attr_lines)

    return (
        f"Tu es un expert en droit européen et français. "
        f"Voici une communauté de {data['num_acts']} actes réglementaires "
        f"liés par {data['num_relations']} relations.\n\n"
        f"Actes:\n{acts_block}\n\n"
        f"Relations:\n{rels_block}\n\n"
        f"Attributs agrégés de la communauté:\n{attributes_block}\n\n"
        "Rédige un résumé thématique concis (3-5 phrases) décrivant :\n"
        "1. Le domaine réglementaire couvert par cette communauté\n"
        "2. Les instruments centraux et leur rôle\n"
        "3. Comment ces actes s'articulent entre eux (amendements, transpositions, etc.)\n\n"
        "Réponds directement avec le résumé, sans préambule."
    )


def _generate_summaries(
    nodes: List[Dict[str, Any]],
    relationships: List[Dict[str, Any]],
    communities: Dict[int, List[int]],
    llm_client: Optional[JSONHTTPLLMClient] = None,
) -> List[Dict[str, Any]]:
    """Generate community summaries (LLM when available, deterministic fallback)."""
    all_data = _collect_community_data(nodes, relationships, communities)

    summaries: List[Dict[str, Any]] = []
    llm_success = 0
    llm_fallback = 0

    for data in all_data:
        summary_text: str

        if llm_client is not None and data["num_acts"] >= 2:
            prompt = _build_llm_summary_prompt(data)
            try:
                raw = llm_client.generate(prompt).strip()
                if raw and len(raw) >= 30:
                    summary_text = raw
                    llm_success += 1
                else:
                    summary_text = _deterministic_summary(data)
                    llm_fallback += 1
            except Exception as exc:
                logger.warning(
                    "LLM summary failed for community C%d: %s",
                    data["community_id"],
                    exc,
                )
                summary_text = _deterministic_summary(data)
                llm_fallback += 1
        else:
            summary_text = _deterministic_summary(data)

        summaries.append(
            {
                "community_id": data["community_id"],
                "num_acts": data["num_acts"],
                "num_relations": data["num_relations"],
                "relation_types": data["relation_types"],
                "act_type_distribution": data.get("act_type_distribution", {}),
                "language_distribution": data.get("language_distribution", {}),
                "level_distribution": data.get("level_distribution", {}),
                "adoption_year_distribution": data.get("adoption_year_distribution", {}),
                "central_acts": data["central_acts"],
                "summary": summary_text,
            }
        )

    if llm_client is not None:
        logger.info(
            "Community summaries: %d LLM-generated, %d deterministic fallback",
            llm_success,
            llm_fallback,
        )

    return summaries


# ── Write to Neo4j ────────────────────────────────────────────────────────


def _write_to_neo4j(
    driver: Any,
    database: str,
    partition: Dict[int, int],
    summaries: List[Dict[str, Any]],
    modularity: float,
    resolution: float,
) -> None:
    """Persist communities as :Community nodes with BELONGS_TO relationships."""
    with driver.session(database=database) as session:
        # Clear old communities
        session.run("MATCH (c:Community) DETACH DELETE c")
        session.run("MATCH (a:Act) WHERE a.community_id IS NOT NULL REMOVE a.community_id")

        # Create Community nodes
        community_batch = [
            {
                "id": s["community_id"],
                "num_acts": s["num_acts"],
                "num_relations": s["num_relations"],
                "relation_types": json.dumps(s["relation_types"], ensure_ascii=False),
                "act_type_distribution": json.dumps(s.get("act_type_distribution", {}), ensure_ascii=False),
                "language_distribution": json.dumps(s.get("language_distribution", {}), ensure_ascii=False),
                "level_distribution": json.dumps(s.get("level_distribution", {}), ensure_ascii=False),
                "adoption_year_distribution": json.dumps(s.get("adoption_year_distribution", {}), ensure_ascii=False),
                "central_acts": json.dumps(s["central_acts"], ensure_ascii=False),
                "summary": s["summary"],
                "modularity": round(modularity, 4),
                "resolution": resolution,
            }
            for s in summaries
        ]
        session.run(
            """
            UNWIND $batch AS props
            CREATE (c:Community)
            SET c = props
            """,
            batch=community_batch,
        )

        # Create BELONGS_TO relationships
        links = [{"act_id": aid, "comm_id": cid} for aid, cid in partition.items() if cid >= 0]
        batch_size = 500
        for start in range(0, len(links), batch_size):
            batch = links[start : start + batch_size]
            session.run(
                """
                UNWIND $batch AS item
                MATCH (a:Act {id: item.act_id})
                MATCH (c:Community {id: item.comm_id})
                MERGE (a)-[:BELONGS_TO]->(c)
                """,
                batch=batch,
            )

        # Ensure index
        try:
            session.run("CREATE INDEX community_id_idx IF NOT EXISTS FOR (c:Community) ON (c.id)")
        except Exception as exc:
            logger.warning("Could not ensure Community index: %s", exc)

    logger.info(
        "Persisted %d Community nodes and %d BELONGS_TO relationships",
        len(community_batch),
        len(links),
    )


# ── Public entry point ────────────────────────────────────────────────────


[docs] def run_community_detection( *, driver: Any, database: str, resolution: float = 1.0, min_community_size: int = 2, llm_client: Optional[JSONHTTPLLMClient] = None, ) -> Dict[str, Any]: """ Full community detection pipeline: extract → detect → summarize → persist to Neo4j. When *llm_client* is provided, community summaries are generated by LLM (GraphRAG-style). Falls back to deterministic summaries on failure. Returns a summary dict with status, counts, and modularity. """ nodes, relationships = _extract_graph(driver, database) if not nodes: logger.warning("Community builder: no Act nodes found, skipping") return {"status": "skipped", "reason": "no_acts", "num_communities": 0} if not relationships: logger.warning("Community builder: no relationships found, skipping") return {"status": "skipped", "reason": "no_relationships", "num_communities": 0} partition, communities, modularity = _detect_and_filter(nodes, relationships, resolution, min_community_size) summaries = _generate_summaries( nodes, relationships, communities, llm_client=llm_client, ) _write_to_neo4j(driver, database, partition, summaries, modularity, resolution) num_communities = len(communities) return { "status": "success", "num_nodes": len(nodes), "num_relationships": len(relationships), "num_communities": num_communities, "modularity": round(modularity, 4), }