Source code for lalandre_chunking.sac_chunker

"""
Semantic Aware Chunking (SAC) for Legal Documents.
Detects chunk boundaries by measuring embedding similarity between consecutive
sentences: breaks happen where the meaning shifts the most.
"""

import logging
import re
from typing import List, Optional

import numpy as np
from lalandre_core.models import Chunks
from lalandre_embedding.base import EmbeddingProvider
from langchain_text_splitters import RecursiveCharacterTextSplitter

from .chunker import Chunker

HEADING_TYPES = {"title", "chapter", "section"}

ROMAN_NUMERALS = (
    r"(?:i|ii|iii|iv|v|vi|vii|viii|ix|x|xi|xii|xiii|xiv|"
    r"xv|xvi|xvii|xviii|xix|xx)"
)
ROMAN_LIST_ITEM = rf"(?:^|\s)(?i:{ROMAN_NUMERALS})\s*\.\s*[-–—]\s+"

LIST_SEPARATORS = [
    ROMAN_LIST_ITEM,  # I. - / II. – list items
    r"(?:^|\s)\(\d{1,3}\)\s+",  # (1) list items
    r"(?:^|\s)\d{1,3}°\s+",  # 1° list items
    r"(?:^|\s)\d{1,3}\.\s+",  # 1. list items
    r"(?:^|\s)(?i:[a-z])\)\s+",  # a) list items
    r"(?:^|\s)(?i:[a-z])\.\s+",  # a. list items
    r"\n\(\d+\)",  # (1) list items (newline)
    r"\n\d+\.\s+",  # 1. list items (newline)
    r"\n[a-z]\)\s+",  # a) list items (newline)
    r"\n[a-z]\.\s+",  # a. list items (newline)
]

WHITESPACE_SEPARATORS = [
    r"\n\n",  # paragraph breaks
    r"\n",  # line breaks
    r"(?<=[.;:])\s+",  # sentence boundary
    r"\s+",  # whitespace fallback
]

logger = logging.getLogger(__name__)

# ── Sentence splitting for legal French / English text ────────────────────

# Common legal abbreviations that should NOT be treated as sentence-ending periods.
# The lookbehinds include the trailing period so that "art. " is correctly skipped.
_ABBREV_LOOKBEHINDS = (
    r"(?<!art\.)(?<!al\.)(?<!par\.)(?<!no\.)(?<!cf\.)(?<!etc\.)(?<!dir\.)(?<!reg\.)"
    r"(?<!ch\.)(?<!sect\.)(?<!pt\.)(?<!vol\.)(?<!eds\.)"
    r"(?<![ \t][A-Z]\.)"  # single uppercase initial preceded by space: M., J., etc.
)

_SENTENCE_BOUNDARY = re.compile(rf"{_ABBREV_LOOKBEHINDS}(?<=[.!?;])\s+", re.UNICODE)

_FALLBACK_BOUNDARY = re.compile(r"\n\n+")


def _split_sentences(text: str) -> List[str]:
    """Split *text* into sentences, tuned for legal prose."""
    parts = _SENTENCE_BOUNDARY.split(text)
    # If the regex produced nothing useful, fall back to paragraph breaks.
    if len(parts) <= 1:
        parts = _FALLBACK_BOUNDARY.split(text)
    # Last resort: single-newline split.
    if len(parts) <= 1:
        parts = text.split("\n")
    # Drop empty / whitespace-only fragments.
    return [p for p in parts if p.strip()]


def _cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
    a = np.asarray(vec_a, dtype=np.float32)
    b = np.asarray(vec_b, dtype=np.float32)
    denom = np.linalg.norm(a) * np.linalg.norm(b)
    return float(np.dot(a, b) / denom) if denom > 0 else 0.0


# ── SACChunker ────────────────────────────────────────────────────────────


[docs] class SACChunker(Chunker): """ Semantic Aware Chunker. 1. Split text into sentences. 2. Embed each sentence (optionally using a sliding window). 3. Compute cosine similarity between consecutive embeddings. 4. Place breakpoints where similarity drops below a percentile threshold. 5. Group sentences between breakpoints, then enforce min/max size constraints. """ method_name = "sac" def __init__( self, embedding_provider: EmbeddingProvider, min_chunk_size: int, max_chunk_size: int, chunk_overlap: int, chars_per_token: float = 3.3, breakpoint_percentile: float = 90.0, breakpoint_max_threshold: float = 1.0, sentence_window_size: int = 1, batch_size: int = 32, ): super().__init__( min_chunk_size=min_chunk_size, max_chunk_size=max_chunk_size, chunk_overlap=chunk_overlap, chars_per_token=chars_per_token, ) self._embedding_provider = embedding_provider self._breakpoint_percentile = breakpoint_percentile self._breakpoint_max_threshold = breakpoint_max_threshold self._sentence_window_size = max(1, sentence_window_size) self._batch_size = max(1, batch_size) # Lazy fallback splitter for oversized groups. self._fallback_splitter: Optional[RecursiveCharacterTextSplitter] = None # ── public interface ──────────────────────────────────────────────
[docs] def chunk_subdivision( self, subdivision_id: int, content: str, subdivision_type: Optional[str] = None, ) -> List[Chunks]: """Split one subdivision into semantic chunks with legal-text safeguards.""" normalized_type = self._normalize_subdivision_type(subdivision_type) # Edge case: empty content if not content: logger.debug("Empty content for subdivision %d, skipping", subdivision_id) return [] # Edge case: short content → single chunk if len(content) < self.min_chunk_size: return [ self._create_chunk( subdivision_id=subdivision_id, chunk_index=0, content=content, char_start=0, char_end=len(content), metadata={"is_single_chunk": True, "subdivision_type": normalized_type}, ) ] # Edge case: heading types → single chunk if fits if normalized_type in HEADING_TYPES and len(content) <= self.max_chunk_size: return [ self._create_chunk( subdivision_id=subdivision_id, chunk_index=0, content=content, char_start=0, char_end=len(content), metadata={ "is_single_chunk": True, "is_heading": True, "subdivision_type": normalized_type, }, ) ] # ── SAC pipeline ────────────────────────────────────────────── sentences = _split_sentences(content) # Too few sentences for meaningful similarity → fallback to rule-based split if len(sentences) <= 2: return self._fallback_chunk(subdivision_id, content, normalized_type) # 1. Build windows & embed windows = self._build_windows(sentences) embeddings = self._compute_embeddings(windows) # 2. Inter-sentence similarities similarities = [_cosine_similarity(embeddings[i], embeddings[i + 1]) for i in range(len(embeddings) - 1)] # 3. Detect breakpoints breakpoints = self._detect_breakpoints(similarities) # 4. Group sentences groups = self._group_sentences(sentences, breakpoints) # 5. Enforce size constraints final_texts = self._enforce_size_constraints(groups) # 6. Map to positions & merge tiny tail positions = self._map_pieces_to_positions(content, final_texts) positions = self._merge_small_tail(content, positions) # Compute similarity threshold (for metadata) threshold = self._percentile_threshold(similarities) logger.debug( "Subdivision %d: len=%d, sentences=%d, breakpoints=%d, chunks=%d", subdivision_id, len(content), len(sentences), len(breakpoints), len(positions), ) chunks: List[Chunks] = [] for idx, (start, end) in enumerate(positions): chunk_text = content[start:end] if not chunk_text: continue chunks.append( self._create_chunk( subdivision_id=subdivision_id, chunk_index=idx, content=chunk_text, char_start=start, char_end=end, metadata={ "subdivision_type": normalized_type, "chunk_overlap": self.chunk_overlap, "breakpoint_percentile": self._breakpoint_percentile, "breakpoint_max_threshold": self._breakpoint_max_threshold, "sentence_window_size": self._sentence_window_size, "similarity_threshold": round(threshold, 4), }, ) ) return chunks
# ── internal helpers ────────────────────────────────────────────── def _build_windows(self, sentences: List[str]) -> List[str]: if self._sentence_window_size <= 1: return sentences windows: List[str] = [] half = self._sentence_window_size // 2 for i in range(len(sentences)): start = max(0, i - half) end = min(len(sentences), i + half + 1) windows.append(" ".join(sentences[start:end])) return windows def _compute_embeddings(self, texts: List[str]) -> List[List[float]]: all_embeddings: List[List[float]] = [] for i in range(0, len(texts), self._batch_size): batch = texts[i : i + self._batch_size] all_embeddings.extend(self._embedding_provider.embed_batch(batch)) return all_embeddings def _percentile_threshold(self, similarities: List[float]) -> float: if not similarities: return 0.0 sorted_sims = sorted(similarities) idx = int(len(sorted_sims) * (1 - self._breakpoint_percentile / 100.0)) return sorted_sims[max(0, idx)] def _detect_breakpoints(self, similarities: List[float]) -> List[int]: if not similarities: return [] threshold = min(self._percentile_threshold(similarities), self._breakpoint_max_threshold) return [i + 1 for i, sim in enumerate(similarities) if sim <= threshold] def _group_sentences(self, sentences: List[str], breakpoints: List[int]) -> List[str]: groups: List[str] = [] prev = 0 for bp in breakpoints: group = " ".join(sentences[prev:bp]) if group.strip(): groups.append(group) prev = bp # Last group tail = " ".join(sentences[prev:]) if tail.strip(): groups.append(tail) return groups def _enforce_size_constraints(self, groups: List[str]) -> List[str]: """Merge groups that are too small, re-split those that are too large.""" # ── merge small consecutive groups ── merged: List[str] = [] for g in groups: if merged and len(merged[-1]) < self.min_chunk_size: merged[-1] = merged[-1] + " " + g else: merged.append(g) # Handle last group being too small if len(merged) >= 2 and len(merged[-1]) < self.min_chunk_size: combined = merged[-2] + " " + merged[-1] if len(combined) <= self.max_chunk_size: merged[-2] = combined merged.pop() # ── re-split oversized groups ── result: List[str] = [] for g in merged: if len(g) <= self.max_chunk_size: result.append(g) else: result.extend(self._split_oversized(g)) return result def _split_oversized(self, text: str) -> List[str]: splitter = self._get_fallback_splitter() pieces = [p for p in splitter.split_text(text) if p] return pieces if pieces else [text] def _get_fallback_splitter(self) -> RecursiveCharacterTextSplitter: if self._fallback_splitter is None: self._fallback_splitter = RecursiveCharacterTextSplitter( separators=[*LIST_SEPARATORS, *WHITESPACE_SEPARATORS], keep_separator="start", is_separator_regex=True, chunk_size=self.max_chunk_size, chunk_overlap=self.chunk_overlap, length_function=len, ) return self._fallback_splitter def _fallback_chunk( self, subdivision_id: int, content: str, normalized_type: Optional[str], ) -> List[Chunks]: """Rule-based fallback for content with too few sentences.""" splitter = self._get_fallback_splitter() pieces = [p for p in splitter.split_text(content) if p] if not pieces: return [ self._create_chunk( subdivision_id=subdivision_id, chunk_index=0, content=content, char_start=0, char_end=len(content), metadata={"is_single_chunk": True, "subdivision_type": normalized_type}, ) ] positions = self._map_pieces_to_positions(content, pieces) positions = self._merge_small_tail(content, positions) chunks: List[Chunks] = [] for idx, (start, end) in enumerate(positions): chunk_text = content[start:end] if not chunk_text: continue chunks.append( self._create_chunk( subdivision_id=subdivision_id, chunk_index=idx, content=chunk_text, char_start=start, char_end=end, metadata={ "splitter": "recursive_fallback", "subdivision_type": normalized_type, "chunk_overlap": self.chunk_overlap, }, ) ) return chunks