Source code for lalandre_rag.retrieval.context.compressor

"""
Context Compressor — reduces context size via per-act LLM summarization.

Groups context slices by act, and for acts exceeding a character budget,
uses an LLM call to compress the fragments into a dense summary.
Preserves the ContextSlice structure so downstream code is unchanged.
"""

import logging
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional

from lalandre_core.config import get_config
from langchain_core.output_parsers import StrOutputParser

from ...prompts import render_compressor_prompt
from .models import ContextSlice, DocumentMeta

logger = logging.getLogger(__name__)


[docs] def compress_context( slices: List[ContextSlice], llm: Any, *, budget_chars: int, max_slices: int = 20, ) -> List[ContextSlice]: """Compress context slices to fit within a character budget. Strategy: 1. Group slices by act_id 2. For acts with multiple large slices, compress into a single dense slice 3. Keep single/small slices as-is 4. Return compressed slices sorted by original score (best first) Args: slices: Input context slices (already scored and sorted) llm: LangChain-compatible LLM for compression calls budget_chars: Target total character budget max_slices: Max slices to keep after compression Returns: Compressed list of ContextSlice objects """ if not slices: return [] cfg = get_config() budget_cfg = cfg.context_budget max_workers = cfg.search.max_parallel_workers total_chars = sum(len(s.content or "") for s in slices) if total_chars <= budget_chars and len(slices) <= max_slices: return slices # Already within budget # Group by act_id by_act: Dict[int, List[ContextSlice]] = defaultdict(list) for s in slices: by_act[s.act.act_id].append(s) compressed: List[ContextSlice] = [] # Per-act budget: distribute proportionally by score total_score = sum(s.score for s in slices if s.score) if total_score <= 0: total_score = 1.0 # Categorize acts: keep-as-is, truncate, or compress (LLM call) acts_to_compress: List[tuple[List[ContextSlice], int]] = [] for act_id, act_slices in by_act.items(): act_total_chars = sum(len(s.content or "") for s in act_slices) act_score_sum = sum(s.score for s in act_slices if s.score) act_budget = int(budget_chars * (act_score_sum / total_score)) act_budget = max(act_budget, budget_cfg.compression_min_budget) if act_total_chars <= act_budget or act_total_chars < budget_cfg.compression_min_chars: compressed.extend(act_slices) elif len(act_slices) == 1: s = act_slices[0] content = (s.content or "")[:act_budget] compressed.append( ContextSlice( content=content, score=s.score, act=s.act, doc=s.doc, trace=s.trace, ) ) else: acts_to_compress.append((act_slices, act_budget)) # Compress acts in parallel (each makes an LLM call) if acts_to_compress: with ThreadPoolExecutor(max_workers=min(len(acts_to_compress), max_workers)) as executor: futures = [ executor.submit(_compress_act_slices, act_slices, llm, act_budget) for act_slices, act_budget in acts_to_compress ] for future in futures: compressed.append(future.result()) # Sort by score (best first) and cap compressed.sort(key=lambda s: s.score, reverse=True) return compressed[:max_slices]
def _compress_act_slices( act_slices: List[ContextSlice], llm: Any, target_chars: int, ) -> ContextSlice: """Compress multiple slices from one act into a single dense slice.""" act = act_slices[0].act best_score = max(s.score for s in act_slices) best_doc = max(act_slices, key=lambda s: s.score).doc # Build fragments text fragment_parts = [] for s in act_slices: loc = s.doc.subdivision_type if s.doc.chunk_id is not None: loc = f"chunk {s.doc.chunk_id}" fragment_parts.append(f"[{loc}]\n{s.content or ''}") fragments = "\n---\n".join(fragment_parts) started_at = time.perf_counter() try: prompt = render_compressor_prompt( celex=act.celex, title=act.title, level=act.regulatory_level or "inconnu", fragments=fragments, max_chars=target_chars, ) chain = llm | StrOutputParser() summary: str = chain.invoke(prompt) elapsed_ms = (time.perf_counter() - started_at) * 1000.0 logger.debug( "Compressed %d slices for %s: %d%d chars in %.0fms", len(act_slices), act.celex, len(fragments), len(summary), elapsed_ms, ) except Exception as exc: elapsed_ms = (time.perf_counter() - started_at) * 1000.0 logger.warning( "Compression failed for %s (%.0fms): %s — falling back to truncation", act.celex, elapsed_ms, exc, ) # Fallback: concatenate and truncate summary = fragments[:target_chars] # Build merged doc metadata merged_doc = DocumentMeta( source_kind="compressed", subdivision_id=best_doc.subdivision_id, subdivision_type=best_doc.subdivision_type, sequence_order=best_doc.sequence_order, chunk_id=best_doc.chunk_id, chunk_index=best_doc.chunk_index, payload={"compressed_from": len(act_slices)}, ) trace: Optional[Dict[str, Any]] = None if act_slices[0].trace: trace = dict(act_slices[0].trace) trace["compressed"] = True trace["compressed_from_count"] = len(act_slices) trace["compression_ms"] = round(elapsed_ms, 1) return ContextSlice( content=summary, score=best_score, act=act, doc=merged_doc, trace=trace, )