"""
Context Compressor — reduces context size via per-act LLM summarization.
Groups context slices by act, and for acts exceeding a character budget,
uses an LLM call to compress the fragments into a dense summary.
Preserves the ContextSlice structure so downstream code is unchanged.
"""
import logging
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
from lalandre_core.config import get_config
from langchain_core.output_parsers import StrOutputParser
from ...prompts import render_compressor_prompt
from .models import ContextSlice, DocumentMeta
logger = logging.getLogger(__name__)
[docs]
def compress_context(
slices: List[ContextSlice],
llm: Any,
*,
budget_chars: int,
max_slices: int = 20,
) -> List[ContextSlice]:
"""Compress context slices to fit within a character budget.
Strategy:
1. Group slices by act_id
2. For acts with multiple large slices, compress into a single dense slice
3. Keep single/small slices as-is
4. Return compressed slices sorted by original score (best first)
Args:
slices: Input context slices (already scored and sorted)
llm: LangChain-compatible LLM for compression calls
budget_chars: Target total character budget
max_slices: Max slices to keep after compression
Returns:
Compressed list of ContextSlice objects
"""
if not slices:
return []
cfg = get_config()
budget_cfg = cfg.context_budget
max_workers = cfg.search.max_parallel_workers
total_chars = sum(len(s.content or "") for s in slices)
if total_chars <= budget_chars and len(slices) <= max_slices:
return slices # Already within budget
# Group by act_id
by_act: Dict[int, List[ContextSlice]] = defaultdict(list)
for s in slices:
by_act[s.act.act_id].append(s)
compressed: List[ContextSlice] = []
# Per-act budget: distribute proportionally by score
total_score = sum(s.score for s in slices if s.score)
if total_score <= 0:
total_score = 1.0
# Categorize acts: keep-as-is, truncate, or compress (LLM call)
acts_to_compress: List[tuple[List[ContextSlice], int]] = []
for act_id, act_slices in by_act.items():
act_total_chars = sum(len(s.content or "") for s in act_slices)
act_score_sum = sum(s.score for s in act_slices if s.score)
act_budget = int(budget_chars * (act_score_sum / total_score))
act_budget = max(act_budget, budget_cfg.compression_min_budget)
if act_total_chars <= act_budget or act_total_chars < budget_cfg.compression_min_chars:
compressed.extend(act_slices)
elif len(act_slices) == 1:
s = act_slices[0]
content = (s.content or "")[:act_budget]
compressed.append(
ContextSlice(
content=content,
score=s.score,
act=s.act,
doc=s.doc,
trace=s.trace,
)
)
else:
acts_to_compress.append((act_slices, act_budget))
# Compress acts in parallel (each makes an LLM call)
if acts_to_compress:
with ThreadPoolExecutor(max_workers=min(len(acts_to_compress), max_workers)) as executor:
futures = [
executor.submit(_compress_act_slices, act_slices, llm, act_budget)
for act_slices, act_budget in acts_to_compress
]
for future in futures:
compressed.append(future.result())
# Sort by score (best first) and cap
compressed.sort(key=lambda s: s.score, reverse=True)
return compressed[:max_slices]
def _compress_act_slices(
act_slices: List[ContextSlice],
llm: Any,
target_chars: int,
) -> ContextSlice:
"""Compress multiple slices from one act into a single dense slice."""
act = act_slices[0].act
best_score = max(s.score for s in act_slices)
best_doc = max(act_slices, key=lambda s: s.score).doc
# Build fragments text
fragment_parts = []
for s in act_slices:
loc = s.doc.subdivision_type
if s.doc.chunk_id is not None:
loc = f"chunk {s.doc.chunk_id}"
fragment_parts.append(f"[{loc}]\n{s.content or ''}")
fragments = "\n---\n".join(fragment_parts)
started_at = time.perf_counter()
try:
prompt = render_compressor_prompt(
celex=act.celex,
title=act.title,
level=act.regulatory_level or "inconnu",
fragments=fragments,
max_chars=target_chars,
)
chain = llm | StrOutputParser()
summary: str = chain.invoke(prompt)
elapsed_ms = (time.perf_counter() - started_at) * 1000.0
logger.debug(
"Compressed %d slices for %s: %d→%d chars in %.0fms",
len(act_slices),
act.celex,
len(fragments),
len(summary),
elapsed_ms,
)
except Exception as exc:
elapsed_ms = (time.perf_counter() - started_at) * 1000.0
logger.warning(
"Compression failed for %s (%.0fms): %s — falling back to truncation",
act.celex,
elapsed_ms,
exc,
)
# Fallback: concatenate and truncate
summary = fragments[:target_chars]
# Build merged doc metadata
merged_doc = DocumentMeta(
source_kind="compressed",
subdivision_id=best_doc.subdivision_id,
subdivision_type=best_doc.subdivision_type,
sequence_order=best_doc.sequence_order,
chunk_id=best_doc.chunk_id,
chunk_index=best_doc.chunk_index,
payload={"compressed_from": len(act_slices)},
)
trace: Optional[Dict[str, Any]] = None
if act_slices[0].trace:
trace = dict(act_slices[0].trace)
trace["compressed"] = True
trace["compressed_from_count"] = len(act_slices)
trace["compression_ms"] = round(elapsed_ms, 1)
return ContextSlice(
content=summary,
score=best_score,
act=act,
doc=merged_doc,
trace=trace,
)