Source code for lalandre_rag.modes.summarize_mode

"""
Summarize Mode
Generate summaries of documents related to a topic
"""

import time
from typing import Any, Dict, List, Optional

from lalandre_core.config import get_config

from lalandre_rag.retrieval import RetrievalService
from lalandre_rag.retrieval.context.service import ContextService

from ..adapters import LlamaIndexAdapter
from ..response import (
    build_source_document,
    collect_act_contexts,
    create_blocked_sourced_response,
    create_compare_response,
    create_summarize_response,
    describe_citation_validation_failure,
    enforce_cited_answer,
)
from ..summaries import ActSummaryService, QuestionSummaryService


[docs] class SummarizeMode: """ MODE 4: Summarization Generate summaries using TreeSummarize for hierarchical processing """ def __init__( self, retrieval_service: RetrievalService, context_service: ContextService, llamaindex_adapter: Optional[LlamaIndexAdapter], citation_llm: Any = None, act_summary_service: Optional[ActSummaryService] = None, question_summary_service: Optional[QuestionSummaryService] = None, ): """ Initialize summarize mode Args: retrieval_service: Service for document retrieval context_service: Service for context enrichment llamaindex_adapter: LlamaIndex adapter (optional) """ self.retrieval_service = retrieval_service self.context_service = context_service self.llamaindex_adapter = llamaindex_adapter self.citation_llm = citation_llm self.act_summary_service = act_summary_service self.question_summary_service = question_summary_service
[docs] def summarize_canonical( self, *, celex: str, question: str, ) -> Optional[Dict[str, Any]]: """Return a cached canonical summary response when one is available.""" if self.act_summary_service is None: return None snapshot = self.act_summary_service.get_canonical_summary_by_celex(celex) if snapshot is None or not snapshot.available: return None response = create_summarize_response( query=question, answer=str(snapshot.summary), documents=[], acts={}, ) response["metadata"].update( { "summary_source": "canonical", "summary_cache_hit": True, "summary_stale": snapshot.is_stale, "canonical_summary_status": snapshot.status, "canonical_summary_generated_at": ( snapshot.generated_at.isoformat() if snapshot.generated_at is not None else None ), "canonical_summary_prompt_version": snapshot.prompt_version, "canonical_summary_model_id": snapshot.model_id, "canonical_summary_trace": snapshot.trace, "phase_timings_ms": { "retrieval_ms": 0.0, "context_enrichment_ms": 0.0, "generation_ms": 0.0, "sources_build_ms": 0.0, "total_ms": 0.0, }, } ) return response
[docs] def summarize_question( self, *, topic: str, top_k: int, score_threshold: Optional[float], filters: Optional[Dict[str, Any]], include_relations: bool, include_full_content: bool, ) -> Dict[str, Any]: """Run question summarization, optionally augmented with canonical memory.""" celex = str(filters.get("celex")) if isinstance(filters, dict) and filters.get("celex") else None generation_topic = topic summary_meta: Dict[str, Any] = { "summary_source": "question_augmented", "summary_cache_hit": False, "summary_stale": False, } if self.question_summary_service is not None: generation_topic, summary_meta = self.question_summary_service.augment_question( celex=celex, question=topic, ) return self._summarize_impl( topic=topic, generation_topic=generation_topic, top_k=top_k, score_threshold=score_threshold, filters=filters, include_relations=include_relations, include_full_content=include_full_content, summary_meta=summary_meta, )
[docs] def summarize( self, topic: str, top_k: int = 10, score_threshold: Optional[float] = None, filters: Optional[Dict[str, Any]] = None, include_relations: bool = True, include_full_content: bool = False, ) -> Dict[str, Any]: """ Generate a summary of documents related to a topic Args: topic: Topic or question to summarize top_k: Number of documents to retrieve filters: Metadata filters include_relations: Include relations in context Returns: Dictionary with summary and sources """ return self.summarize_question( topic=topic, top_k=top_k, score_threshold=score_threshold, filters=filters, include_relations=include_relations, include_full_content=include_full_content, )
def _summarize_impl( self, *, topic: str, generation_topic: str, top_k: int, score_threshold: Optional[float], filters: Optional[Dict[str, Any]], include_relations: bool, include_full_content: bool, summary_meta: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: if self.llamaindex_adapter is None: raise ValueError("No LlamaIndex adapter is available for the current LLM provider.") # Retrieve documents total_started_at = time.perf_counter() retrieval_started_at = time.perf_counter() retrieval_results = self.retrieval_service.retrieve( query=topic, top_k=top_k, score_threshold=score_threshold, filters=filters ) retrieval_ms = (time.perf_counter() - retrieval_started_at) * 1000.0 if not retrieval_results: return create_blocked_sourced_response( mode="summarize", query=topic, reason="empty_retrieval", metadata={ **(summary_meta or {}), "phase_timings_ms": { "retrieval_ms": round(retrieval_ms, 1), "context_enrichment_ms": 0.0, "generation_ms": 0.0, "sources_build_ms": 0.0, "total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1), }, }, ) # Enrich results for consistent metadata access context_started_at = time.perf_counter() context_slices = self.context_service.enrich_results( retrieval_results, include_relations=include_relations, include_subjects=True ) context_ms = (time.perf_counter() - context_started_at) * 1000.0 # Budget: cap context to leave room for prompt + output max_context_chars = get_config().generation.summarize_max_context_chars total_chars = 0 budgeted_slices: List[Any] = [] for cs in context_slices: slice_chars = len(cs.content) if total_chars + slice_chars > max_context_chars and budgeted_slices: break budgeted_slices.append(cs) total_chars += slice_chars context_slices = budgeted_slices or context_slices[:1] key_fn = LlamaIndexAdapter.context_slice_key source_id_map = {key_fn(doc): f"S{idx}" for idx, doc in enumerate(context_slices, start=1)} # Use LlamaIndex TreeSummarize generation_started_at = time.perf_counter() summary = self.llamaindex_adapter.summarize( generation_topic, context_slices, source_id_map=source_id_map, ) generation_ms = (time.perf_counter() - generation_started_at) * 1000.0 sources_build_started_at = time.perf_counter() documents: List[Dict[str, Any]] = [] for idx, doc in enumerate(context_slices, start=1): documents.append( build_source_document( doc, include_full_content=include_full_content, include_content_preview=True, source_id=source_id_map.get(key_fn(doc), f"S{idx}"), ) ) sources_build_ms = (time.perf_counter() - sources_build_started_at) * 1000.0 if not documents: return create_blocked_sourced_response( mode="summarize", query=topic, reason="no_source_context", metadata={ **(summary_meta or {}), "phase_timings_ms": { "retrieval_ms": round(retrieval_ms, 1), "context_enrichment_ms": round(context_ms, 1), "generation_ms": round(generation_ms, 1), "sources_build_ms": round(sources_build_ms, 1), "total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1), }, }, ) enforcement = enforce_cited_answer( mode="summarize", question=topic, draft_answer=summary, sources=documents, llm=self.citation_llm, ) if enforcement["blocked"]: return create_blocked_sourced_response( mode="summarize", query=topic, reason="invalid_citations", sources={ "documents": documents, "acts": collect_act_contexts(context_slices), }, metadata={ **(summary_meta or {}), "citation_validation": enforcement["validation"], "citation_failure_detail": describe_citation_validation_failure(enforcement["validation"]), "citation_repair_attempted": bool(enforcement["repair_attempted"]), "citation_repaired": bool(enforcement["repaired"]), "phase_timings_ms": { "retrieval_ms": round(retrieval_ms, 1), "context_enrichment_ms": round(context_ms, 1), "generation_ms": round(generation_ms, 1), "sources_build_ms": round(sources_build_ms, 1), "total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1), }, }, ) summary = str(enforcement["answer"]) response = create_summarize_response( query=topic, answer=summary, documents=documents, acts=collect_act_contexts(context_slices) ) response["metadata"]["citation_validation"] = enforcement["validation"] response["metadata"]["citation_repair_attempted"] = bool(enforcement["repair_attempted"]) response["metadata"]["citation_repaired"] = bool(enforcement["repaired"]) if summary_meta: response["metadata"].update(summary_meta) response["metadata"]["phase_timings_ms"] = { "retrieval_ms": round(retrieval_ms, 1), "context_enrichment_ms": round(context_ms, 1), "generation_ms": round(generation_ms, 1), "sources_build_ms": round(sources_build_ms, 1), "total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1), } return response
[docs] class CompareMode: """ MODE 5: Comparison Compare multiple legal documents """ def __init__( self, retrieval_service: RetrievalService, context_service: ContextService, llamaindex_adapter: Optional[LlamaIndexAdapter], citation_llm: Any = None, question_summary_service: Optional[QuestionSummaryService] = None, ): """ Initialize compare mode Args: retrieval_service: Service for document retrieval context_service: Service for context enrichment llamaindex_adapter: LlamaIndex adapter (optional) """ self.retrieval_service = retrieval_service self.context_service = context_service self.llamaindex_adapter = llamaindex_adapter self.citation_llm = citation_llm self.question_summary_service = question_summary_service
[docs] def compare( self, comparison_question: str, celex_list: Optional[List[str]] = None, top_k: int = 10, score_threshold: Optional[float] = None, include_full_content: bool = False, ) -> Dict[str, Any]: """ Compare multiple legal documents Args: comparison_question: What to compare celex_list: Optional list of specific CELEX to compare top_k: Number of documents if CELEX not specified Returns: Dictionary with comparison and sources """ if self.llamaindex_adapter is None: raise ValueError("No LlamaIndex adapter is available for the current LLM provider.") filtered_celex_list = [celex for celex in (celex_list or []) if celex] if len(filtered_celex_list) < 2: raise ValueError("Compare mode requires at least two CELEX identifiers.") generation_question = comparison_question summary_meta: Dict[str, Any] = { "summary_source": "question_augmented", "summary_cache_hit": False, "summary_stale": False, } if self.question_summary_service is not None: generation_question, summary_meta = self.question_summary_service.augment_compare_question( comparison_question=comparison_question, celex_list=filtered_celex_list, ) # Retrieve specific documents with balanced top_k budget total_started_at = time.perf_counter() all_results: List[Any] = [] requested_top_k = max(top_k, 1) total_celex = len(filtered_celex_list) base_limit = requested_top_k // total_celex if total_celex else 0 remainder = requested_top_k % total_celex if total_celex else 0 retrieval_started_at = time.perf_counter() for idx, celex in enumerate(filtered_celex_list): celex_top_k = base_limit + (1 if idx < remainder else 0) if celex_top_k <= 0: continue results = self.retrieval_service.retrieve( query=comparison_question, top_k=celex_top_k, score_threshold=score_threshold, filters={"celex": celex}, ) all_results.extend(results) retrieval_ms = (time.perf_counter() - retrieval_started_at) * 1000.0 if not all_results: return create_blocked_sourced_response( mode="compare", query=comparison_question, reason="empty_retrieval", metadata={ **summary_meta, "phase_timings_ms": { "retrieval_ms": round(retrieval_ms, 1), "context_enrichment_ms": 0.0, "generation_ms": 0.0, "sources_build_ms": 0.0, "total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1), }, }, ) context_started_at = time.perf_counter() context_slices = self.context_service.enrich_results( all_results, include_relations=False, include_subjects=True ) context_ms = (time.perf_counter() - context_started_at) * 1000.0 # Budget: cap context to leave room for prompt + output max_context_chars = get_config().generation.summarize_max_context_chars total_chars = 0 budgeted_slices: List[Any] = [] for cs in context_slices: slice_chars = len(cs.content) if total_chars + slice_chars > max_context_chars and budgeted_slices: break budgeted_slices.append(cs) total_chars += slice_chars context_slices = budgeted_slices or context_slices[:1] key_fn = LlamaIndexAdapter.context_slice_key source_id_map = {key_fn(doc): f"S{idx}" for idx, doc in enumerate(context_slices, start=1)} # Use LlamaIndex for structured multi-document comparison generation_started_at = time.perf_counter() comparison = self.llamaindex_adapter.compare( generation_question, context_slices, filtered_celex_list, source_id_map=source_id_map, ) generation_ms = (time.perf_counter() - generation_started_at) * 1000.0 sources_build_started_at = time.perf_counter() documents: List[Dict[str, Any]] = [] for idx, doc in enumerate(context_slices, start=1): documents.append( build_source_document( doc, include_full_content=include_full_content, include_content_preview=True, source_id=source_id_map.get(key_fn(doc), f"S{idx}"), ) ) documents_compared = list(set(doc.act.celex for doc in context_slices if doc.act.celex)) sources_build_ms = (time.perf_counter() - sources_build_started_at) * 1000.0 if not documents: return create_blocked_sourced_response( mode="compare", query=comparison_question, reason="no_source_context", metadata={ **summary_meta, "phase_timings_ms": { "retrieval_ms": round(retrieval_ms, 1), "context_enrichment_ms": round(context_ms, 1), "generation_ms": round(generation_ms, 1), "sources_build_ms": round(sources_build_ms, 1), "total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1), }, }, ) enforcement = enforce_cited_answer( mode="compare", question=comparison_question, draft_answer=comparison, sources=documents, llm=self.citation_llm, ) if enforcement["blocked"]: return create_blocked_sourced_response( mode="compare", query=comparison_question, reason="invalid_citations", sources={ "documents": documents, "acts": collect_act_contexts(context_slices), }, metadata={ **summary_meta, "citation_validation": enforcement["validation"], "citation_failure_detail": describe_citation_validation_failure(enforcement["validation"]), "citation_repair_attempted": bool(enforcement["repair_attempted"]), "citation_repaired": bool(enforcement["repaired"]), "phase_timings_ms": { "retrieval_ms": round(retrieval_ms, 1), "context_enrichment_ms": round(context_ms, 1), "generation_ms": round(generation_ms, 1), "sources_build_ms": round(sources_build_ms, 1), "total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1), }, }, ) comparison = str(enforcement["answer"]) response = create_compare_response( query=comparison_question, answer=comparison, documents=documents, documents_compared=documents_compared, acts=collect_act_contexts(context_slices), ) response["metadata"]["citation_validation"] = enforcement["validation"] response["metadata"]["citation_repair_attempted"] = bool(enforcement["repair_attempted"]) response["metadata"]["citation_repaired"] = bool(enforcement["repaired"]) response["metadata"].update(summary_meta) response["metadata"]["phase_timings_ms"] = { "retrieval_ms": round(retrieval_ms, 1), "context_enrichment_ms": round(context_ms, 1), "generation_ms": round(generation_ms, 1), "sources_build_ms": round(sources_build_ms, 1), "total_ms": round((time.perf_counter() - total_started_at) * 1000.0, 1), } return response