Source code for chunking_worker.worker

"""
Chunking Worker - Lalandre
Processes chunking jobs from Redis queue
"""

import logging
import time
from dataclasses import dataclass
from typing import Any, cast

from chunking_worker.bootstrap import ChunkingComponents, init_components
from chunking_worker.service_metrics import observe_job_error, observe_job_execution
from lalandre_chunking import pipeline as chunking_pipeline
from lalandre_core.config import get_config
from lalandre_core.embedding_presets import list_embedding_presets
from lalandre_core.logging_setup import setup_worker_logging
from lalandre_core.queue.dispatch_all import dispatch_all_act_jobs
from lalandre_core.queue.job_queue import (
    QueueRuntime,
    update_job_status,
)
from lalandre_core.queue.job_queue import (
    enqueue_job as _enqueue_job,
)
from lalandre_core.queue.job_queue import (
    job_already_queued as _job_already_queued,
)
from lalandre_core.queue.reconcile import with_reconcile_lock
from lalandre_core.queue.worker_config import get_reconcile_params, require_gateway_config
from lalandre_core.queue.worker_loop import (
    instrumented_process_job,
    resolve_base_runtime_params,
    run_worker_loop,
)
from lalandre_core.utils import coerce_bool, normalize_celex, to_optional_int
from prometheus_client import start_http_server  # type: ignore[import-untyped]

setup_worker_logging()
logger = logging.getLogger(__name__)

prepare_article_level_plan = cast(Any, getattr(chunking_pipeline, "prepare_article_level_plan"))
make_article_level_chunks = cast(Any, getattr(chunking_pipeline, "make_article_level_chunks"))
serialize_chunk_records = cast(Any, getattr(chunking_pipeline, "serialize_chunk_records"))


CHUNK_QUEUE_NAME = "chunk_jobs"
EXTRACT_QUEUE_NAME = "extract_jobs"
JOB_TYPE_CHUNK_ALL = "chunk_all"
JOB_TYPE_CHUNK_ACT = "chunk_act"
JOB_TYPE_EMBED_ACT = "embed_act"
JOB_TYPE_EXTRACT_ACT = "extract_act"


[docs] @dataclass class WorkerRuntime(QueueRuntime): """Runtime state and lazy dependencies for the chunking worker loop.""" chunk_db_commit_batch_size: int brpop_timeout_seconds: int components: ChunkingComponents | None = None
[docs] def ensure_components(self) -> ChunkingComponents: """Initialize worker dependencies on first use and return them.""" if self.components is None: self.components = init_components() logger.info("Components initialized successfully") return self.components
[docs] def build_runtime() -> WorkerRuntime: """Build the queue runtime used by the chunking worker.""" config = get_config() base = resolve_base_runtime_params( redis_host=config.gateway.redis_host, redis_port=config.gateway.redis_port, job_ttl_seconds=config.gateway.job_ttl_seconds, brpop_timeout_seconds=config.workers.brpop_timeout_seconds, ) return WorkerRuntime( redis_client=base.redis_client, job_ttl_seconds=base.job_ttl_seconds, chunk_db_commit_batch_size=config.workers.chunk_db_commit_batch_size, brpop_timeout_seconds=base.brpop_timeout_seconds, )
def _default_chunk_min_content_length() -> int: return require_gateway_config("job_chunk_min_content_length") def _enqueue_embed_act_jobs(runtime: WorkerRuntime, *, celex: str, force: bool) -> int: queued = 0 for preset in list_embedding_presets(indexing_only=True): queued_job_id = _enqueue_job( runtime, queue_name=preset.resolved_queue_name(), job_type=JOB_TYPE_EMBED_ACT, params={ "celex": celex, "force": force, }, dedupe_celex=celex, ) if queued_job_id: queued += 1 logger.info( "[Pipeline] Enqueued embed_act for %s on preset %s (job_id=%s)", celex, preset.preset_id, queued_job_id, ) if queued == 0: logger.warning( "[Pipeline] No embed_act job queued for %s (no indexing-enabled presets or jobs already queued)", celex, ) return queued def _enqueue_extract_act_job(runtime: WorkerRuntime, *, celex: str, force: bool) -> None: queued_job_id = _enqueue_job( runtime, queue_name=EXTRACT_QUEUE_NAME, job_type=JOB_TYPE_EXTRACT_ACT, params={ "celex": celex, "force": force, }, dedupe_celex=celex, ) if queued_job_id: logger.info("[Pipeline] Enqueued extract_act for %s (job_id=%s)", celex, queued_job_id) def _delete_chunk_vectors_for_ids(repos: dict[str, Any], chunk_ids: list[int]) -> None: if not chunk_ids: return typed_chunk_ids = cast(list[str | int], chunk_ids) for preset_id, repo in repos.items(): try: if not repo.collection_exists(): continue repo.delete_points(typed_chunk_ids) cast(Any, repo).delete_points_by_filter({"chunk_id": chunk_ids}) except Exception as exc: logger.warning( "Failed to delete chunk vectors for preset %s (chunk_ids=%d): %s", preset_id, len(chunk_ids), exc, ) def _delete_act_vectors_for_ids(repos: dict[str, Any], act_ids: list[int]) -> None: if not act_ids: return typed_act_ids = cast(list[str | int], act_ids) for preset_id, repo in repos.items(): try: if not repo.collection_exists(): continue repo.delete_points(typed_act_ids) except Exception as exc: logger.warning( "Failed to delete act vectors for preset %s (act_ids=%d): %s", preset_id, len(act_ids), exc, ) def _chunk_all_job_already_queued(runtime: WorkerRuntime) -> bool: return _job_already_queued( runtime, queue_name=CHUNK_QUEUE_NAME, job_type=JOB_TYPE_CHUNK_ALL, ) def _maybe_enqueue_reconcile_job(runtime: WorkerRuntime) -> None: """ Enqueue a chunk_all job on startup if eligible subdivisions are not chunked. Controlled by AUTO_CHUNK_RECONCILE / workers.auto_chunk_reconcile. """ enabled, ttl, _ = get_reconcile_params("chunk") if not enabled: return def _do_reconcile() -> None: components = runtime.ensure_components() pg_repo = components.pg_repo min_length = _default_chunk_min_content_length() with pg_repo.get_session() as session: missing_count = pg_repo.count_subdivisions_without_chunks( session=session, min_content_length=min_length, ) if missing_count == 0: logger.info("[Reconcile] All eligible subdivisions already chunked, skipping") return if _chunk_all_job_already_queued(runtime): logger.info("[Reconcile] chunk_all already queued, skipping") return queued_job_id = _enqueue_job( runtime, queue_name=CHUNK_QUEUE_NAME, job_type=JOB_TYPE_CHUNK_ALL, params={ "min_content_length": min_length, }, ) if queued_job_id: logger.info("[Reconcile] Enqueued chunk_all (missing=%d)", missing_count) with_reconcile_lock( redis_client=runtime.redis_client, lock_key="chunk:reconcile:lock", lock_ttl=ttl, action=_do_reconcile, )
[docs] def process_chunk_all(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None: """Dispatch chunk_act jobs for all acts (event-driven pipeline).""" components = runtime.ensure_components() pg_repo = components.pg_repo min_length: int = to_optional_int(params.get("min_content_length")) or _default_chunk_min_content_length() force: bool = coerce_bool(params.get("force"), False) with pg_repo.get_session() as session: acts: list[Any] = pg_repo.list_acts_with_metadata(session) dispatch_all_act_jobs( runtime=runtime, job_id=job_id, queue_name=CHUNK_QUEUE_NAME, job_type=JOB_TYPE_CHUNK_ACT, label="chunk_act", acts=acts, build_params=lambda celex: {"celex": celex, "min_content_length": min_length, "force": force}, error_label="Chunking", )
[docs] def process_chunk_act(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None: """Process chunk single act job""" celex_value: Any = params.get("celex") celex = normalize_celex(celex_value if isinstance(celex_value, str) else "") if not celex: raise ValueError("Missing or invalid 'celex' in job params") force: bool = coerce_bool(params.get("force"), False) min_content_length: int = to_optional_int(params.get("min_content_length")) or _default_chunk_min_content_length() logger.info(f"[Job {job_id}] Starting chunk for act {celex}") update_job_status(runtime, job_id, "running", f"Chunking act {celex}...") try: components = runtime.ensure_components() pg_repo = components.pg_repo chunk_vector_repos = components.chunk_vector_repos act_vector_repos = components.act_vector_repos chunker: Any = components.chunker with pg_repo.get_session() as session: # Get the act act: Any = pg_repo.get_act_by_celex(session, celex) if not act: raise ValueError(f"Act {celex} not found") # Get subdivisions for this act subdivisions: list[Any] = pg_repo.list_subdivisions_for_act(session, act.id) eligible_subdivisions = [ subdiv for subdiv in subdivisions if len(str(subdiv.content or "")) >= min_content_length ] total = len(eligible_subdivisions) # Prepare article-level chunking plan (EUR-Lex / Légifrance) config = get_config() art_plan = prepare_article_level_plan( celex=celex, subdivisions=subdivisions, article_level_enabled=config.chunking.article_level_chunking, ) logger.info( f"[Job {job_id}] Found {total} eligible subdivisions for act {celex} " f"(min_content_length={min_content_length}, article_level={art_plan.active})" ) # Delete existing chunks if force if force: if min_content_length > 0: for subdiv in eligible_subdivisions: existing_chunk_ids = pg_repo.list_chunk_ids_for_subdivision( session=session, subdivision_id=subdiv.id, ) if existing_chunk_ids: _delete_chunk_vectors_for_ids(chunk_vector_repos, existing_chunk_ids) pg_repo.delete_embedding_states_for_chunk_ids(session, existing_chunk_ids) pg_repo.delete_chunks_for_subdivision( session=session, subdivision_id=subdiv.id, ) else: existing_chunk_ids = pg_repo.list_chunk_ids_for_act(session, act.id) if existing_chunk_ids: _delete_chunk_vectors_for_ids(chunk_vector_repos, existing_chunk_ids) pg_repo.delete_embedding_states_for_chunk_ids(session, existing_chunk_ids) pg_repo.delete_chunks_for_act(session, act.id) _delete_act_vectors_for_ids(act_vector_repos, [act.id]) pg_repo.delete_embedding_states_for_act_ids(session, [act.id]) pg_repo.reset_extraction_status(session, act.id) session.commit() chunk_count = 0 skipped_existing = 0 skipped_aggregated = 0 MAX_RETRIES = 3 skipped_errors = 0 for i, subdiv in enumerate(eligible_subdivisions): # Skip child paragraphs already aggregated into their parent article if subdiv.id in art_plan.skip_ids: skipped_aggregated += 1 continue # Skip if already chunked (unless force) if not force and pg_repo.subdivision_has_chunks(session, subdiv.id): skipped_existing += 1 continue for attempt in range(1, MAX_RETRIES + 1): try: article_level_chunks = make_article_level_chunks( chunker=chunker, subdivision=subdiv, article_level_plan=art_plan, ) if article_level_chunks is not None: if not article_level_chunks: break chunks: list[Any] = article_level_chunks else: chunks = chunker.chunk_subdivision( subdivision_id=subdiv.id, content=subdiv.content, subdivision_type=subdiv.subdivision_type, ) pg_repo.insert_chunk_records( session=session, records=serialize_chunk_records(chunks), ) chunk_count += len(chunks) break # success except Exception as e: if attempt < MAX_RETRIES: logger.warning( "[Job %s] Retry %d/%d for subdivision %d: %s", job_id, attempt, MAX_RETRIES, subdiv.id, e, ) time.sleep(2**attempt) else: logger.error( "[Job %s] Skipping subdivision %d after %d retries: %s", job_id, subdiv.id, MAX_RETRIES, e, ) skipped_errors += 1 # Update progress progress = int((i + 1) / total * 100) if (i + 1) % 100 == 0 or (i + 1) == total: update_job_status( runtime, job_id, "running", ( f"Chunking {celex}: {i + 1}/{total} eligible subdivisions " f"({chunk_count} chunks, skipped_existing={skipped_existing})" ), progress, ) if chunk_count > 0: _delete_act_vectors_for_ids(act_vector_repos, [act.id]) pg_repo.delete_embedding_states_for_act_ids(session, [act.id]) pg_repo.reset_extraction_status(session, act.id) session.commit() message = ( f"Chunked act {celex}: {chunk_count} chunks from {total} eligible subdivisions " f"(skipped_existing={skipped_existing}, aggregated_children={skipped_aggregated}, errors={skipped_errors})" ) logger.info(f"[Job {job_id}] {message}") update_job_status(runtime, job_id, "completed", message, 100) _enqueue_embed_act_jobs(runtime, celex=celex, force=force) _enqueue_extract_act_job(runtime, celex=celex, force=force) except Exception as e: error_msg = f"Chunking failed: {str(e)}" logger.error(f"[Job {job_id}] {error_msg}", exc_info=True) update_job_status(runtime, job_id, "failed", error_msg) raise
[docs] def process_job(runtime: WorkerRuntime, job_data: dict[str, Any]) -> None: """Dispatch one chunking-related job through the instrumented worker loop.""" instrumented_process_job( runtime=runtime, job_data=job_data, dispatch={ JOB_TYPE_CHUNK_ALL: process_chunk_all, JOB_TYPE_CHUNK_ACT: process_chunk_act, }, observe_execution=observe_job_execution, observe_error=observe_job_error, )
[docs] def main() -> None: """Start the chunking worker loop and its Prometheus endpoint.""" metrics_port = int(get_config().workers.chunk_metrics_port) start_http_server(metrics_port) logger.info("Chunking worker metrics endpoint started on port %d", metrics_port) runtime = build_runtime() _, _, reconcile_interval = get_reconcile_params("chunk") run_worker_loop( queue_name=CHUNK_QUEUE_NAME, worker_name="Chunking Worker", redis_client=runtime.redis_client, brpop_timeout_seconds=runtime.brpop_timeout_seconds, process_job=lambda job_data: process_job(runtime, job_data), reconcile_callback=lambda: _maybe_enqueue_reconcile_job(runtime), reconcile_interval_seconds=reconcile_interval, )
if __name__ == "__main__": main()