"""
LlamaIndex Adapter
Utilities for using LlamaIndex with context slices
"""
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from llama_index.core import response_synthesizers as _li_response_synth
from llama_index.core.llms import LLM
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.core.schema import NodeWithScore, TextNode
from lalandre_rag.retrieval.context import ContextSlice
from ..prompts import get_llamaindex_prompt
from ..response import format_doc_location, format_source_header
ResponseSynthesizerFactory = Callable[..., Any]
get_response_synthesizer = cast(
ResponseSynthesizerFactory,
getattr(_li_response_synth, "get_response_synthesizer"),
)
[docs]
class LlamaIndexAdapter:
"""
Adapter for using LlamaIndex with context slice objects
Provides:
- Document to node conversion
- TreeSummarize for long documents
- Multi-document comparison
"""
def __init__(self, llama_llm: LLM):
"""
Initialize LlamaIndex adapter
Args:
llama_llm: LlamaIndex-compatible LLM client
"""
self.llama_llm = llama_llm
[docs]
@staticmethod
def context_slice_key(doc: ContextSlice) -> Tuple[str, int, Optional[int]]:
"""Return the stable lookup key used for source identifiers."""
if doc.doc.chunk_id is not None:
return ("chunk", doc.doc.chunk_id, doc.act.act_id)
return ("subdivision", doc.doc.subdivision_id, doc.act.act_id)
[docs]
def context_slices_to_nodes(
self,
context_slices: List[ContextSlice],
source_id_map: Optional[Dict[Tuple[str, int, Optional[int]], str]] = None,
) -> List[NodeWithScore]:
"""
Convert ContextSlice objects to LlamaIndex NodeWithScore
Args:
context_slices: List of context slices
Returns:
List of LlamaIndex nodes with scores
"""
nodes: List[NodeWithScore] = []
for idx, doc in enumerate(context_slices, start=1):
# Prefer upstream payload metadata (already standardized in common builders)
metadata: Dict[str, Any] = dict(doc.doc.payload or {})
key = self.context_slice_key(doc)
source_id = (source_id_map.get(key) if source_id_map else None) or f"S{idx}"
metadata.setdefault("celex", doc.act.celex)
metadata.setdefault("title", doc.act.title)
metadata.setdefault("subdivision_type", doc.doc.subdivision_type)
metadata.setdefault("act_type", doc.act.act_type)
metadata.setdefault("sequence_order", doc.doc.sequence_order)
metadata.setdefault("subdivision_id", doc.doc.subdivision_id)
metadata.setdefault("act_id", doc.act.act_id)
metadata.setdefault("source_kind", doc.doc.source_kind)
metadata.setdefault("source_id", source_id)
if doc.doc.chunk_id is not None:
metadata.setdefault("chunk_id", doc.doc.chunk_id)
if doc.doc.chunk_index is not None:
metadata.setdefault("chunk_index", doc.doc.chunk_index)
if doc.doc.char_start is not None:
metadata.setdefault("char_start", doc.doc.char_start)
if doc.doc.char_end is not None:
metadata.setdefault("char_end", doc.doc.char_end)
if doc.act.url_eurlex:
metadata.setdefault("url_eurlex", doc.act.url_eurlex)
# Create TextNode with inline source header for citations
node_id = (
f"{doc.act.celex}_{doc.doc.chunk_id}"
if doc.doc.chunk_id is not None
else f"{doc.act.celex}_{doc.doc.subdivision_id}"
)
location = format_doc_location(
doc.doc.chunk_id,
doc.doc.chunk_index,
doc.doc.subdivision_type,
doc.doc.subdivision_id,
)
header = format_source_header(
source_id,
doc.act.celex,
location,
doc.act.title,
regulatory_level=doc.act.regulatory_level,
)
node = TextNode(text=f"{header}\n{doc.content}", id_=node_id)
node.metadata = metadata
# Wrap in NodeWithScore
node_with_score = NodeWithScore.model_validate(
{
"node": node,
"score": doc.score,
}
)
nodes.append(node_with_score)
return nodes
[docs]
def summarize(
self,
topic: str,
context_slices: List[ContextSlice],
source_id_map: Optional[Dict[Tuple[str, int, Optional[int]], str]] = None,
) -> str:
"""
Use LlamaIndex TreeSummarize for hierarchical summarization
Better for long documents as it summarizes in chunks then combines
Args:
topic: Topic to summarize
context_slices: Context slices to summarize
Returns:
Summary text
"""
# Convert to LlamaIndex nodes
nodes = self.context_slices_to_nodes(context_slices, source_id_map=source_id_map)
# Get summary template from centralized prompts
summary_template = get_llamaindex_prompt("summary")
# Create response synthesizer with TreeSummarize
synthesizer = get_response_synthesizer(
response_mode=ResponseMode.TREE_SUMMARIZE,
llm=self.llama_llm,
summary_template=summary_template,
use_async=False,
)
# Generate summary
response = synthesizer.synthesize(query=topic, nodes=nodes)
return str(response)
[docs]
def compare(
self,
comparison_question: str,
context_slices: List[ContextSlice],
celex_list: List[str],
source_id_map: Optional[Dict[Tuple[str, int, Optional[int]], str]] = None,
) -> str:
"""
Use LlamaIndex for intelligent multi-document comparison
Groups documents by CELEX and compares systematically
Args:
comparison_question: Question for comparison
context_slices: Context slices to compare
celex_list: List of CELEX codes being compared
Returns:
Comparison text
"""
# Group documents by CELEX
docs_by_celex: Dict[str, List[ContextSlice]] = {}
for doc in context_slices:
if doc.act.celex:
if doc.act.celex in docs_by_celex:
docs_by_celex[doc.act.celex].append(doc)
else:
docs_by_celex[doc.act.celex] = [doc]
# Create nodes per document
all_nodes: List[NodeWithScore] = []
for docs in docs_by_celex.values():
nodes = self.context_slices_to_nodes(docs, source_id_map=source_id_map)
all_nodes.extend(nodes)
# Le template de comparaison est centralisé dans prompts.py
comparison_template = get_llamaindex_prompt("comparison")
# Create response synthesizer
synthesizer = get_response_synthesizer(
response_mode=ResponseMode.TREE_SUMMARIZE,
llm=self.llama_llm,
summary_template=comparison_template,
use_async=False,
)
# Generate comparison
response = synthesizer.synthesize(query=comparison_question, nodes=all_nodes)
return str(response)