Source code for lalandre_rag.graph.map_reduce

"""
Map-Reduce generation for Graph RAG.

When the assembled context exceeds a certain threshold, a single LLM call
can time out or produce degraded answers.  This module splits the context
into chunks, runs parallel "map" calls to produce partial summaries, then
merges them into a final answer with a "reduce" call.

Pipeline::

    context_chunks  ──►  LLM map (parallel)  ──►  partial summaries

    question ──────────────────────────────────────►  LLM reduce  ──►  answer

Usage::

    answer = await map_reduce_generate(
        context=long_context,
        question=question,
        llm=llm_chain,
        chunk_chars=6000,
        map_timeout=15.0,
        reduce_timeout=20.0,
    )
"""

import asyncio
import logging
import textwrap
from typing import Any, List, Optional

from lalandre_core.config import get_config
from langchain_core.output_parsers import StrOutputParser

logger = logging.getLogger(__name__)


# ── Context chunking ──────────────────────────────────────────────────────


def _split_context(context: str, chunk_chars: int) -> List[str]:
    """
    Split context into chunks at paragraph boundaries.

    Prefers splitting at ``\\n\\n`` to avoid cutting mid-sentence.
    """
    if len(context) <= chunk_chars:
        return [context]

    paragraphs = context.split("\n\n")
    chunks: List[str] = []
    current: List[str] = []
    current_len = 0

    for para in paragraphs:
        para_len = len(para) + 2  # account for \n\n
        if current_len + para_len > chunk_chars and current:
            chunks.append("\n\n".join(current))
            current = [para]
            current_len = para_len
        else:
            current.append(para)
            current_len += para_len

    if current:
        chunks.append("\n\n".join(current))

    return chunks


# ── Map phase ─────────────────────────────────────────────────────────────

MAP_PROMPT_TEMPLATE = textwrap.dedent("""\
    You are a legal research assistant. Analyze the following legal document
    excerpts and extract the key information relevant to the question.

    Documents:
    {context_chunk}

    Question: {question}

    Provide a concise summary of the relevant information found in these
    excerpts. Focus on facts, legal references (CELEX numbers), regulatory
    relationships, and specific provisions. Keep the native evidence
    identifiers exactly as provided, such as [S1], [G1], [R1], [C1].
    """)


REDUCE_PROMPT_TEMPLATE = textwrap.dedent("""\
    You are a legal research assistant. Below are partial analyses of
    different sections of a regulatory document corpus, all answering
    the same question.

    Partial analyses:
    {partial_summaries}

    Question: {question}

    Synthesize these partial analyses into a single comprehensive answer.
    Resolve any contradictions. Cite sources using their native identifiers
    [Sx], [Gx], [Rx], [Cx] and do not rewrite everything as [Sx].
    Take into account the regulatory ecosystem (amendments, implementations,
    citations between acts).
    """)


async def _map_one_chunk(
    llm: Any,
    context_chunk: str,
    question: str,
    chunk_index: int,
    timeout: float,
) -> str:
    """Run a single map call with timeout."""
    prompt = MAP_PROMPT_TEMPLATE.format(
        context_chunk=context_chunk,
        question=question,
    )
    _str_llm = llm | StrOutputParser()
    try:
        result = await asyncio.wait_for(
            asyncio.to_thread(_str_llm.invoke, prompt),
            timeout=timeout,
        )
        return str(result)
    except asyncio.TimeoutError:
        logger.warning("Map chunk %d timed out after %.1fs", chunk_index, timeout)
        return f"[Chunk {chunk_index + 1}: analysis timed out]"
    except Exception as e:
        logger.warning("Map chunk %d failed: %s", chunk_index, e)
        return f"[Chunk {chunk_index + 1}: analysis failed]"


# ── Reduce phase ──────────────────────────────────────────────────────────


async def _reduce(
    llm: Any,
    partial_summaries: List[str],
    question: str,
    timeout: float,
) -> str:
    """Merge partial summaries into a final answer."""
    numbered: List[str] = []
    for idx, summary in enumerate(partial_summaries, start=1):
        numbered.append(f"--- Analysis {idx} ---\n{summary}")
    joined = "\n\n".join(numbered)

    prompt = REDUCE_PROMPT_TEMPLATE.format(
        partial_summaries=joined,
        question=question,
    )
    _str_llm = llm | StrOutputParser()
    return str(
        await asyncio.wait_for(
            asyncio.to_thread(_str_llm.invoke, prompt),
            timeout=timeout,
        )
    )


# ── Public API ────────────────────────────────────────────────────────────


def _cfg_map_reduce_threshold() -> int:
    return int(get_config().graph.map_reduce_threshold)


def _cfg_chunk_chars() -> int:
    return int(get_config().graph.map_reduce_chunk_chars)


def _cfg_max_parallel() -> int:
    return int(get_config().graph.map_reduce_max_parallel)


def _cfg_map_timeout() -> float:
    return float(get_config().graph.map_reduce_map_timeout)


def _cfg_reduce_timeout() -> float:
    return float(get_config().graph.map_reduce_reduce_timeout)


[docs] async def map_reduce_generate( *, context: str, question: str, llm: Any, chunk_chars: Optional[int] = None, map_timeout: Optional[float] = None, reduce_timeout: Optional[float] = None, max_parallel: Optional[int] = None, ) -> str: """ Map-reduce generation pipeline. 1. Split context into chunks of ``chunk_chars`` 2. Run up to ``max_parallel`` map calls concurrently 3. Merge with a single reduce call Falls back to the concatenated map summaries if the reduce step fails. All parameters default to values from ``config.graph``. """ _chunk_chars = chunk_chars if chunk_chars is not None else _cfg_chunk_chars() _map_timeout = map_timeout if map_timeout is not None else _cfg_map_timeout() _reduce_timeout = reduce_timeout if reduce_timeout is not None else _cfg_reduce_timeout() _max_parallel = max_parallel if max_parallel is not None else _cfg_max_parallel() chunks = _split_context(context, _chunk_chars) logger.info( "Map-reduce: %d chunks (%.1fk chars total), max_parallel=%d", len(chunks), len(context) / 1000, _max_parallel, ) # Map phase — run in batches all_summaries: List[str] = [] for batch_start in range(0, len(chunks), _max_parallel): batch = chunks[batch_start : batch_start + _max_parallel] tasks = [_map_one_chunk(llm, chunk, question, batch_start + i, _map_timeout) for i, chunk in enumerate(batch)] batch_results = await asyncio.gather(*tasks) all_summaries.extend(batch_results) # Filter out empty or failed summaries valid_summaries = [s for s in all_summaries if s and not s.startswith("[Chunk")] if not valid_summaries: logger.warning("All map calls failed; returning concatenated chunks") return "\n\n".join(all_summaries) # If only one summary made it through, skip reduce if len(valid_summaries) == 1: return valid_summaries[0] # Reduce phase try: answer = await _reduce(llm, valid_summaries, question, _reduce_timeout) return answer except asyncio.TimeoutError: logger.warning("Reduce timed out; returning best partial summary") return valid_summaries[0] except Exception as e: logger.warning("Reduce failed: %s; returning best partial summary", e) return valid_summaries[0]
[docs] def should_use_map_reduce(context: str, threshold: Optional[int] = None) -> bool: """Decide whether to use map-reduce based on context length.""" _threshold = threshold if threshold is not None else _cfg_map_reduce_threshold() return len(context) > _threshold