"""
Embedding Worker
Processes embedding jobs from Redis queue
Creates vector embeddings for chunks and acts and stores them in Qdrant
"""
import logging
import os
from dataclasses import dataclass
from typing import Any, cast
from embedding_worker.bootstrap import EmbeddingComponents, init_components
from embedding_worker.service_metrics import observe_job_error, observe_job_execution
from lalandre_core.config import get_config
from lalandre_core.embedding_presets import resolve_worker_embedding_preset
from lalandre_core.logging_setup import setup_worker_logging
from lalandre_core.queue.dispatch_all import dispatch_all_act_jobs
from lalandre_core.queue.job_queue import (
QueueRuntime,
update_job_status,
)
from lalandre_core.queue.job_queue import (
enqueue_job as _enqueue_job,
)
from lalandre_core.queue.job_queue import (
job_already_queued as _job_already_queued,
)
from lalandre_core.queue.reconcile import with_reconcile_lock
from lalandre_core.queue.worker_config import get_reconcile_params, require_gateway_config
from lalandre_core.queue.worker_loop import (
instrumented_process_job,
resolve_base_runtime_params,
run_worker_loop,
)
from lalandre_core.utils import coerce_bool, is_eurlex_celex, normalize_celex, to_optional_int
from lalandre_db_qdrant.repository import QdrantRepository
from lalandre_embedding import EmbeddingService
from lalandre_embedding.pipeline import (
prepare_act_document_embedding,
prepare_chunk_embeddings,
)
from prometheus_client import start_http_server
setup_worker_logging()
logger = logging.getLogger(__name__)
JOB_TYPE_EMBED_ALL = "embed_all"
JOB_TYPE_EMBED_ACT = "embed_act"
[docs]
@dataclass
class WorkerRuntime(QueueRuntime):
"""Runtime state and lazy dependencies for the embedding worker loop."""
embed_worker_max_batch_size: int
embed_qdrant_upsert_batch_size: int
brpop_timeout_seconds: int
embed_queue_name: str
preset_id: str
components: EmbeddingComponents | None = None
[docs]
def ensure_components(self) -> EmbeddingComponents:
"""Initialize worker dependencies on first use and return them."""
if self.components is None:
self.components = init_components()
logger.info("Components initialized successfully for preset '%s'", self.components.preset_id)
return self.components
[docs]
def build_runtime() -> WorkerRuntime:
"""Build the queue runtime used by the embedding worker."""
config = get_config()
preset = resolve_worker_embedding_preset()
embed_queue_name = os.getenv("EMBED_QUEUE_NAME", preset.resolved_queue_name())
base = resolve_base_runtime_params(
redis_host=config.gateway.redis_host,
redis_port=config.gateway.redis_port,
job_ttl_seconds=config.gateway.job_ttl_seconds,
brpop_timeout_seconds=config.workers.brpop_timeout_seconds,
)
return WorkerRuntime(
redis_client=base.redis_client,
job_ttl_seconds=base.job_ttl_seconds,
embed_worker_max_batch_size=config.workers.embed_worker_max_batch_size,
embed_qdrant_upsert_batch_size=config.workers.embed_qdrant_upsert_batch_size,
brpop_timeout_seconds=base.brpop_timeout_seconds,
embed_queue_name=embed_queue_name,
preset_id=preset.preset_id,
)
def _default_embed_batch_size() -> int:
return require_gateway_config("job_embed_batch_size")
def _resolve_embed_batch_size(runtime: WorkerRuntime, params: dict[str, Any]) -> int:
"""Extract batch_size from job params and clamp it to the worker-safe maximum."""
requested_batch_size = to_optional_int(params.get("batch_size")) or _default_embed_batch_size()
max_batch_size = max(int(runtime.embed_worker_max_batch_size), 1)
resolved_batch_size = max(1, min(int(requested_batch_size), max_batch_size))
if resolved_batch_size != requested_batch_size:
logger.info(
"Clamped embed batch size for preset '%s': requested=%d max=%d effective=%d",
runtime.preset_id,
requested_batch_size,
max_batch_size,
resolved_batch_size,
)
return resolved_batch_size
def _embed_all_job_already_queued(runtime: WorkerRuntime) -> bool:
return _job_already_queued(
runtime,
queue_name=runtime.embed_queue_name,
job_type=JOB_TYPE_EMBED_ALL,
)
def _maybe_enqueue_reconcile_job(runtime: WorkerRuntime) -> None:
"""
Enqueue an embed_all job on startup if embeddings are missing
for the current provider/model/vector_size.
Controlled by AUTO_EMBED_RECONCILE / workers.auto_embed_reconcile.
"""
enabled, ttl, _ = get_reconcile_params("embed")
if not enabled:
return
def _do_reconcile() -> None:
components = runtime.ensure_components()
pg_repo = components.pg_repo
embedding_service = components.embedding_service
provider: str = embedding_service.provider_name
model_name: str = embedding_service.model_name
vector_size: int = embedding_service.get_vector_size()
logger.info(
"[Reconcile] Checking embeddings for preset=%s provider=%s model=%s vector_size=%s",
runtime.preset_id,
provider,
model_name,
vector_size,
)
with pg_repo.get_session() as session:
pg_repo.ensure_embedding_state_table(session)
purged_states = pg_repo.purge_orphan_embedding_states(session)
if purged_states > 0:
session.commit()
logger.info(f"[Reconcile] Purged {purged_states} orphan embedding_state rows")
total_chunks = pg_repo.count_chunks(session)
if total_chunks == 0:
logger.info("[Reconcile] No chunks found, skipping")
return
embedded_chunks = pg_repo.count_embedded_chunks(
session=session,
provider=provider,
model_name=model_name,
vector_size=vector_size,
)
needs_chunks: bool = total_chunks > 0 and embedded_chunks < total_chunks
if not needs_chunks:
logger.info("[Reconcile] Embeddings already up to date, skipping")
return
if _embed_all_job_already_queued(runtime):
logger.info("[Reconcile] embed_all already queued, skipping")
return
_enqueue_job(
runtime,
queue_name=runtime.embed_queue_name,
job_type=JOB_TYPE_EMBED_ALL,
params={
"batch_size": _default_embed_batch_size(),
},
)
logger.info(
f"[Reconcile] Enqueued embed_all for preset={runtime.preset_id} (chunks {embedded_chunks}/{total_chunks})"
)
with_reconcile_lock(
redis_client=runtime.redis_client,
lock_key=f"embed:reconcile:lock:{runtime.preset_id}",
lock_ttl=ttl,
action=_do_reconcile,
)
def _persist_batch_results(
batch_results: list[Any],
*,
pg_repo: Any,
session: Any,
qdrant_repo: QdrantRepository,
upsert_batch_size: int,
label: str,
) -> tuple[int, int]:
"""Persist embedding batch results — upsert to Qdrant + Postgres, commit per batch.
Returns (total_embedded, total_skipped).
"""
total_embedded = 0
total_skipped = 0
for result in batch_results:
total_skipped += result.skipped_count
if not result.points:
continue
try:
for delete_filter in getattr(result, "delete_filters", []):
cast(Any, qdrant_repo).delete_points_by_filter(delete_filter)
qdrant_repo.upsert_points(points=result.points, batch_size=upsert_batch_size)
pg_repo.upsert_embedding_states(session, result.state_records)
session.commit()
total_embedded += result.embedded_count
except Exception as e:
session.rollback()
logger.error("Failed to persist %s batch: %s", label, e, exc_info=True)
continue
return total_embedded, total_skipped
[docs]
def process_embed_all(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None:
"""Dispatch embed_act jobs for all acts (event-driven pipeline)."""
components = runtime.ensure_components()
pg_repo = components.pg_repo
batch_size = _resolve_embed_batch_size(runtime, params)
with pg_repo.get_session() as session:
acts = pg_repo.list_acts_with_metadata(session)
dispatch_all_act_jobs(
runtime=runtime,
job_id=job_id,
queue_name=runtime.embed_queue_name,
job_type=JOB_TYPE_EMBED_ACT,
label="embed_act",
acts=acts,
build_params=lambda celex: {"celex": celex, "batch_size": batch_size},
error_label="Embedding",
)
[docs]
def process_embed_act(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None:
"""Embed chunks and (EUR-Lex only) whole-act vector for one act."""
celex_value: Any = params.get("celex")
celex = normalize_celex(celex_value if isinstance(celex_value, str) else "")
if not celex:
raise ValueError("Missing or invalid 'celex' in job params")
logger.info(f"[Job {job_id}] Starting embed for act {celex}")
update_job_status(runtime, job_id, "running", f"Embedding act {celex}...")
try:
components = runtime.ensure_components()
pg_repo = components.pg_repo
embedding_service: EmbeddingService = components.embedding_service
qdrant_chunks: QdrantRepository = components.qdrant_chunks
qdrant_acts: QdrantRepository = components.qdrant_acts
batch_size = _resolve_embed_batch_size(runtime, params)
force: bool = coerce_bool(params.get("force"), False)
payload_builder = components.payload_builder
provider = embedding_service.provider_name
model_name = embedding_service.model_name
vector_size = embedding_service.get_vector_size()
upsert_batch_size = runtime.embed_qdrant_upsert_batch_size
config = get_config()
retrieval_segment_chars = config.chunking.resolve_max_chunk_size(config.token_limits)
retrieval_segment_overlap = min(
max(int(config.chunking.chunk_overlap), 0),
max(retrieval_segment_chars - 1, 0),
)
with pg_repo.get_session() as session:
pg_repo.ensure_embedding_state_table(session)
purged = pg_repo.purge_orphan_embedding_states(session)
if purged > 0:
session.commit()
logger.info(f"[Job {job_id}] Purged {purged} orphan embedding_state rows")
act = pg_repo.get_act_by_celex(session, celex)
if not act:
raise ValueError(f"Act {celex} not found")
# ── 1. Chunks ─────────────────────────────────────────────
chunks = pg_repo.list_chunks_for_act(session, act.id)
logger.info(f"[Job {job_id}] Found {len(chunks)} chunks for act {celex}")
chunk_ids = [c.id for c, _s in chunks]
chunk_state_map: dict[int, str] = (
{}
if force
else pg_repo.get_embedding_state_map(
session=session,
object_type="chunk",
object_ids=chunk_ids,
provider=provider,
model_name=model_name,
vector_size=vector_size,
)
)
chunk_batches = cast(Any, prepare_chunk_embeddings)(
chunks=chunks,
act=act,
embedding_service=embedding_service,
payload_builder=payload_builder,
state_map=chunk_state_map,
batch_size=batch_size,
retrieval_segment_chars=retrieval_segment_chars,
retrieval_segment_overlap=retrieval_segment_overlap,
provider=provider,
model_name=model_name,
vector_size=vector_size,
force=force,
)
embedded_chunks, skipped_chunks = _persist_batch_results(
chunk_batches,
pg_repo=pg_repo,
session=session,
qdrant_repo=qdrant_chunks,
upsert_batch_size=upsert_batch_size,
label="chunk",
)
# ── 2. Act-level vector (EUR-Lex only, mean-pool chunk vectors) ──
act_suffix = "0 act-vectors (skipped)"
if is_eurlex_celex(celex):
chunk_vectors = qdrant_chunks.retrieve_vectors_by_ids(chunk_ids)
act_state_map: dict[int, str] = (
{}
if force
else pg_repo.get_embedding_state_map(
session=session,
object_type="act",
object_ids=[act.id],
provider=provider,
model_name=model_name,
vector_size=vector_size,
)
)
subdivisions = pg_repo.list_subdivisions_for_act(session, act.id)
act_result = prepare_act_document_embedding(
act=act,
subdivisions=subdivisions,
chunks=chunks,
chunk_vectors=chunk_vectors,
payload_builder=payload_builder,
state_map=act_state_map,
provider=provider,
model_name=model_name,
vector_size=vector_size,
force=force,
)
if act_result is not None and act_result.points:
try:
qdrant_acts.upsert_points(points=act_result.points)
pg_repo.upsert_embedding_states(session, act_result.state_records)
session.commit()
act_suffix = "1 act-vector"
except Exception as e:
session.rollback()
logger.error("Failed to persist act-vector: %s", e, exc_info=True)
else:
act_suffix = "act-vector skipped (non EUR-Lex)"
final_message = (
f"Embedded {embedded_chunks} chunks and {act_suffix} for act {celex} "
f"(skipped {skipped_chunks} chunks already embedded)"
)
logger.info(f"[Job {job_id}] {final_message}")
update_job_status(runtime, job_id, "completed", final_message, 100)
except Exception as e:
error_msg = f"Embedding failed: {str(e)}"
logger.error(f"[Job {job_id}] {error_msg}", exc_info=True)
update_job_status(runtime, job_id, "failed", error_msg)
raise
[docs]
def process_job(runtime: WorkerRuntime, job_data: dict[str, Any]) -> None:
"""Dispatch one embedding-related job through the instrumented worker loop."""
instrumented_process_job(
runtime=runtime,
job_data=job_data,
dispatch={
JOB_TYPE_EMBED_ALL: process_embed_all,
JOB_TYPE_EMBED_ACT: process_embed_act,
},
observe_execution=observe_job_execution,
observe_error=observe_job_error,
)
[docs]
def main() -> None:
"""Start the embedding worker loop and its Prometheus endpoint."""
metrics_port = int(get_config().workers.embed_metrics_port)
start_http_server(metrics_port)
logger.info("Embedding worker metrics endpoint started on port %d", metrics_port)
runtime = build_runtime()
logger.info(
"Embedding worker bound to preset '%s' on queue '%s'",
runtime.preset_id,
runtime.embed_queue_name,
)
_, _, reconcile_interval = get_reconcile_params("embed")
run_worker_loop(
queue_name=runtime.embed_queue_name,
worker_name=f"Embedding Worker [{runtime.preset_id}]",
redis_client=runtime.redis_client,
brpop_timeout_seconds=runtime.brpop_timeout_seconds,
process_job=lambda job_data: process_job(runtime, job_data),
reconcile_callback=lambda: _maybe_enqueue_reconcile_job(runtime),
reconcile_interval_seconds=reconcile_interval,
)
if __name__ == "__main__":
main()