"""Reusable business logic for vector embedding.
Contains payload construction, batched embedding, mean-pooling, and
incremental state tracking. The worker imports this module and delegates
the heavy lifting here.
"""
from __future__ import annotations
import hashlib
import json
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any, Protocol
from lalandre_core.repositories.common import PayloadBuilder
from lalandre_db_qdrant import VectorPoint
from langchain_text_splitters import RecursiveCharacterTextSplitter
logger = logging.getLogger(__name__)
_SEGMENT_UUID_NAMESPACE = uuid.UUID("a6f2e1c4-7b3d-4e9f-8a2c-1d5e4f7b0c93")
def _segment_point_id(chunk_id: int, segment_index: int) -> str:
return str(uuid.uuid5(_SEGMENT_UUID_NAMESPACE, f"{chunk_id}::seg::{segment_index}"))
[docs]
class ChunkEmbeddingService(Protocol):
"""Embedding interface required by the chunk embedding pipeline."""
[docs]
def embed_batch(
self,
texts: list[str],
batch_size: int | None = None,
) -> list[list[float]]:
"""Embed a batch of texts and preserve input ordering."""
...
[docs]
def estimate_tokens(self, text: str) -> int | None:
"""Estimate the token count for one input text when supported."""
...
# ═══════════════════════════════════════════════════════════════════════════
# 1. Pure helpers
# ═══════════════════════════════════════════════════════════════════════════
[docs]
def compute_payload_hash(payload: dict[str, Any]) -> str:
"""Stable hash of *payload* to detect changes across runs."""
payload_json = json.dumps(payload, sort_keys=True, ensure_ascii=False, default=str)
return hashlib.sha256(payload_json.encode("utf-8")).hexdigest()
[docs]
def mean_pool_vectors(vectors: list[list[float]]) -> list[float] | None:
"""Return the element-wise mean of *vectors*, or ``None`` if empty."""
if not vectors:
return None
dim = len(vectors[0])
acc = [0.0] * dim
for vec in vectors:
for i, v in enumerate(vec):
acc[i] += v
n = len(vectors)
return [x / n for x in acc]
[docs]
def truncate_subdivision_content(content: str, max_chars: int) -> str:
"""Truncate *content* for embedding if it exceeds *max_chars*."""
if not content:
return content
if len(content) <= max_chars:
return content
return content[:max_chars] + "\n\n[...Contenu tronque pour embedding...]"
[docs]
def resolved_subdivision_embed_max_chars(config: Any) -> int:
"""Compute the max embedding chars from a ``LalandreConfig``."""
safe_embed_chars = int(config.token_limits.embedding_max_chars * config.token_limits.embedding_safety_ratio)
return min(config.chunking.subdivision_max_chars, max(safe_embed_chars, 256))
[docs]
def make_state_record(
*,
object_type: str,
object_id: int,
provider: str,
model_name: str,
vector_size: int,
content_hash: str,
) -> dict[str, Any]:
"""Build one embedding-state upsert record."""
return {
"object_type": object_type,
"object_id": object_id,
"provider": provider,
"model_name": model_name,
"vector_size": vector_size,
"content_hash": content_hash,
}
# ═══════════════════════════════════════════════════════════════════════════
# 3. ORM → dict extractors
# ═══════════════════════════════════════════════════════════════════════════
def _extract_subjects(act: Any) -> list[dict[str, Any]]:
"""Extract subject entries from an ORM act (via junction table)."""
if not act.subjects:
return []
return [
{
"eurovoc_code": s.subject.eurovoc_code,
"label_fr": s.subject.label_fr,
"label_en": s.subject.label_en,
}
for s in act.subjects
]
def _build_summary_text(act: Any, subdivisions: list[Any]) -> str:
"""Build a summary text from act title + first subdivision contents."""
parts = [act.title or ""]
for subdiv in subdivisions[:5]:
part = str(subdiv.content or "")[:300].strip()
if part:
parts.append(part)
return "\n\n".join(p for p in parts if p)
# ═══════════════════════════════════════════════════════════════════════════
# 4. Payload builders (ORM → Qdrant payload dict)
# ═══════════════════════════════════════════════════════════════════════════
[docs]
def build_subdivision_payload(
payload_builder: PayloadBuilder,
subdivision: Any,
act: Any,
version: Any | None,
metadata: dict[str, str] | None,
content_override: str | None = None,
) -> dict[str, Any]:
"""Assemble the Qdrant payload dict from ORM models for a *subdivision*."""
subdivision_data: dict[str, Any] = {
"id": subdivision.id,
"subdivision_type": subdivision.subdivision_type.value,
"number": subdivision.number,
"title": subdivision.title,
"sequence_order": subdivision.sequence_order,
"hierarchy_path": subdivision.hierarchy_path,
"depth": subdivision.depth,
"parent_id": subdivision.parent_id,
"content": content_override if content_override is not None else subdivision.content,
}
act_data: dict[str, Any] = {
"id": act.id,
"celex": act.celex,
"title": act.title,
"act_type": act.act_type.value,
"language": act.language.value,
"adoption_date": act.adoption_date,
"force_date": act.force_date,
"level": act.level,
}
version_data: dict[str, Any] | None = None
if version:
version_data = {
"id": version.id,
"version_number": version.version_number,
"version_type": version.version_type.value,
"version_date": version.version_date,
"is_current": version.is_current,
}
return payload_builder.build_subdivision_payload(
subdivision_data=subdivision_data,
act_data=act_data,
version_data=version_data,
metadata=metadata or {},
)
[docs]
def build_chunk_payload(
payload_builder: PayloadBuilder,
chunk: Any,
subdivision: Any,
act: Any,
) -> dict[str, Any]:
"""Assemble the Qdrant payload dict from ORM models for a *chunk*."""
return _build_chunk_payload_from_data(
payload_builder=payload_builder,
chunk_data={
"id": chunk.id,
"subdivision_id": chunk.subdivision_id,
"chunk_index": chunk.chunk_index,
"content": chunk.content,
"char_start": chunk.char_start,
"char_end": chunk.char_end,
"token_count": chunk.token_count,
"chunk_metadata": chunk.chunk_metadata,
},
subdivision=subdivision,
act=act,
)
def _build_chunk_payload_from_data(
*,
payload_builder: PayloadBuilder,
chunk_data: dict[str, Any],
subdivision: Any,
act: Any,
) -> dict[str, Any]:
subdivision_data: dict[str, Any] = {
"subdivision_type": subdivision.subdivision_type.value,
"number": subdivision.number,
"title": subdivision.title,
}
act_data: dict[str, Any] = {
"id": act.id,
"celex": act.celex,
"title": act.title,
"act_type": act.act_type.value,
}
return payload_builder.build_chunk_payload(
chunk_data=chunk_data,
subdivision_data=subdivision_data,
act_data=act_data,
)
[docs]
def build_act_document_payload(
payload_builder: PayloadBuilder,
act: Any,
subdivisions: list[Any],
*,
vector_method: str,
chunk_count: int,
) -> dict[str, Any]:
"""Assemble the Qdrant payload dict from ORM models for a whole-act vector."""
act_data: dict[str, Any] = {
"id": act.id,
"celex": act.celex,
"title": act.title,
"act_type": act.act_type.value,
"language": act.language.value,
"adoption_date": act.adoption_date,
"force_date": act.force_date,
"end_date": act.end_date,
"level": act.level,
"sector": act.sector,
"official_journal_reference": act.official_journal_reference,
"form_number": act.form_number,
"url_eurlex": act.url_eurlex,
"eli": act.eli,
}
metadata = extract_metadata(act) or {}
metadata["vector_method"] = vector_method
metadata["chunk_count"] = str(chunk_count)
metadata["subdivision_count"] = str(len(subdivisions))
return payload_builder.build_act_payload(
act_data=act_data,
full_text=_build_summary_text(act, subdivisions),
subjects=_extract_subjects(act),
metadata=metadata,
)
def _normalize_subdivision_type(subdivision_type: Any) -> str | None:
if subdivision_type is None:
return None
raw_value = getattr(subdivision_type, "value", subdivision_type)
return str(raw_value).strip().lower() or None
def _is_article_level_chunk(chunk: Any, subdivision: Any) -> bool:
metadata = getattr(chunk, "chunk_metadata", None) or {}
subdivision_type = _normalize_subdivision_type(getattr(subdivision, "subdivision_type", None))
return bool(metadata.get("article_level_chunking")) and subdivision_type == "article"
def _locate_piece(
content: str,
piece: str,
*,
search_from: int,
cursor: int,
) -> tuple[int, int]:
start = content.find(piece, search_from)
if start != -1:
return start, start + len(piece)
stripped = piece.strip()
if stripped:
start = content.find(stripped, search_from)
if start != -1:
return start, start + len(stripped)
if stripped and len(stripped) > 10:
anchor_len = min(50, len(stripped) // 2)
prefix = stripped[:anchor_len]
start = content.find(prefix, search_from)
if start != -1:
suffix = stripped[-anchor_len:]
end_search = max(start + len(stripped) - anchor_len - 20, start)
end_pos = content.find(suffix, end_search)
if end_pos != -1:
return start, end_pos + len(suffix)
return start, min(len(content), start + len(piece))
return cursor, min(len(content), cursor + len(piece))
def _map_pieces_to_positions(
content: str,
pieces: list[str],
*,
overlap: int,
) -> list[tuple[int, int]]:
positions: list[tuple[int, int]] = []
cursor = 0
for piece in pieces:
search_from = max(0, cursor - overlap) if overlap else cursor
start, end = _locate_piece(content, piece, search_from=search_from, cursor=cursor)
positions.append((start, end))
cursor = end
if positions and cursor < len(content):
trailing = content[cursor:]
if not trailing.strip():
last_start, _ = positions[-1]
positions[-1] = (last_start, len(content))
return positions
def _split_article_retrieval_segments(
content: str,
*,
segment_chars: int,
segment_overlap: int,
) -> list[tuple[str, int, int]]:
if not content:
return []
if len(content) <= segment_chars:
return [(content, 0, len(content))]
splitter = RecursiveCharacterTextSplitter(
chunk_size=max(int(segment_chars), 1),
chunk_overlap=max(int(segment_overlap), 0),
separators=["\n\n", "\n", r"(?<=[\.\!\?\:\;])\s+", " ", ""],
is_separator_regex=True,
)
pieces = [piece for piece in splitter.split_text(content) if piece]
if not pieces:
return [(content, 0, len(content))]
positions = _map_pieces_to_positions(content, pieces, overlap=max(int(segment_overlap), 0))
return [(content[start:end], start, end) for start, end in positions if end > start]
def _build_chunk_vector_plan(
*,
chunk: Any,
subdivision: Any,
act: Any,
embedding_service: ChunkEmbeddingService,
payload_builder: PayloadBuilder,
retrieval_segment_chars: int,
retrieval_segment_overlap: int,
) -> dict[str, Any]:
canonical_payload = build_chunk_payload(
payload_builder=payload_builder,
chunk=chunk,
subdivision=subdivision,
act=act,
)
canonical_payload["canonical_chunk_vector"] = True
canonical_payload["is_retrieval_segment"] = False
canonical_payload.setdefault("retrieval_enabled", True)
vector_specs: list[dict[str, Any]] = [
{
"point_id": chunk.id,
"text": chunk.content,
"payload": canonical_payload,
}
]
hash_payloads: list[dict[str, Any]] = [dict(canonical_payload)]
if _is_article_level_chunk(chunk, subdivision) and len(str(chunk.content or "")) > retrieval_segment_chars:
segments = _split_article_retrieval_segments(
str(chunk.content or ""),
segment_chars=retrieval_segment_chars,
segment_overlap=retrieval_segment_overlap,
)
if len(segments) > 1:
canonical_payload["retrieval_enabled"] = False
canonical_payload["retrieval_segment_count"] = len(segments)
hash_payloads = [dict(canonical_payload)]
vector_specs = [
{
"point_id": chunk.id,
"text": chunk.content,
"payload": canonical_payload,
}
]
for segment_index, (segment_text, char_start, char_end) in enumerate(segments):
segment_metadata = dict(getattr(chunk, "chunk_metadata", None) or {})
segment_metadata.update(
{
"retrieval_segment_index": segment_index,
"retrieval_segment_count": len(segments),
"retrieval_vector_role": "segment",
}
)
segment_payload = _build_chunk_payload_from_data(
payload_builder=payload_builder,
chunk_data={
"id": chunk.id,
"subdivision_id": chunk.subdivision_id,
"chunk_index": chunk.chunk_index,
"content": segment_text,
"char_start": char_start,
"char_end": char_end,
"token_count": embedding_service.estimate_tokens(segment_text),
"chunk_metadata": segment_metadata,
},
subdivision=subdivision,
act=act,
)
segment_payload["canonical_chunk_vector"] = False
segment_payload["is_retrieval_segment"] = True
segment_payload["retrieval_enabled"] = True
segment_payload["retrieval_segment_index"] = segment_index
segment_payload["retrieval_segment_count"] = len(segments)
vector_specs.append(
{
"point_id": _segment_point_id(chunk.id, segment_index),
"text": segment_text,
"payload": segment_payload,
}
)
hash_payloads.append(segment_payload)
return {
"chunk_id": chunk.id,
"hash": compute_payload_hash({"vectors": hash_payloads}),
"vector_specs": vector_specs,
}
# ═══════════════════════════════════════════════════════════════════════════
# 5. Embedding pipelines (pure — no DB, no sessions)
# ═══════════════════════════════════════════════════════════════════════════
[docs]
@dataclass
class EmbedBatchResult:
"""Result of preparing one embedding batch — ready to be persisted by the caller."""
points: list[VectorPoint] = field(default_factory=list)
state_records: list[dict[str, Any]] = field(default_factory=list)
delete_filters: list[dict[str, Any]] = field(default_factory=list)
embedded_count: int = 0
skipped_count: int = 0
[docs]
def prepare_subdivision_embeddings(
*,
act: Any,
subdivisions: list[Any],
version: Any | None,
metadata: dict[str, str] | None,
embedding_service: ChunkEmbeddingService,
payload_builder: PayloadBuilder,
state_map: dict[int, str],
batch_size: int,
max_batch_size: int,
max_chars: int,
provider: str,
model_name: str,
vector_size: int,
force: bool = False,
) -> list[EmbedBatchResult]:
"""Prepare embedding batches for *subdivisions*. Pure logic — no DB access.
Args:
act: The ORM act object (for payload building).
subdivisions: ORM subdivision objects.
version: Current version for the act (or ``None``).
metadata: Act metadata dict (or ``None``).
embedding_service: Service to produce vectors.
payload_builder: Builds Qdrant payloads.
state_map: Pre-fetched ``{subdivision_id: content_hash}`` from embedding_state.
batch_size / max_batch_size: Embedding batch sizing.
max_chars: Truncation limit for subdivision content.
provider / model_name / vector_size: Embedding identity.
force: If ``True``, embed everything regardless of state_map.
Returns:
One :class:`EmbedBatchResult` per batch, ready to be persisted.
"""
if not subdivisions:
return []
total = len(subdivisions)
effective_batch_size = min(batch_size, max_batch_size)
results: list[EmbedBatchResult] = []
for i in range(0, total, effective_batch_size):
batch: list[Any] = subdivisions[i : i + effective_batch_size]
try:
payloads: list[dict[str, Any]] = []
texts: list[str] = []
ids: list[int] = []
hashes: list[str] = []
for subdiv in batch:
content = truncate_subdivision_content(subdiv.content, max_chars)
payload = build_subdivision_payload(
payload_builder,
subdivision=subdiv,
act=act,
version=version,
metadata=metadata,
content_override=content,
)
payload_hash = compute_payload_hash(payload)
ids.append(subdiv.id)
texts.append(content)
payloads.append(payload)
hashes.append(payload_hash)
if force:
embed_indices = list(range(len(batch)))
batch_skipped = 0
else:
embed_indices = [idx for idx, subdiv_id in enumerate(ids) if state_map.get(subdiv_id) != hashes[idx]]
batch_skipped = len(batch) - len(embed_indices)
if not embed_indices:
results.append(EmbedBatchResult(skipped_count=batch_skipped))
continue
embed_texts = [texts[idx] for idx in embed_indices]
embeddings: list[list[float]] = embedding_service.embed_batch(embed_texts, batch_size=effective_batch_size)
points: list[VectorPoint] = []
state_records: list[dict[str, Any]] = []
for idx, embedding in zip(embed_indices, embeddings):
points.append(VectorPoint(id=ids[idx], vector=embedding, payload=payloads[idx]))
state_records.append(
make_state_record(
object_type="subdivision",
object_id=ids[idx],
provider=provider,
model_name=model_name,
vector_size=vector_size,
content_hash=hashes[idx],
)
)
results.append(
EmbedBatchResult(
points=points,
state_records=state_records,
embedded_count=len(points),
skipped_count=batch_skipped,
)
)
except Exception as e:
logger.error(
"Failed to prepare subdivision embedding batch for %s: %s",
act.celex,
e,
exc_info=True,
)
continue
return results
[docs]
def prepare_chunk_embeddings(
*,
chunks: list[Any],
act: Any,
embedding_service: ChunkEmbeddingService,
payload_builder: PayloadBuilder,
state_map: dict[int, str],
batch_size: int,
retrieval_segment_chars: int,
retrieval_segment_overlap: int,
provider: str,
model_name: str,
vector_size: int,
force: bool,
) -> list[EmbedBatchResult]:
"""Prepare embedding batches for *chunks*. Pure logic — no DB access.
Args:
chunks: List of ``(chunk, subdivision)`` tuples.
act: The ORM act object (for payload building).
embedding_service: Service to produce vectors.
payload_builder: Builds Qdrant payloads.
state_map: Pre-fetched ``{chunk_id: content_hash}`` from embedding_state.
batch_size: Embedding batch sizing.
retrieval_segment_chars / retrieval_segment_overlap: Target segmentation
budget for long article retrieval vectors.
provider / model_name / vector_size: Embedding identity.
force: If ``True``, embed everything regardless of state_map.
Returns:
One :class:`EmbedBatchResult` per batch, ready to be persisted.
"""
total = len(chunks)
results: list[EmbedBatchResult] = []
for i in range(0, total, batch_size):
batch: list[tuple[Any, Any]] = chunks[i : i + batch_size]
try:
chunk_plans: list[dict[str, Any]] = []
chunk_ids: list[int] = []
hashes: list[str] = []
for chunk, subdivision in batch:
chunk_plan = _build_chunk_vector_plan(
chunk=chunk,
subdivision=subdivision,
act=act,
embedding_service=embedding_service,
payload_builder=payload_builder,
retrieval_segment_chars=retrieval_segment_chars,
retrieval_segment_overlap=retrieval_segment_overlap,
)
payload_hash = str(chunk_plan["hash"])
chunk_plans.append(chunk_plan)
chunk_ids.append(chunk.id)
hashes.append(payload_hash)
if force:
embed_indices = list(range(len(batch)))
batch_skipped = 0
else:
embed_indices = [idx for idx, cid in enumerate(chunk_ids) if state_map.get(cid) != hashes[idx]]
batch_skipped = len(batch) - len(embed_indices)
if not embed_indices:
results.append(EmbedBatchResult(skipped_count=batch_skipped))
continue
vector_specs: list[dict[str, Any]] = []
for idx in embed_indices:
vector_specs.extend(chunk_plans[idx]["vector_specs"])
embeddings: list[list[float]] = embedding_service.embed_batch(
[str(spec["text"]) for spec in vector_specs],
batch_size=batch_size,
)
points: list[VectorPoint] = [
VectorPoint(id=spec["point_id"], vector=emb, payload=spec["payload"])
for spec, emb in zip(vector_specs, embeddings)
]
state_records = [
make_state_record(
object_type="chunk",
object_id=chunk_ids[idx],
provider=provider,
model_name=model_name,
vector_size=vector_size,
content_hash=hashes[idx],
)
for idx in embed_indices
]
results.append(
EmbedBatchResult(
points=points,
state_records=state_records,
delete_filters=[{"chunk_id": [chunk_ids[idx] for idx in embed_indices]}],
embedded_count=len(embed_indices),
skipped_count=batch_skipped,
)
)
except Exception as e:
logger.error(
"Failed to prepare chunk embedding batch for %s: %s",
act.celex,
e,
exc_info=True,
)
continue
return results
[docs]
def prepare_act_document_embedding(
*,
act: Any,
subdivisions: list[Any],
chunks: list[Any],
chunk_vectors: dict[int, list[float]],
payload_builder: PayloadBuilder,
state_map: dict[int, str],
provider: str,
model_name: str,
vector_size: int,
force: bool = False,
) -> EmbedBatchResult | None:
"""Prepare whole-act embedding via mean pooling. Pure logic — no DB access.
Args:
act: The ORM act object.
subdivisions: ORM subdivision objects for the act.
chunks: ORM ``(chunk, subdivision)`` tuples for the act.
chunk_vectors: Pre-fetched ``{chunk_id: vector}`` from Qdrant.
payload_builder: Builds Qdrant payloads.
state_map: Pre-fetched ``{act_id: content_hash}`` from embedding_state.
provider / model_name / vector_size: Embedding identity.
force: If ``True``, embed regardless of state_map.
Returns:
:class:`EmbedBatchResult` with a single point, or ``None`` if skipped.
"""
if not chunks:
logger.info("[EmbedAct] No chunks for act %s, skipping", act.celex)
return None
chunk_ids = [chunk.id for chunk, _subdivision in chunks]
vectors = [chunk_vectors[chunk_id] for chunk_id in chunk_ids if chunk_id in chunk_vectors]
if not vectors:
logger.info(
"[EmbedAct] No chunk vectors found for act %s (%d chunks) — chunks not embedded yet?",
act.celex,
len(chunk_ids),
)
return None
act_vector = mean_pool_vectors(vectors)
assert act_vector is not None # guaranteed: vectors is non-empty
logger.info(
"[EmbedAct] Mean pooled %d/%d chunk vectors for act %s",
len(vectors),
len(chunk_ids),
act.celex,
)
payload = build_act_document_payload(
payload_builder,
act=act,
subdivisions=subdivisions,
vector_method="mean_pool_chunks",
chunk_count=len(vectors),
)
payload_hash = compute_payload_hash(payload)
if not force:
if state_map.get(act.id) == payload_hash:
logger.info("[EmbedAct] Act %s act-vector already up to date, skipping", act.celex)
return None
point = VectorPoint(id=act.id, vector=act_vector, payload=payload)
state_record = make_state_record(
object_type="act",
object_id=act.id,
provider=provider,
model_name=model_name,
vector_size=vector_size,
content_hash=payload_hash,
)
logger.info(
"[EmbedAct] Prepared act-vector for %s (mean of %d chunks)",
act.celex,
len(vectors),
)
return EmbedBatchResult(
points=[point],
state_records=[state_record],
embedded_count=1,
)