"""
Summarize Mode
Generate summaries of documents related to a topic
"""
import time
from typing import Any, Dict, List, Optional
from lalandre_core.config import get_config
from lalandre_rag.retrieval import RetrievalService
from lalandre_rag.retrieval.context.service import ContextService
from ..adapters import LlamaIndexAdapter
from ..response import (
build_source_document,
collect_act_contexts,
create_blocked_sourced_response,
create_compare_response,
create_summarize_response,
describe_citation_validation_failure,
enforce_cited_answer,
)
from ..summaries import ActSummaryService, QuestionSummaryService
[docs]
class SummarizeMode:
"""
MODE 4: Summarization
Generate summaries using TreeSummarize for hierarchical processing
"""
def __init__(
self,
retrieval_service: RetrievalService,
context_service: ContextService,
llamaindex_adapter: Optional[LlamaIndexAdapter],
citation_llm: Any = None,
act_summary_service: Optional[ActSummaryService] = None,
question_summary_service: Optional[QuestionSummaryService] = None,
):
"""
Initialize summarize mode
Args:
retrieval_service: Service for document retrieval
context_service: Service for context enrichment
llamaindex_adapter: LlamaIndex adapter (optional)
"""
self.retrieval_service = retrieval_service
self.context_service = context_service
self.llamaindex_adapter = llamaindex_adapter
self.citation_llm = citation_llm
self.act_summary_service = act_summary_service
self.question_summary_service = question_summary_service
[docs]
def summarize_canonical(
self,
*,
celex: str,
question: str,
) -> Optional[Dict[str, Any]]:
"""Return a cached canonical summary response when one is available."""
if self.act_summary_service is None:
return None
snapshot = self.act_summary_service.get_canonical_summary_by_celex(celex)
if snapshot is None or not snapshot.available:
return None
response = create_summarize_response(
query=question,
answer=str(snapshot.summary),
documents=[],
acts={},
)
response["metadata"].update(
{
"summary_source": "canonical",
"summary_cache_hit": True,
"summary_stale": snapshot.is_stale,
"canonical_summary_status": snapshot.status,
"canonical_summary_generated_at": (
snapshot.generated_at.isoformat() if snapshot.generated_at is not None else None
),
"canonical_summary_prompt_version": snapshot.prompt_version,
"canonical_summary_model_id": snapshot.model_id,
"canonical_summary_trace": snapshot.trace,
"phase_timings_ms": {
"retrieval_ms": 0.0,
"context_enrichment_ms": 0.0,
"generation_ms": 0.0,
"sources_build_ms": 0.0,
"total_ms": 0.0,
},
}
)
return response
[docs]
def summarize_question(
self,
*,
topic: str,
top_k: int,
score_threshold: Optional[float],
filters: Optional[Dict[str, Any]],
include_relations: bool,
include_full_content: bool,
) -> Dict[str, Any]:
"""Run question summarization, optionally augmented with canonical memory."""
celex = str(filters.get("celex")) if isinstance(filters, dict) and filters.get("celex") else None
generation_topic = topic
summary_meta: Dict[str, Any] = {
"summary_source": "question_augmented",
"summary_cache_hit": False,
"summary_stale": False,
}
if self.question_summary_service is not None:
generation_topic, summary_meta = self.question_summary_service.augment_question(
celex=celex,
question=topic,
)
return self._summarize_impl(
topic=topic,
generation_topic=generation_topic,
top_k=top_k,
score_threshold=score_threshold,
filters=filters,
include_relations=include_relations,
include_full_content=include_full_content,
summary_meta=summary_meta,
)
[docs]
def summarize(
self,
topic: str,
top_k: int = 10,
score_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
include_relations: bool = True,
include_full_content: bool = False,
) -> Dict[str, Any]:
"""
Generate a summary of documents related to a topic
Args:
topic: Topic or question to summarize
top_k: Number of documents to retrieve
filters: Metadata filters
include_relations: Include relations in context
Returns:
Dictionary with summary and sources
"""
return self.summarize_question(
topic=topic,
top_k=top_k,
score_threshold=score_threshold,
filters=filters,
include_relations=include_relations,
include_full_content=include_full_content,
)
def _summarize_impl(
self,
*,
topic: str,
generation_topic: str,
top_k: int,
score_threshold: Optional[float],
filters: Optional[Dict[str, Any]],
include_relations: bool,
include_full_content: bool,
summary_meta: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
if self.llamaindex_adapter is None:
raise ValueError("No LlamaIndex adapter is available for the current LLM provider.")
# Retrieve documents
total_started_at = time.perf_counter()
retrieval_started_at = time.perf_counter()
retrieval_results = self.retrieval_service.retrieve(
query=topic, top_k=top_k, score_threshold=score_threshold, filters=filters
)
retrieval_ms = (time.perf_counter() - retrieval_started_at) * 1000.0
if not retrieval_results:
return create_blocked_sourced_response(
mode="summarize",
query=topic,
reason="empty_retrieval",
metadata={
**(summary_meta or {}),
"phase_timings_ms": {
"retrieval_ms": round(retrieval_ms, 1),
"context_enrichment_ms": 0.0,
"generation_ms": 0.0,
"sources_build_ms": 0.0,
"total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1),
},
},
)
# Enrich results for consistent metadata access
context_started_at = time.perf_counter()
context_slices = self.context_service.enrich_results(
retrieval_results, include_relations=include_relations, include_subjects=True
)
context_ms = (time.perf_counter() - context_started_at) * 1000.0
# Budget: cap context to leave room for prompt + output
max_context_chars = get_config().generation.summarize_max_context_chars
total_chars = 0
budgeted_slices: List[Any] = []
for cs in context_slices:
slice_chars = len(cs.content)
if total_chars + slice_chars > max_context_chars and budgeted_slices:
break
budgeted_slices.append(cs)
total_chars += slice_chars
context_slices = budgeted_slices or context_slices[:1]
key_fn = LlamaIndexAdapter.context_slice_key
source_id_map = {key_fn(doc): f"S{idx}" for idx, doc in enumerate(context_slices, start=1)}
# Use LlamaIndex TreeSummarize
generation_started_at = time.perf_counter()
summary = self.llamaindex_adapter.summarize(
generation_topic,
context_slices,
source_id_map=source_id_map,
)
generation_ms = (time.perf_counter() - generation_started_at) * 1000.0
sources_build_started_at = time.perf_counter()
documents: List[Dict[str, Any]] = []
for idx, doc in enumerate(context_slices, start=1):
documents.append(
build_source_document(
doc,
include_full_content=include_full_content,
include_content_preview=True,
source_id=source_id_map.get(key_fn(doc), f"S{idx}"),
)
)
sources_build_ms = (time.perf_counter() - sources_build_started_at) * 1000.0
if not documents:
return create_blocked_sourced_response(
mode="summarize",
query=topic,
reason="no_source_context",
metadata={
**(summary_meta or {}),
"phase_timings_ms": {
"retrieval_ms": round(retrieval_ms, 1),
"context_enrichment_ms": round(context_ms, 1),
"generation_ms": round(generation_ms, 1),
"sources_build_ms": round(sources_build_ms, 1),
"total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1),
},
},
)
enforcement = enforce_cited_answer(
mode="summarize",
question=topic,
draft_answer=summary,
sources=documents,
llm=self.citation_llm,
)
if enforcement["blocked"]:
return create_blocked_sourced_response(
mode="summarize",
query=topic,
reason="invalid_citations",
sources={
"documents": documents,
"acts": collect_act_contexts(context_slices),
},
metadata={
**(summary_meta or {}),
"citation_validation": enforcement["validation"],
"citation_failure_detail": describe_citation_validation_failure(enforcement["validation"]),
"citation_repair_attempted": bool(enforcement["repair_attempted"]),
"citation_repaired": bool(enforcement["repaired"]),
"phase_timings_ms": {
"retrieval_ms": round(retrieval_ms, 1),
"context_enrichment_ms": round(context_ms, 1),
"generation_ms": round(generation_ms, 1),
"sources_build_ms": round(sources_build_ms, 1),
"total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1),
},
},
)
summary = str(enforcement["answer"])
response = create_summarize_response(
query=topic, answer=summary, documents=documents, acts=collect_act_contexts(context_slices)
)
response["metadata"]["citation_validation"] = enforcement["validation"]
response["metadata"]["citation_repair_attempted"] = bool(enforcement["repair_attempted"])
response["metadata"]["citation_repaired"] = bool(enforcement["repaired"])
if summary_meta:
response["metadata"].update(summary_meta)
response["metadata"]["phase_timings_ms"] = {
"retrieval_ms": round(retrieval_ms, 1),
"context_enrichment_ms": round(context_ms, 1),
"generation_ms": round(generation_ms, 1),
"sources_build_ms": round(sources_build_ms, 1),
"total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1),
}
return response
[docs]
class CompareMode:
"""
MODE 5: Comparison
Compare multiple legal documents
"""
def __init__(
self,
retrieval_service: RetrievalService,
context_service: ContextService,
llamaindex_adapter: Optional[LlamaIndexAdapter],
citation_llm: Any = None,
question_summary_service: Optional[QuestionSummaryService] = None,
):
"""
Initialize compare mode
Args:
retrieval_service: Service for document retrieval
context_service: Service for context enrichment
llamaindex_adapter: LlamaIndex adapter (optional)
"""
self.retrieval_service = retrieval_service
self.context_service = context_service
self.llamaindex_adapter = llamaindex_adapter
self.citation_llm = citation_llm
self.question_summary_service = question_summary_service
[docs]
def compare(
self,
comparison_question: str,
celex_list: Optional[List[str]] = None,
top_k: int = 10,
score_threshold: Optional[float] = None,
include_full_content: bool = False,
) -> Dict[str, Any]:
"""
Compare multiple legal documents
Args:
comparison_question: What to compare
celex_list: Optional list of specific CELEX to compare
top_k: Number of documents if CELEX not specified
Returns:
Dictionary with comparison and sources
"""
if self.llamaindex_adapter is None:
raise ValueError("No LlamaIndex adapter is available for the current LLM provider.")
filtered_celex_list = [celex for celex in (celex_list or []) if celex]
if len(filtered_celex_list) < 2:
raise ValueError("Compare mode requires at least two CELEX identifiers.")
generation_question = comparison_question
summary_meta: Dict[str, Any] = {
"summary_source": "question_augmented",
"summary_cache_hit": False,
"summary_stale": False,
}
if self.question_summary_service is not None:
generation_question, summary_meta = self.question_summary_service.augment_compare_question(
comparison_question=comparison_question,
celex_list=filtered_celex_list,
)
# Retrieve specific documents with balanced top_k budget
total_started_at = time.perf_counter()
all_results: List[Any] = []
requested_top_k = max(top_k, 1)
total_celex = len(filtered_celex_list)
base_limit = requested_top_k // total_celex if total_celex else 0
remainder = requested_top_k % total_celex if total_celex else 0
retrieval_started_at = time.perf_counter()
for idx, celex in enumerate(filtered_celex_list):
celex_top_k = base_limit + (1 if idx < remainder else 0)
if celex_top_k <= 0:
continue
results = self.retrieval_service.retrieve(
query=comparison_question,
top_k=celex_top_k,
score_threshold=score_threshold,
filters={"celex": celex},
)
all_results.extend(results)
retrieval_ms = (time.perf_counter() - retrieval_started_at) * 1000.0
if not all_results:
return create_blocked_sourced_response(
mode="compare",
query=comparison_question,
reason="empty_retrieval",
metadata={
**summary_meta,
"phase_timings_ms": {
"retrieval_ms": round(retrieval_ms, 1),
"context_enrichment_ms": 0.0,
"generation_ms": 0.0,
"sources_build_ms": 0.0,
"total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1),
},
},
)
context_started_at = time.perf_counter()
context_slices = self.context_service.enrich_results(
all_results, include_relations=False, include_subjects=True
)
context_ms = (time.perf_counter() - context_started_at) * 1000.0
# Budget: cap context to leave room for prompt + output
max_context_chars = get_config().generation.summarize_max_context_chars
total_chars = 0
budgeted_slices: List[Any] = []
for cs in context_slices:
slice_chars = len(cs.content)
if total_chars + slice_chars > max_context_chars and budgeted_slices:
break
budgeted_slices.append(cs)
total_chars += slice_chars
context_slices = budgeted_slices or context_slices[:1]
key_fn = LlamaIndexAdapter.context_slice_key
source_id_map = {key_fn(doc): f"S{idx}" for idx, doc in enumerate(context_slices, start=1)}
# Use LlamaIndex for structured multi-document comparison
generation_started_at = time.perf_counter()
comparison = self.llamaindex_adapter.compare(
generation_question,
context_slices,
filtered_celex_list,
source_id_map=source_id_map,
)
generation_ms = (time.perf_counter() - generation_started_at) * 1000.0
sources_build_started_at = time.perf_counter()
documents: List[Dict[str, Any]] = []
for idx, doc in enumerate(context_slices, start=1):
documents.append(
build_source_document(
doc,
include_full_content=include_full_content,
include_content_preview=True,
source_id=source_id_map.get(key_fn(doc), f"S{idx}"),
)
)
documents_compared = list(set(doc.act.celex for doc in context_slices if doc.act.celex))
sources_build_ms = (time.perf_counter() - sources_build_started_at) * 1000.0
if not documents:
return create_blocked_sourced_response(
mode="compare",
query=comparison_question,
reason="no_source_context",
metadata={
**summary_meta,
"phase_timings_ms": {
"retrieval_ms": round(retrieval_ms, 1),
"context_enrichment_ms": round(context_ms, 1),
"generation_ms": round(generation_ms, 1),
"sources_build_ms": round(sources_build_ms, 1),
"total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1),
},
},
)
enforcement = enforce_cited_answer(
mode="compare",
question=comparison_question,
draft_answer=comparison,
sources=documents,
llm=self.citation_llm,
)
if enforcement["blocked"]:
return create_blocked_sourced_response(
mode="compare",
query=comparison_question,
reason="invalid_citations",
sources={
"documents": documents,
"acts": collect_act_contexts(context_slices),
},
metadata={
**summary_meta,
"citation_validation": enforcement["validation"],
"citation_failure_detail": describe_citation_validation_failure(enforcement["validation"]),
"citation_repair_attempted": bool(enforcement["repair_attempted"]),
"citation_repaired": bool(enforcement["repaired"]),
"phase_timings_ms": {
"retrieval_ms": round(retrieval_ms, 1),
"context_enrichment_ms": round(context_ms, 1),
"generation_ms": round(generation_ms, 1),
"sources_build_ms": round(sources_build_ms, 1),
"total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1),
},
},
)
comparison = str(enforcement["answer"])
response = create_compare_response(
query=comparison_question,
answer=comparison,
documents=documents,
documents_compared=documents_compared,
acts=collect_act_contexts(context_slices),
)
response["metadata"]["citation_validation"] = enforcement["validation"]
response["metadata"]["citation_repair_attempted"] = bool(enforcement["repair_attempted"])
response["metadata"]["citation_repaired"] = bool(enforcement["repaired"])
response["metadata"].update(summary_meta)
response["metadata"]["phase_timings_ms"] = {
"retrieval_ms": round(retrieval_ms, 1),
"context_enrichment_ms": round(context_ms, 1),
"generation_ms": round(generation_ms, 1),
"sources_build_ms": round(sources_build_ms, 1),
"total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1),
}
return response