Source code for lalandre_embedding.pipeline

"""Reusable business logic for vector embedding.

Contains payload construction, batched embedding, mean-pooling, and
incremental state tracking. The worker imports this module and delegates
the heavy lifting here.
"""

from __future__ import annotations

import hashlib
import json
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any, Protocol

from lalandre_core.repositories.common import PayloadBuilder
from lalandre_db_qdrant import VectorPoint
from langchain_text_splitters import RecursiveCharacterTextSplitter

logger = logging.getLogger(__name__)

_SEGMENT_UUID_NAMESPACE = uuid.UUID("a6f2e1c4-7b3d-4e9f-8a2c-1d5e4f7b0c93")


def _segment_point_id(chunk_id: int, segment_index: int) -> str:
    return str(uuid.uuid5(_SEGMENT_UUID_NAMESPACE, f"{chunk_id}::seg::{segment_index}"))


[docs] class ChunkEmbeddingService(Protocol): """Embedding interface required by the chunk embedding pipeline."""
[docs] def embed_batch( self, texts: list[str], batch_size: int | None = None, ) -> list[list[float]]: """Embed a batch of texts and preserve input ordering.""" ...
[docs] def estimate_tokens(self, text: str) -> int | None: """Estimate the token count for one input text when supported.""" ...
# ═══════════════════════════════════════════════════════════════════════════ # 1. Pure helpers # ═══════════════════════════════════════════════════════════════════════════
[docs] def compute_payload_hash(payload: dict[str, Any]) -> str: """Stable hash of *payload* to detect changes across runs.""" payload_json = json.dumps(payload, sort_keys=True, ensure_ascii=False, default=str) return hashlib.sha256(payload_json.encode("utf-8")).hexdigest()
[docs] def mean_pool_vectors(vectors: list[list[float]]) -> list[float] | None: """Return the element-wise mean of *vectors*, or ``None`` if empty.""" if not vectors: return None dim = len(vectors[0]) acc = [0.0] * dim for vec in vectors: for i, v in enumerate(vec): acc[i] += v n = len(vectors) return [x / n for x in acc]
[docs] def truncate_subdivision_content(content: str, max_chars: int) -> str: """Truncate *content* for embedding if it exceeds *max_chars*.""" if not content: return content if len(content) <= max_chars: return content return content[:max_chars] + "\n\n[...Contenu tronque pour embedding...]"
[docs] def resolved_subdivision_embed_max_chars(config: Any) -> int: """Compute the max embedding chars from a ``LalandreConfig``.""" safe_embed_chars = int(config.token_limits.embedding_max_chars * config.token_limits.embedding_safety_ratio) return min(config.chunking.subdivision_max_chars, max(safe_embed_chars, 256))
[docs] def make_state_record( *, object_type: str, object_id: int, provider: str, model_name: str, vector_size: int, content_hash: str, ) -> dict[str, Any]: """Build one embedding-state upsert record.""" return { "object_type": object_type, "object_id": object_id, "provider": provider, "model_name": model_name, "vector_size": vector_size, "content_hash": content_hash, }
# ═══════════════════════════════════════════════════════════════════════════ # 3. ORM → dict extractors # ═══════════════════════════════════════════════════════════════════════════
[docs] def extract_metadata(act: Any) -> dict[str, str] | None: """Extract metadata entries from an ORM act, or ``None``.""" if act.metadata_entries: return {entry.key: entry.value for entry in act.metadata_entries} return None
def _extract_subjects(act: Any) -> list[dict[str, Any]]: """Extract subject entries from an ORM act (via junction table).""" if not act.subjects: return [] return [ { "eurovoc_code": s.subject.eurovoc_code, "label_fr": s.subject.label_fr, "label_en": s.subject.label_en, } for s in act.subjects ] def _build_summary_text(act: Any, subdivisions: list[Any]) -> str: """Build a summary text from act title + first subdivision contents.""" parts = [act.title or ""] for subdiv in subdivisions[:5]: part = str(subdiv.content or "")[:300].strip() if part: parts.append(part) return "\n\n".join(p for p in parts if p) # ═══════════════════════════════════════════════════════════════════════════ # 4. Payload builders (ORM → Qdrant payload dict) # ═══════════════════════════════════════════════════════════════════════════
[docs] def build_subdivision_payload( payload_builder: PayloadBuilder, subdivision: Any, act: Any, version: Any | None, metadata: dict[str, str] | None, content_override: str | None = None, ) -> dict[str, Any]: """Assemble the Qdrant payload dict from ORM models for a *subdivision*.""" subdivision_data: dict[str, Any] = { "id": subdivision.id, "subdivision_type": subdivision.subdivision_type.value, "number": subdivision.number, "title": subdivision.title, "sequence_order": subdivision.sequence_order, "hierarchy_path": subdivision.hierarchy_path, "depth": subdivision.depth, "parent_id": subdivision.parent_id, "content": content_override if content_override is not None else subdivision.content, } act_data: dict[str, Any] = { "id": act.id, "celex": act.celex, "title": act.title, "act_type": act.act_type.value, "language": act.language.value, "adoption_date": act.adoption_date, "force_date": act.force_date, "level": act.level, } version_data: dict[str, Any] | None = None if version: version_data = { "id": version.id, "version_number": version.version_number, "version_type": version.version_type.value, "version_date": version.version_date, "is_current": version.is_current, } return payload_builder.build_subdivision_payload( subdivision_data=subdivision_data, act_data=act_data, version_data=version_data, metadata=metadata or {}, )
[docs] def build_chunk_payload( payload_builder: PayloadBuilder, chunk: Any, subdivision: Any, act: Any, ) -> dict[str, Any]: """Assemble the Qdrant payload dict from ORM models for a *chunk*.""" return _build_chunk_payload_from_data( payload_builder=payload_builder, chunk_data={ "id": chunk.id, "subdivision_id": chunk.subdivision_id, "chunk_index": chunk.chunk_index, "content": chunk.content, "char_start": chunk.char_start, "char_end": chunk.char_end, "token_count": chunk.token_count, "chunk_metadata": chunk.chunk_metadata, }, subdivision=subdivision, act=act, )
def _build_chunk_payload_from_data( *, payload_builder: PayloadBuilder, chunk_data: dict[str, Any], subdivision: Any, act: Any, ) -> dict[str, Any]: subdivision_data: dict[str, Any] = { "subdivision_type": subdivision.subdivision_type.value, "number": subdivision.number, "title": subdivision.title, } act_data: dict[str, Any] = { "id": act.id, "celex": act.celex, "title": act.title, "act_type": act.act_type.value, } return payload_builder.build_chunk_payload( chunk_data=chunk_data, subdivision_data=subdivision_data, act_data=act_data, )
[docs] def build_act_document_payload( payload_builder: PayloadBuilder, act: Any, subdivisions: list[Any], *, vector_method: str, chunk_count: int, ) -> dict[str, Any]: """Assemble the Qdrant payload dict from ORM models for a whole-act vector.""" act_data: dict[str, Any] = { "id": act.id, "celex": act.celex, "title": act.title, "act_type": act.act_type.value, "language": act.language.value, "adoption_date": act.adoption_date, "force_date": act.force_date, "end_date": act.end_date, "level": act.level, "sector": act.sector, "official_journal_reference": act.official_journal_reference, "form_number": act.form_number, "url_eurlex": act.url_eurlex, "eli": act.eli, } metadata = extract_metadata(act) or {} metadata["vector_method"] = vector_method metadata["chunk_count"] = str(chunk_count) metadata["subdivision_count"] = str(len(subdivisions)) return payload_builder.build_act_payload( act_data=act_data, full_text=_build_summary_text(act, subdivisions), subjects=_extract_subjects(act), metadata=metadata, )
def _normalize_subdivision_type(subdivision_type: Any) -> str | None: if subdivision_type is None: return None raw_value = getattr(subdivision_type, "value", subdivision_type) return str(raw_value).strip().lower() or None def _is_article_level_chunk(chunk: Any, subdivision: Any) -> bool: metadata = getattr(chunk, "chunk_metadata", None) or {} subdivision_type = _normalize_subdivision_type(getattr(subdivision, "subdivision_type", None)) return bool(metadata.get("article_level_chunking")) and subdivision_type == "article" def _locate_piece( content: str, piece: str, *, search_from: int, cursor: int, ) -> tuple[int, int]: start = content.find(piece, search_from) if start != -1: return start, start + len(piece) stripped = piece.strip() if stripped: start = content.find(stripped, search_from) if start != -1: return start, start + len(stripped) if stripped and len(stripped) > 10: anchor_len = min(50, len(stripped) // 2) prefix = stripped[:anchor_len] start = content.find(prefix, search_from) if start != -1: suffix = stripped[-anchor_len:] end_search = max(start + len(stripped) - anchor_len - 20, start) end_pos = content.find(suffix, end_search) if end_pos != -1: return start, end_pos + len(suffix) return start, min(len(content), start + len(piece)) return cursor, min(len(content), cursor + len(piece)) def _map_pieces_to_positions( content: str, pieces: list[str], *, overlap: int, ) -> list[tuple[int, int]]: positions: list[tuple[int, int]] = [] cursor = 0 for piece in pieces: search_from = max(0, cursor - overlap) if overlap else cursor start, end = _locate_piece(content, piece, search_from=search_from, cursor=cursor) positions.append((start, end)) cursor = end if positions and cursor < len(content): trailing = content[cursor:] if not trailing.strip(): last_start, _ = positions[-1] positions[-1] = (last_start, len(content)) return positions def _split_article_retrieval_segments( content: str, *, segment_chars: int, segment_overlap: int, ) -> list[tuple[str, int, int]]: if not content: return [] if len(content) <= segment_chars: return [(content, 0, len(content))] splitter = RecursiveCharacterTextSplitter( chunk_size=max(int(segment_chars), 1), chunk_overlap=max(int(segment_overlap), 0), separators=["\n\n", "\n", r"(?<=[\.\!\?\:\;])\s+", " ", ""], is_separator_regex=True, ) pieces = [piece for piece in splitter.split_text(content) if piece] if not pieces: return [(content, 0, len(content))] positions = _map_pieces_to_positions(content, pieces, overlap=max(int(segment_overlap), 0)) return [(content[start:end], start, end) for start, end in positions if end > start] def _build_chunk_vector_plan( *, chunk: Any, subdivision: Any, act: Any, embedding_service: ChunkEmbeddingService, payload_builder: PayloadBuilder, retrieval_segment_chars: int, retrieval_segment_overlap: int, ) -> dict[str, Any]: canonical_payload = build_chunk_payload( payload_builder=payload_builder, chunk=chunk, subdivision=subdivision, act=act, ) canonical_payload["canonical_chunk_vector"] = True canonical_payload["is_retrieval_segment"] = False canonical_payload.setdefault("retrieval_enabled", True) vector_specs: list[dict[str, Any]] = [ { "point_id": chunk.id, "text": chunk.content, "payload": canonical_payload, } ] hash_payloads: list[dict[str, Any]] = [dict(canonical_payload)] if _is_article_level_chunk(chunk, subdivision) and len(str(chunk.content or "")) > retrieval_segment_chars: segments = _split_article_retrieval_segments( str(chunk.content or ""), segment_chars=retrieval_segment_chars, segment_overlap=retrieval_segment_overlap, ) if len(segments) > 1: canonical_payload["retrieval_enabled"] = False canonical_payload["retrieval_segment_count"] = len(segments) hash_payloads = [dict(canonical_payload)] vector_specs = [ { "point_id": chunk.id, "text": chunk.content, "payload": canonical_payload, } ] for segment_index, (segment_text, char_start, char_end) in enumerate(segments): segment_metadata = dict(getattr(chunk, "chunk_metadata", None) or {}) segment_metadata.update( { "retrieval_segment_index": segment_index, "retrieval_segment_count": len(segments), "retrieval_vector_role": "segment", } ) segment_payload = _build_chunk_payload_from_data( payload_builder=payload_builder, chunk_data={ "id": chunk.id, "subdivision_id": chunk.subdivision_id, "chunk_index": chunk.chunk_index, "content": segment_text, "char_start": char_start, "char_end": char_end, "token_count": embedding_service.estimate_tokens(segment_text), "chunk_metadata": segment_metadata, }, subdivision=subdivision, act=act, ) segment_payload["canonical_chunk_vector"] = False segment_payload["is_retrieval_segment"] = True segment_payload["retrieval_enabled"] = True segment_payload["retrieval_segment_index"] = segment_index segment_payload["retrieval_segment_count"] = len(segments) vector_specs.append( { "point_id": _segment_point_id(chunk.id, segment_index), "text": segment_text, "payload": segment_payload, } ) hash_payloads.append(segment_payload) return { "chunk_id": chunk.id, "hash": compute_payload_hash({"vectors": hash_payloads}), "vector_specs": vector_specs, } # ═══════════════════════════════════════════════════════════════════════════ # 5. Embedding pipelines (pure — no DB, no sessions) # ═══════════════════════════════════════════════════════════════════════════
[docs] @dataclass class EmbedBatchResult: """Result of preparing one embedding batch — ready to be persisted by the caller.""" points: list[VectorPoint] = field(default_factory=list) state_records: list[dict[str, Any]] = field(default_factory=list) delete_filters: list[dict[str, Any]] = field(default_factory=list) embedded_count: int = 0 skipped_count: int = 0
[docs] def prepare_subdivision_embeddings( *, act: Any, subdivisions: list[Any], version: Any | None, metadata: dict[str, str] | None, embedding_service: ChunkEmbeddingService, payload_builder: PayloadBuilder, state_map: dict[int, str], batch_size: int, max_batch_size: int, max_chars: int, provider: str, model_name: str, vector_size: int, force: bool = False, ) -> list[EmbedBatchResult]: """Prepare embedding batches for *subdivisions*. Pure logic — no DB access. Args: act: The ORM act object (for payload building). subdivisions: ORM subdivision objects. version: Current version for the act (or ``None``). metadata: Act metadata dict (or ``None``). embedding_service: Service to produce vectors. payload_builder: Builds Qdrant payloads. state_map: Pre-fetched ``{subdivision_id: content_hash}`` from embedding_state. batch_size / max_batch_size: Embedding batch sizing. max_chars: Truncation limit for subdivision content. provider / model_name / vector_size: Embedding identity. force: If ``True``, embed everything regardless of state_map. Returns: One :class:`EmbedBatchResult` per batch, ready to be persisted. """ if not subdivisions: return [] total = len(subdivisions) effective_batch_size = min(batch_size, max_batch_size) results: list[EmbedBatchResult] = [] for i in range(0, total, effective_batch_size): batch: list[Any] = subdivisions[i : i + effective_batch_size] try: payloads: list[dict[str, Any]] = [] texts: list[str] = [] ids: list[int] = [] hashes: list[str] = [] for subdiv in batch: content = truncate_subdivision_content(subdiv.content, max_chars) payload = build_subdivision_payload( payload_builder, subdivision=subdiv, act=act, version=version, metadata=metadata, content_override=content, ) payload_hash = compute_payload_hash(payload) ids.append(subdiv.id) texts.append(content) payloads.append(payload) hashes.append(payload_hash) if force: embed_indices = list(range(len(batch))) batch_skipped = 0 else: embed_indices = [idx for idx, subdiv_id in enumerate(ids) if state_map.get(subdiv_id) != hashes[idx]] batch_skipped = len(batch) - len(embed_indices) if not embed_indices: results.append(EmbedBatchResult(skipped_count=batch_skipped)) continue embed_texts = [texts[idx] for idx in embed_indices] embeddings: list[list[float]] = embedding_service.embed_batch(embed_texts, batch_size=effective_batch_size) points: list[VectorPoint] = [] state_records: list[dict[str, Any]] = [] for idx, embedding in zip(embed_indices, embeddings): points.append(VectorPoint(id=ids[idx], vector=embedding, payload=payloads[idx])) state_records.append( make_state_record( object_type="subdivision", object_id=ids[idx], provider=provider, model_name=model_name, vector_size=vector_size, content_hash=hashes[idx], ) ) results.append( EmbedBatchResult( points=points, state_records=state_records, embedded_count=len(points), skipped_count=batch_skipped, ) ) except Exception as e: logger.error( "Failed to prepare subdivision embedding batch for %s: %s", act.celex, e, exc_info=True, ) continue return results
[docs] def prepare_chunk_embeddings( *, chunks: list[Any], act: Any, embedding_service: ChunkEmbeddingService, payload_builder: PayloadBuilder, state_map: dict[int, str], batch_size: int, retrieval_segment_chars: int, retrieval_segment_overlap: int, provider: str, model_name: str, vector_size: int, force: bool, ) -> list[EmbedBatchResult]: """Prepare embedding batches for *chunks*. Pure logic — no DB access. Args: chunks: List of ``(chunk, subdivision)`` tuples. act: The ORM act object (for payload building). embedding_service: Service to produce vectors. payload_builder: Builds Qdrant payloads. state_map: Pre-fetched ``{chunk_id: content_hash}`` from embedding_state. batch_size: Embedding batch sizing. retrieval_segment_chars / retrieval_segment_overlap: Target segmentation budget for long article retrieval vectors. provider / model_name / vector_size: Embedding identity. force: If ``True``, embed everything regardless of state_map. Returns: One :class:`EmbedBatchResult` per batch, ready to be persisted. """ total = len(chunks) results: list[EmbedBatchResult] = [] for i in range(0, total, batch_size): batch: list[tuple[Any, Any]] = chunks[i : i + batch_size] try: chunk_plans: list[dict[str, Any]] = [] chunk_ids: list[int] = [] hashes: list[str] = [] for chunk, subdivision in batch: chunk_plan = _build_chunk_vector_plan( chunk=chunk, subdivision=subdivision, act=act, embedding_service=embedding_service, payload_builder=payload_builder, retrieval_segment_chars=retrieval_segment_chars, retrieval_segment_overlap=retrieval_segment_overlap, ) payload_hash = str(chunk_plan["hash"]) chunk_plans.append(chunk_plan) chunk_ids.append(chunk.id) hashes.append(payload_hash) if force: embed_indices = list(range(len(batch))) batch_skipped = 0 else: embed_indices = [idx for idx, cid in enumerate(chunk_ids) if state_map.get(cid) != hashes[idx]] batch_skipped = len(batch) - len(embed_indices) if not embed_indices: results.append(EmbedBatchResult(skipped_count=batch_skipped)) continue vector_specs: list[dict[str, Any]] = [] for idx in embed_indices: vector_specs.extend(chunk_plans[idx]["vector_specs"]) embeddings: list[list[float]] = embedding_service.embed_batch( [str(spec["text"]) for spec in vector_specs], batch_size=batch_size, ) points: list[VectorPoint] = [ VectorPoint(id=spec["point_id"], vector=emb, payload=spec["payload"]) for spec, emb in zip(vector_specs, embeddings) ] state_records = [ make_state_record( object_type="chunk", object_id=chunk_ids[idx], provider=provider, model_name=model_name, vector_size=vector_size, content_hash=hashes[idx], ) for idx in embed_indices ] results.append( EmbedBatchResult( points=points, state_records=state_records, delete_filters=[{"chunk_id": [chunk_ids[idx] for idx in embed_indices]}], embedded_count=len(embed_indices), skipped_count=batch_skipped, ) ) except Exception as e: logger.error( "Failed to prepare chunk embedding batch for %s: %s", act.celex, e, exc_info=True, ) continue return results
[docs] def prepare_act_document_embedding( *, act: Any, subdivisions: list[Any], chunks: list[Any], chunk_vectors: dict[int, list[float]], payload_builder: PayloadBuilder, state_map: dict[int, str], provider: str, model_name: str, vector_size: int, force: bool = False, ) -> EmbedBatchResult | None: """Prepare whole-act embedding via mean pooling. Pure logic — no DB access. Args: act: The ORM act object. subdivisions: ORM subdivision objects for the act. chunks: ORM ``(chunk, subdivision)`` tuples for the act. chunk_vectors: Pre-fetched ``{chunk_id: vector}`` from Qdrant. payload_builder: Builds Qdrant payloads. state_map: Pre-fetched ``{act_id: content_hash}`` from embedding_state. provider / model_name / vector_size: Embedding identity. force: If ``True``, embed regardless of state_map. Returns: :class:`EmbedBatchResult` with a single point, or ``None`` if skipped. """ if not chunks: logger.info("[EmbedAct] No chunks for act %s, skipping", act.celex) return None chunk_ids = [chunk.id for chunk, _subdivision in chunks] vectors = [chunk_vectors[chunk_id] for chunk_id in chunk_ids if chunk_id in chunk_vectors] if not vectors: logger.info( "[EmbedAct] No chunk vectors found for act %s (%d chunks) — chunks not embedded yet?", act.celex, len(chunk_ids), ) return None act_vector = mean_pool_vectors(vectors) assert act_vector is not None # guaranteed: vectors is non-empty logger.info( "[EmbedAct] Mean pooled %d/%d chunk vectors for act %s", len(vectors), len(chunk_ids), act.celex, ) payload = build_act_document_payload( payload_builder, act=act, subdivisions=subdivisions, vector_method="mean_pool_chunks", chunk_count=len(vectors), ) payload_hash = compute_payload_hash(payload) if not force: if state_map.get(act.id) == payload_hash: logger.info("[EmbedAct] Act %s act-vector already up to date, skipping", act.celex) return None point = VectorPoint(id=act.id, vector=act_vector, payload=payload) state_record = make_state_record( object_type="act", object_id=act.id, provider=provider, model_name=model_name, vector_size=vector_size, content_hash=payload_hash, ) logger.info( "[EmbedAct] Prepared act-vector for %s (mean of %d chunks)", act.celex, len(vectors), ) return EmbedBatchResult( points=[point], state_records=[state_record], embedded_count=1, )