"""
Map-Reduce generation for Graph RAG.
When the assembled context exceeds a certain threshold, a single LLM call
can time out or produce degraded answers. This module splits the context
into chunks, runs parallel "map" calls to produce partial summaries, then
merges them into a final answer with a "reduce" call.
Pipeline::
context_chunks ──► LLM map (parallel) ──► partial summaries
│
question ──────────────────────────────────────► LLM reduce ──► answer
Usage::
answer = await map_reduce_generate(
context=long_context,
question=question,
llm=llm_chain,
chunk_chars=6000,
map_timeout=15.0,
reduce_timeout=20.0,
)
"""
import asyncio
import logging
import textwrap
from typing import Any, List, Optional
from lalandre_core.config import get_config
from langchain_core.output_parsers import StrOutputParser
logger = logging.getLogger(__name__)
# ── Context chunking ──────────────────────────────────────────────────────
def _split_context(context: str, chunk_chars: int) -> List[str]:
"""
Split context into chunks at paragraph boundaries.
Prefers splitting at ``\\n\\n`` to avoid cutting mid-sentence.
"""
if len(context) <= chunk_chars:
return [context]
paragraphs = context.split("\n\n")
chunks: List[str] = []
current: List[str] = []
current_len = 0
for para in paragraphs:
para_len = len(para) + 2 # account for \n\n
if current_len + para_len > chunk_chars and current:
chunks.append("\n\n".join(current))
current = [para]
current_len = para_len
else:
current.append(para)
current_len += para_len
if current:
chunks.append("\n\n".join(current))
return chunks
# ── Map phase ─────────────────────────────────────────────────────────────
MAP_PROMPT_TEMPLATE = textwrap.dedent("""\
You are a legal research assistant. Analyze the following legal document
excerpts and extract the key information relevant to the question.
Documents:
{context_chunk}
Question: {question}
Provide a concise summary of the relevant information found in these
excerpts. Focus on facts, legal references (CELEX numbers), regulatory
relationships, and specific provisions. Keep the native evidence
identifiers exactly as provided, such as [S1], [G1], [R1], [C1].
""")
REDUCE_PROMPT_TEMPLATE = textwrap.dedent("""\
You are a legal research assistant. Below are partial analyses of
different sections of a regulatory document corpus, all answering
the same question.
Partial analyses:
{partial_summaries}
Question: {question}
Synthesize these partial analyses into a single comprehensive answer.
Resolve any contradictions. Cite sources using their native identifiers
[Sx], [Gx], [Rx], [Cx] and do not rewrite everything as [Sx].
Take into account the regulatory ecosystem (amendments, implementations,
citations between acts).
""")
async def _map_one_chunk(
llm: Any,
context_chunk: str,
question: str,
chunk_index: int,
timeout: float,
) -> str:
"""Run a single map call with timeout."""
prompt = MAP_PROMPT_TEMPLATE.format(
context_chunk=context_chunk,
question=question,
)
_str_llm = llm | StrOutputParser()
try:
result = await asyncio.wait_for(
asyncio.to_thread(_str_llm.invoke, prompt),
timeout=timeout,
)
return str(result)
except asyncio.TimeoutError:
logger.warning("Map chunk %d timed out after %.1fs", chunk_index, timeout)
return f"[Chunk {chunk_index + 1}: analysis timed out]"
except Exception as e:
logger.warning("Map chunk %d failed: %s", chunk_index, e)
return f"[Chunk {chunk_index + 1}: analysis failed]"
# ── Reduce phase ──────────────────────────────────────────────────────────
async def _reduce(
llm: Any,
partial_summaries: List[str],
question: str,
timeout: float,
) -> str:
"""Merge partial summaries into a final answer."""
numbered: List[str] = []
for idx, summary in enumerate(partial_summaries, start=1):
numbered.append(f"--- Analysis {idx} ---\n{summary}")
joined = "\n\n".join(numbered)
prompt = REDUCE_PROMPT_TEMPLATE.format(
partial_summaries=joined,
question=question,
)
_str_llm = llm | StrOutputParser()
return str(
await asyncio.wait_for(
asyncio.to_thread(_str_llm.invoke, prompt),
timeout=timeout,
)
)
# ── Public API ────────────────────────────────────────────────────────────
def _cfg_map_reduce_threshold() -> int:
return int(get_config().graph.map_reduce_threshold)
def _cfg_chunk_chars() -> int:
return int(get_config().graph.map_reduce_chunk_chars)
def _cfg_max_parallel() -> int:
return int(get_config().graph.map_reduce_max_parallel)
def _cfg_map_timeout() -> float:
return float(get_config().graph.map_reduce_map_timeout)
def _cfg_reduce_timeout() -> float:
return float(get_config().graph.map_reduce_reduce_timeout)
[docs]
async def map_reduce_generate(
*,
context: str,
question: str,
llm: Any,
chunk_chars: Optional[int] = None,
map_timeout: Optional[float] = None,
reduce_timeout: Optional[float] = None,
max_parallel: Optional[int] = None,
) -> str:
"""
Map-reduce generation pipeline.
1. Split context into chunks of ``chunk_chars``
2. Run up to ``max_parallel`` map calls concurrently
3. Merge with a single reduce call
Falls back to the concatenated map summaries if the reduce step fails.
All parameters default to values from ``config.graph``.
"""
_chunk_chars = chunk_chars if chunk_chars is not None else _cfg_chunk_chars()
_map_timeout = map_timeout if map_timeout is not None else _cfg_map_timeout()
_reduce_timeout = reduce_timeout if reduce_timeout is not None else _cfg_reduce_timeout()
_max_parallel = max_parallel if max_parallel is not None else _cfg_max_parallel()
chunks = _split_context(context, _chunk_chars)
logger.info(
"Map-reduce: %d chunks (%.1fk chars total), max_parallel=%d",
len(chunks),
len(context) / 1000,
_max_parallel,
)
# Map phase — run in batches
all_summaries: List[str] = []
for batch_start in range(0, len(chunks), _max_parallel):
batch = chunks[batch_start : batch_start + _max_parallel]
tasks = [_map_one_chunk(llm, chunk, question, batch_start + i, _map_timeout) for i, chunk in enumerate(batch)]
batch_results = await asyncio.gather(*tasks)
all_summaries.extend(batch_results)
# Filter out empty or failed summaries
valid_summaries = [s for s in all_summaries if s and not s.startswith("[Chunk")]
if not valid_summaries:
logger.warning("All map calls failed; returning concatenated chunks")
return "\n\n".join(all_summaries)
# If only one summary made it through, skip reduce
if len(valid_summaries) == 1:
return valid_summaries[0]
# Reduce phase
try:
answer = await _reduce(llm, valid_summaries, question, _reduce_timeout)
return answer
except asyncio.TimeoutError:
logger.warning("Reduce timed out; returning best partial summary")
return valid_summaries[0]
except Exception as e:
logger.warning("Reduce failed: %s; returning best partial summary", e)
return valid_summaries[0]
[docs]
def should_use_map_reduce(context: str, threshold: Optional[int] = None) -> bool:
"""Decide whether to use map-reduce based on context length."""
_threshold = threshold if threshold is not None else _cfg_map_reduce_threshold()
return len(context) > _threshold