"""
Chunking base classes and helpers
"""
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from lalandre_core.models import Chunks
logger = logging.getLogger(__name__)
_DEFAULT_CHARS_PER_TOKEN = 3.3
[docs]
class Chunker(ABC):
"""Base chunker interface."""
method_name = "chunker"
def __init__(
self,
min_chunk_size: int,
max_chunk_size: int,
chunk_overlap: int,
chars_per_token: float = _DEFAULT_CHARS_PER_TOKEN,
):
self.min_chunk_size = min_chunk_size
self.max_chunk_size = max_chunk_size
self.chunk_overlap = max(0, chunk_overlap)
self._chars_per_token = max(0.1, chars_per_token)
[docs]
@abstractmethod
def chunk_subdivision(
self, subdivision_id: int, content: str, subdivision_type: Optional[str] = None
) -> List[Chunks]:
"""Chunk a subdivision's content."""
def _estimate_tokens(self, text: str) -> int:
"""Approximate token count (informational metadata only)."""
return max(1, int(len(text) / self._chars_per_token))
def _create_chunk(
self,
subdivision_id: int,
chunk_index: int,
content: str,
char_start: int,
char_end: int,
metadata: Dict[str, Any],
) -> Chunks:
"""Create a Chunk object."""
return Chunks(
id=None,
subdivision_id=subdivision_id,
chunk_index=chunk_index,
content=content,
char_start=char_start,
char_end=char_end,
token_count=self._estimate_tokens(content),
chunk_metadata={**metadata, "chunking_method": self.method_name},
created_at=None,
)
[docs]
def make_single_chunk(
self,
subdivision_id: int,
content: str,
metadata: Dict[str, Any] | None = None,
) -> Chunks:
"""Create a single chunk encompassing all content (no splitting)."""
base_meta: Dict[str, Any] = {"is_single_chunk": True}
if metadata:
base_meta.update(metadata)
return self._create_chunk(
subdivision_id=subdivision_id,
chunk_index=0,
content=content,
char_start=0,
char_end=len(content),
metadata=base_meta,
)
# ──────────────────────────────────────────────────────────────────
# Shared utilities (used by SemanticChunker and SACChunker)
# ──────────────────────────────────────────────────────────────────
def _normalize_subdivision_type(self, subdivision_type: Any) -> Optional[str]:
if subdivision_type is None:
return None
type_value = getattr(subdivision_type, "value", subdivision_type)
return str(type_value).strip().lower() or None
def _map_pieces_to_positions(self, content: str, pieces: List[str]) -> List[tuple[int, int]]:
positions: List[tuple[int, int]] = []
cursor = 0
for piece in pieces:
search_from = max(0, cursor - self.chunk_overlap) if self.chunk_overlap else cursor
start, end = self._locate_piece(content, piece, search_from, cursor)
positions.append((start, end))
cursor = end
if positions and cursor < len(content):
trailing = content[cursor:]
if not trailing.strip():
last_start, _ = positions[-1]
positions[-1] = (last_start, len(content))
return positions
def _locate_piece(self, content: str, piece: str, search_from: int, cursor: int) -> tuple[int, int]:
"""Find piece in content, tolerating whitespace differences."""
# 1. Exact match
start = content.find(piece, search_from)
if start != -1:
return start, start + len(piece)
stripped = piece.strip()
# 2. Stripped match
if stripped:
start = content.find(stripped, search_from)
if start != -1:
return start, start + len(stripped)
# 3. Prefix + suffix anchor (handles whitespace normalisation from
# sentence joining where \n/\r\n/multi-space become a single space)
if stripped and len(stripped) > 10:
anchor_len = min(50, len(stripped) // 2)
prefix = stripped[:anchor_len]
start = content.find(prefix, search_from)
if start != -1:
suffix = stripped[-anchor_len:]
# search for suffix near expected end
end_search = max(start + len(stripped) - anchor_len - 20, start)
end_pos = content.find(suffix, end_search)
if end_pos != -1:
return start, end_pos + len(suffix)
return start, min(len(content), start + len(piece))
# 4. Cursor fallback
logger.debug(
"Chunk piece not found in content (len=%d, cursor=%d, piece_len=%d), using cursor as fallback",
len(content),
cursor,
len(piece),
)
return cursor, min(len(content), cursor + len(piece))
def _merge_small_tail(
self,
content: str,
positions: List[tuple[int, int]],
) -> List[tuple[int, int]]:
if len(positions) < 2:
return positions
last_start, last_end = positions[-1]
if (last_end - last_start) >= self.min_chunk_size:
return positions
prev_start, _ = positions[-2]
merged_size = last_end - prev_start
if merged_size > self.max_chunk_size:
logger.debug(
"Skipping tail merge: merged size %d would exceed max_chunk_size %d",
merged_size,
self.max_chunk_size,
)
return positions
logger.debug(
"Merging small tail (%d chars) into previous chunk",
last_end - last_start,
)
return positions[:-2] + [(prev_start, last_end)]