"""
Semantic Aware Chunking (SAC) for Legal Documents.
Detects chunk boundaries by measuring embedding similarity between consecutive
sentences: breaks happen where the meaning shifts the most.
"""
import logging
import re
from typing import List, Optional
import numpy as np
from lalandre_core.models import Chunks
from lalandre_embedding.base import EmbeddingProvider
from langchain_text_splitters import RecursiveCharacterTextSplitter
from .chunker import Chunker
HEADING_TYPES = {"title", "chapter", "section"}
ROMAN_NUMERALS = (
r"(?:i|ii|iii|iv|v|vi|vii|viii|ix|x|xi|xii|xiii|xiv|"
r"xv|xvi|xvii|xviii|xix|xx)"
)
ROMAN_LIST_ITEM = rf"(?:^|\s)(?i:{ROMAN_NUMERALS})\s*\.\s*[-–—]\s+"
LIST_SEPARATORS = [
ROMAN_LIST_ITEM, # I. - / II. – list items
r"(?:^|\s)\(\d{1,3}\)\s+", # (1) list items
r"(?:^|\s)\d{1,3}°\s+", # 1° list items
r"(?:^|\s)\d{1,3}\.\s+", # 1. list items
r"(?:^|\s)(?i:[a-z])\)\s+", # a) list items
r"(?:^|\s)(?i:[a-z])\.\s+", # a. list items
r"\n\(\d+\)", # (1) list items (newline)
r"\n\d+\.\s+", # 1. list items (newline)
r"\n[a-z]\)\s+", # a) list items (newline)
r"\n[a-z]\.\s+", # a. list items (newline)
]
WHITESPACE_SEPARATORS = [
r"\n\n", # paragraph breaks
r"\n", # line breaks
r"(?<=[.;:])\s+", # sentence boundary
r"\s+", # whitespace fallback
]
logger = logging.getLogger(__name__)
# ── Sentence splitting for legal French / English text ────────────────────
# Common legal abbreviations that should NOT be treated as sentence-ending periods.
# The lookbehinds include the trailing period so that "art. " is correctly skipped.
_ABBREV_LOOKBEHINDS = (
r"(?<!art\.)(?<!al\.)(?<!par\.)(?<!no\.)(?<!cf\.)(?<!etc\.)(?<!dir\.)(?<!reg\.)"
r"(?<!ch\.)(?<!sect\.)(?<!pt\.)(?<!vol\.)(?<!eds\.)"
r"(?<![ \t][A-Z]\.)" # single uppercase initial preceded by space: M., J., etc.
)
_SENTENCE_BOUNDARY = re.compile(rf"{_ABBREV_LOOKBEHINDS}(?<=[.!?;])\s+", re.UNICODE)
_FALLBACK_BOUNDARY = re.compile(r"\n\n+")
def _split_sentences(text: str) -> List[str]:
"""Split *text* into sentences, tuned for legal prose."""
parts = _SENTENCE_BOUNDARY.split(text)
# If the regex produced nothing useful, fall back to paragraph breaks.
if len(parts) <= 1:
parts = _FALLBACK_BOUNDARY.split(text)
# Last resort: single-newline split.
if len(parts) <= 1:
parts = text.split("\n")
# Drop empty / whitespace-only fragments.
return [p for p in parts if p.strip()]
def _cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
a = np.asarray(vec_a, dtype=np.float32)
b = np.asarray(vec_b, dtype=np.float32)
denom = np.linalg.norm(a) * np.linalg.norm(b)
return float(np.dot(a, b) / denom) if denom > 0 else 0.0
# ── SACChunker ────────────────────────────────────────────────────────────
[docs]
class SACChunker(Chunker):
"""
Semantic Aware Chunker.
1. Split text into sentences.
2. Embed each sentence (optionally using a sliding window).
3. Compute cosine similarity between consecutive embeddings.
4. Place breakpoints where similarity drops below a percentile threshold.
5. Group sentences between breakpoints, then enforce min/max size constraints.
"""
method_name = "sac"
def __init__(
self,
embedding_provider: EmbeddingProvider,
min_chunk_size: int,
max_chunk_size: int,
chunk_overlap: int,
chars_per_token: float = 3.3,
breakpoint_percentile: float = 90.0,
breakpoint_max_threshold: float = 1.0,
sentence_window_size: int = 1,
batch_size: int = 32,
):
super().__init__(
min_chunk_size=min_chunk_size,
max_chunk_size=max_chunk_size,
chunk_overlap=chunk_overlap,
chars_per_token=chars_per_token,
)
self._embedding_provider = embedding_provider
self._breakpoint_percentile = breakpoint_percentile
self._breakpoint_max_threshold = breakpoint_max_threshold
self._sentence_window_size = max(1, sentence_window_size)
self._batch_size = max(1, batch_size)
# Lazy fallback splitter for oversized groups.
self._fallback_splitter: Optional[RecursiveCharacterTextSplitter] = None
# ── public interface ──────────────────────────────────────────────
[docs]
def chunk_subdivision(
self,
subdivision_id: int,
content: str,
subdivision_type: Optional[str] = None,
) -> List[Chunks]:
"""Split one subdivision into semantic chunks with legal-text safeguards."""
normalized_type = self._normalize_subdivision_type(subdivision_type)
# Edge case: empty content
if not content:
logger.debug("Empty content for subdivision %d, skipping", subdivision_id)
return []
# Edge case: short content → single chunk
if len(content) < self.min_chunk_size:
return [
self._create_chunk(
subdivision_id=subdivision_id,
chunk_index=0,
content=content,
char_start=0,
char_end=len(content),
metadata={"is_single_chunk": True, "subdivision_type": normalized_type},
)
]
# Edge case: heading types → single chunk if fits
if normalized_type in HEADING_TYPES and len(content) <= self.max_chunk_size:
return [
self._create_chunk(
subdivision_id=subdivision_id,
chunk_index=0,
content=content,
char_start=0,
char_end=len(content),
metadata={
"is_single_chunk": True,
"is_heading": True,
"subdivision_type": normalized_type,
},
)
]
# ── SAC pipeline ──────────────────────────────────────────────
sentences = _split_sentences(content)
# Too few sentences for meaningful similarity → fallback to rule-based split
if len(sentences) <= 2:
return self._fallback_chunk(subdivision_id, content, normalized_type)
# 1. Build windows & embed
windows = self._build_windows(sentences)
embeddings = self._compute_embeddings(windows)
# 2. Inter-sentence similarities
similarities = [_cosine_similarity(embeddings[i], embeddings[i + 1]) for i in range(len(embeddings) - 1)]
# 3. Detect breakpoints
breakpoints = self._detect_breakpoints(similarities)
# 4. Group sentences
groups = self._group_sentences(sentences, breakpoints)
# 5. Enforce size constraints
final_texts = self._enforce_size_constraints(groups)
# 6. Map to positions & merge tiny tail
positions = self._map_pieces_to_positions(content, final_texts)
positions = self._merge_small_tail(content, positions)
# Compute similarity threshold (for metadata)
threshold = self._percentile_threshold(similarities)
logger.debug(
"Subdivision %d: len=%d, sentences=%d, breakpoints=%d, chunks=%d",
subdivision_id,
len(content),
len(sentences),
len(breakpoints),
len(positions),
)
chunks: List[Chunks] = []
for idx, (start, end) in enumerate(positions):
chunk_text = content[start:end]
if not chunk_text:
continue
chunks.append(
self._create_chunk(
subdivision_id=subdivision_id,
chunk_index=idx,
content=chunk_text,
char_start=start,
char_end=end,
metadata={
"subdivision_type": normalized_type,
"chunk_overlap": self.chunk_overlap,
"breakpoint_percentile": self._breakpoint_percentile,
"breakpoint_max_threshold": self._breakpoint_max_threshold,
"sentence_window_size": self._sentence_window_size,
"similarity_threshold": round(threshold, 4),
},
)
)
return chunks
# ── internal helpers ──────────────────────────────────────────────
def _build_windows(self, sentences: List[str]) -> List[str]:
if self._sentence_window_size <= 1:
return sentences
windows: List[str] = []
half = self._sentence_window_size // 2
for i in range(len(sentences)):
start = max(0, i - half)
end = min(len(sentences), i + half + 1)
windows.append(" ".join(sentences[start:end]))
return windows
def _compute_embeddings(self, texts: List[str]) -> List[List[float]]:
all_embeddings: List[List[float]] = []
for i in range(0, len(texts), self._batch_size):
batch = texts[i : i + self._batch_size]
all_embeddings.extend(self._embedding_provider.embed_batch(batch))
return all_embeddings
def _percentile_threshold(self, similarities: List[float]) -> float:
if not similarities:
return 0.0
sorted_sims = sorted(similarities)
idx = int(len(sorted_sims) * (1 - self._breakpoint_percentile / 100.0))
return sorted_sims[max(0, idx)]
def _detect_breakpoints(self, similarities: List[float]) -> List[int]:
if not similarities:
return []
threshold = min(self._percentile_threshold(similarities), self._breakpoint_max_threshold)
return [i + 1 for i, sim in enumerate(similarities) if sim <= threshold]
def _group_sentences(self, sentences: List[str], breakpoints: List[int]) -> List[str]:
groups: List[str] = []
prev = 0
for bp in breakpoints:
group = " ".join(sentences[prev:bp])
if group.strip():
groups.append(group)
prev = bp
# Last group
tail = " ".join(sentences[prev:])
if tail.strip():
groups.append(tail)
return groups
def _enforce_size_constraints(self, groups: List[str]) -> List[str]:
"""Merge groups that are too small, re-split those that are too large."""
# ── merge small consecutive groups ──
merged: List[str] = []
for g in groups:
if merged and len(merged[-1]) < self.min_chunk_size:
merged[-1] = merged[-1] + " " + g
else:
merged.append(g)
# Handle last group being too small
if len(merged) >= 2 and len(merged[-1]) < self.min_chunk_size:
combined = merged[-2] + " " + merged[-1]
if len(combined) <= self.max_chunk_size:
merged[-2] = combined
merged.pop()
# ── re-split oversized groups ──
result: List[str] = []
for g in merged:
if len(g) <= self.max_chunk_size:
result.append(g)
else:
result.extend(self._split_oversized(g))
return result
def _split_oversized(self, text: str) -> List[str]:
splitter = self._get_fallback_splitter()
pieces = [p for p in splitter.split_text(text) if p]
return pieces if pieces else [text]
def _get_fallback_splitter(self) -> RecursiveCharacterTextSplitter:
if self._fallback_splitter is None:
self._fallback_splitter = RecursiveCharacterTextSplitter(
separators=[*LIST_SEPARATORS, *WHITESPACE_SEPARATORS],
keep_separator="start",
is_separator_regex=True,
chunk_size=self.max_chunk_size,
chunk_overlap=self.chunk_overlap,
length_function=len,
)
return self._fallback_splitter
def _fallback_chunk(
self,
subdivision_id: int,
content: str,
normalized_type: Optional[str],
) -> List[Chunks]:
"""Rule-based fallback for content with too few sentences."""
splitter = self._get_fallback_splitter()
pieces = [p for p in splitter.split_text(content) if p]
if not pieces:
return [
self._create_chunk(
subdivision_id=subdivision_id,
chunk_index=0,
content=content,
char_start=0,
char_end=len(content),
metadata={"is_single_chunk": True, "subdivision_type": normalized_type},
)
]
positions = self._map_pieces_to_positions(content, pieces)
positions = self._merge_small_tail(content, positions)
chunks: List[Chunks] = []
for idx, (start, end) in enumerate(positions):
chunk_text = content[start:end]
if not chunk_text:
continue
chunks.append(
self._create_chunk(
subdivision_id=subdivision_id,
chunk_index=idx,
content=chunk_text,
char_start=start,
char_end=end,
metadata={
"splitter": "recursive_fallback",
"subdivision_type": normalized_type,
"chunk_overlap": self.chunk_overlap,
},
)
)
return chunks