"""
Chunking Worker - Lalandre
Processes chunking jobs from Redis queue
"""
import logging
import time
from dataclasses import dataclass
from typing import Any, cast
from chunking_worker.bootstrap import ChunkingComponents, init_components
from chunking_worker.service_metrics import observe_job_error, observe_job_execution
from lalandre_chunking import pipeline as chunking_pipeline
from lalandre_core.config import get_config
from lalandre_core.embedding_presets import list_embedding_presets
from lalandre_core.logging_setup import setup_worker_logging
from lalandre_core.queue.dispatch_all import dispatch_all_act_jobs
from lalandre_core.queue.job_queue import (
QueueRuntime,
update_job_status,
)
from lalandre_core.queue.job_queue import (
enqueue_job as _enqueue_job,
)
from lalandre_core.queue.job_queue import (
job_already_queued as _job_already_queued,
)
from lalandre_core.queue.reconcile import with_reconcile_lock
from lalandre_core.queue.worker_config import get_reconcile_params, require_gateway_config
from lalandre_core.queue.worker_loop import (
instrumented_process_job,
resolve_base_runtime_params,
run_worker_loop,
)
from lalandre_core.utils import coerce_bool, normalize_celex, to_optional_int
from prometheus_client import start_http_server # type: ignore[import-untyped]
setup_worker_logging()
logger = logging.getLogger(__name__)
prepare_article_level_plan = cast(Any, getattr(chunking_pipeline, "prepare_article_level_plan"))
make_article_level_chunks = cast(Any, getattr(chunking_pipeline, "make_article_level_chunks"))
serialize_chunk_records = cast(Any, getattr(chunking_pipeline, "serialize_chunk_records"))
CHUNK_QUEUE_NAME = "chunk_jobs"
EXTRACT_QUEUE_NAME = "extract_jobs"
JOB_TYPE_CHUNK_ALL = "chunk_all"
JOB_TYPE_CHUNK_ACT = "chunk_act"
JOB_TYPE_EMBED_ACT = "embed_act"
JOB_TYPE_EXTRACT_ACT = "extract_act"
[docs]
@dataclass
class WorkerRuntime(QueueRuntime):
"""Runtime state and lazy dependencies for the chunking worker loop."""
chunk_db_commit_batch_size: int
brpop_timeout_seconds: int
components: ChunkingComponents | None = None
[docs]
def ensure_components(self) -> ChunkingComponents:
"""Initialize worker dependencies on first use and return them."""
if self.components is None:
self.components = init_components()
logger.info("Components initialized successfully")
return self.components
[docs]
def build_runtime() -> WorkerRuntime:
"""Build the queue runtime used by the chunking worker."""
config = get_config()
base = resolve_base_runtime_params(
redis_host=config.gateway.redis_host,
redis_port=config.gateway.redis_port,
job_ttl_seconds=config.gateway.job_ttl_seconds,
brpop_timeout_seconds=config.workers.brpop_timeout_seconds,
)
return WorkerRuntime(
redis_client=base.redis_client,
job_ttl_seconds=base.job_ttl_seconds,
chunk_db_commit_batch_size=config.workers.chunk_db_commit_batch_size,
brpop_timeout_seconds=base.brpop_timeout_seconds,
)
def _default_chunk_min_content_length() -> int:
return require_gateway_config("job_chunk_min_content_length")
def _enqueue_embed_act_jobs(runtime: WorkerRuntime, *, celex: str, force: bool) -> int:
queued = 0
for preset in list_embedding_presets(indexing_only=True):
queued_job_id = _enqueue_job(
runtime,
queue_name=preset.resolved_queue_name(),
job_type=JOB_TYPE_EMBED_ACT,
params={
"celex": celex,
"force": force,
},
dedupe_celex=celex,
)
if queued_job_id:
queued += 1
logger.info(
"[Pipeline] Enqueued embed_act for %s on preset %s (job_id=%s)",
celex,
preset.preset_id,
queued_job_id,
)
if queued == 0:
logger.warning(
"[Pipeline] No embed_act job queued for %s (no indexing-enabled presets or jobs already queued)",
celex,
)
return queued
def _enqueue_extract_act_job(runtime: WorkerRuntime, *, celex: str, force: bool) -> None:
queued_job_id = _enqueue_job(
runtime,
queue_name=EXTRACT_QUEUE_NAME,
job_type=JOB_TYPE_EXTRACT_ACT,
params={
"celex": celex,
"force": force,
},
dedupe_celex=celex,
)
if queued_job_id:
logger.info("[Pipeline] Enqueued extract_act for %s (job_id=%s)", celex, queued_job_id)
def _delete_chunk_vectors_for_ids(repos: dict[str, Any], chunk_ids: list[int]) -> None:
if not chunk_ids:
return
typed_chunk_ids = cast(list[str | int], chunk_ids)
for preset_id, repo in repos.items():
try:
if not repo.collection_exists():
continue
repo.delete_points(typed_chunk_ids)
cast(Any, repo).delete_points_by_filter({"chunk_id": chunk_ids})
except Exception as exc:
logger.warning(
"Failed to delete chunk vectors for preset %s (chunk_ids=%d): %s",
preset_id,
len(chunk_ids),
exc,
)
def _delete_act_vectors_for_ids(repos: dict[str, Any], act_ids: list[int]) -> None:
if not act_ids:
return
typed_act_ids = cast(list[str | int], act_ids)
for preset_id, repo in repos.items():
try:
if not repo.collection_exists():
continue
repo.delete_points(typed_act_ids)
except Exception as exc:
logger.warning(
"Failed to delete act vectors for preset %s (act_ids=%d): %s",
preset_id,
len(act_ids),
exc,
)
def _chunk_all_job_already_queued(runtime: WorkerRuntime) -> bool:
return _job_already_queued(
runtime,
queue_name=CHUNK_QUEUE_NAME,
job_type=JOB_TYPE_CHUNK_ALL,
)
def _maybe_enqueue_reconcile_job(runtime: WorkerRuntime) -> None:
"""
Enqueue a chunk_all job on startup if eligible subdivisions are not chunked.
Controlled by AUTO_CHUNK_RECONCILE / workers.auto_chunk_reconcile.
"""
enabled, ttl, _ = get_reconcile_params("chunk")
if not enabled:
return
def _do_reconcile() -> None:
components = runtime.ensure_components()
pg_repo = components.pg_repo
min_length = _default_chunk_min_content_length()
with pg_repo.get_session() as session:
missing_count = pg_repo.count_subdivisions_without_chunks(
session=session,
min_content_length=min_length,
)
if missing_count == 0:
logger.info("[Reconcile] All eligible subdivisions already chunked, skipping")
return
if _chunk_all_job_already_queued(runtime):
logger.info("[Reconcile] chunk_all already queued, skipping")
return
queued_job_id = _enqueue_job(
runtime,
queue_name=CHUNK_QUEUE_NAME,
job_type=JOB_TYPE_CHUNK_ALL,
params={
"min_content_length": min_length,
},
)
if queued_job_id:
logger.info("[Reconcile] Enqueued chunk_all (missing=%d)", missing_count)
with_reconcile_lock(
redis_client=runtime.redis_client,
lock_key="chunk:reconcile:lock",
lock_ttl=ttl,
action=_do_reconcile,
)
[docs]
def process_chunk_all(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None:
"""Dispatch chunk_act jobs for all acts (event-driven pipeline)."""
components = runtime.ensure_components()
pg_repo = components.pg_repo
min_length: int = to_optional_int(params.get("min_content_length")) or _default_chunk_min_content_length()
force: bool = coerce_bool(params.get("force"), False)
with pg_repo.get_session() as session:
acts: list[Any] = pg_repo.list_acts_with_metadata(session)
dispatch_all_act_jobs(
runtime=runtime,
job_id=job_id,
queue_name=CHUNK_QUEUE_NAME,
job_type=JOB_TYPE_CHUNK_ACT,
label="chunk_act",
acts=acts,
build_params=lambda celex: {"celex": celex, "min_content_length": min_length, "force": force},
error_label="Chunking",
)
[docs]
def process_chunk_act(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None:
"""Process chunk single act job"""
celex_value: Any = params.get("celex")
celex = normalize_celex(celex_value if isinstance(celex_value, str) else "")
if not celex:
raise ValueError("Missing or invalid 'celex' in job params")
force: bool = coerce_bool(params.get("force"), False)
min_content_length: int = to_optional_int(params.get("min_content_length")) or _default_chunk_min_content_length()
logger.info(f"[Job {job_id}] Starting chunk for act {celex}")
update_job_status(runtime, job_id, "running", f"Chunking act {celex}...")
try:
components = runtime.ensure_components()
pg_repo = components.pg_repo
chunk_vector_repos = components.chunk_vector_repos
act_vector_repos = components.act_vector_repos
chunker: Any = components.chunker
with pg_repo.get_session() as session:
# Get the act
act: Any = pg_repo.get_act_by_celex(session, celex)
if not act:
raise ValueError(f"Act {celex} not found")
# Get subdivisions for this act
subdivisions: list[Any] = pg_repo.list_subdivisions_for_act(session, act.id)
eligible_subdivisions = [
subdiv for subdiv in subdivisions if len(str(subdiv.content or "")) >= min_content_length
]
total = len(eligible_subdivisions)
# Prepare article-level chunking plan (EUR-Lex / Légifrance)
config = get_config()
art_plan = prepare_article_level_plan(
celex=celex,
subdivisions=subdivisions,
article_level_enabled=config.chunking.article_level_chunking,
)
logger.info(
f"[Job {job_id}] Found {total} eligible subdivisions for act {celex} "
f"(min_content_length={min_content_length}, article_level={art_plan.active})"
)
# Delete existing chunks if force
if force:
if min_content_length > 0:
for subdiv in eligible_subdivisions:
existing_chunk_ids = pg_repo.list_chunk_ids_for_subdivision(
session=session,
subdivision_id=subdiv.id,
)
if existing_chunk_ids:
_delete_chunk_vectors_for_ids(chunk_vector_repos, existing_chunk_ids)
pg_repo.delete_embedding_states_for_chunk_ids(session, existing_chunk_ids)
pg_repo.delete_chunks_for_subdivision(
session=session,
subdivision_id=subdiv.id,
)
else:
existing_chunk_ids = pg_repo.list_chunk_ids_for_act(session, act.id)
if existing_chunk_ids:
_delete_chunk_vectors_for_ids(chunk_vector_repos, existing_chunk_ids)
pg_repo.delete_embedding_states_for_chunk_ids(session, existing_chunk_ids)
pg_repo.delete_chunks_for_act(session, act.id)
_delete_act_vectors_for_ids(act_vector_repos, [act.id])
pg_repo.delete_embedding_states_for_act_ids(session, [act.id])
pg_repo.reset_extraction_status(session, act.id)
session.commit()
chunk_count = 0
skipped_existing = 0
skipped_aggregated = 0
MAX_RETRIES = 3
skipped_errors = 0
for i, subdiv in enumerate(eligible_subdivisions):
# Skip child paragraphs already aggregated into their parent article
if subdiv.id in art_plan.skip_ids:
skipped_aggregated += 1
continue
# Skip if already chunked (unless force)
if not force and pg_repo.subdivision_has_chunks(session, subdiv.id):
skipped_existing += 1
continue
for attempt in range(1, MAX_RETRIES + 1):
try:
article_level_chunks = make_article_level_chunks(
chunker=chunker,
subdivision=subdiv,
article_level_plan=art_plan,
)
if article_level_chunks is not None:
if not article_level_chunks:
break
chunks: list[Any] = article_level_chunks
else:
chunks = chunker.chunk_subdivision(
subdivision_id=subdiv.id,
content=subdiv.content,
subdivision_type=subdiv.subdivision_type,
)
pg_repo.insert_chunk_records(
session=session,
records=serialize_chunk_records(chunks),
)
chunk_count += len(chunks)
break # success
except Exception as e:
if attempt < MAX_RETRIES:
logger.warning(
"[Job %s] Retry %d/%d for subdivision %d: %s",
job_id,
attempt,
MAX_RETRIES,
subdiv.id,
e,
)
time.sleep(2**attempt)
else:
logger.error(
"[Job %s] Skipping subdivision %d after %d retries: %s",
job_id,
subdiv.id,
MAX_RETRIES,
e,
)
skipped_errors += 1
# Update progress
progress = int((i + 1) / total * 100)
if (i + 1) % 100 == 0 or (i + 1) == total:
update_job_status(
runtime,
job_id,
"running",
(
f"Chunking {celex}: {i + 1}/{total} eligible subdivisions "
f"({chunk_count} chunks, skipped_existing={skipped_existing})"
),
progress,
)
if chunk_count > 0:
_delete_act_vectors_for_ids(act_vector_repos, [act.id])
pg_repo.delete_embedding_states_for_act_ids(session, [act.id])
pg_repo.reset_extraction_status(session, act.id)
session.commit()
message = (
f"Chunked act {celex}: {chunk_count} chunks from {total} eligible subdivisions "
f"(skipped_existing={skipped_existing}, aggregated_children={skipped_aggregated}, errors={skipped_errors})"
)
logger.info(f"[Job {job_id}] {message}")
update_job_status(runtime, job_id, "completed", message, 100)
_enqueue_embed_act_jobs(runtime, celex=celex, force=force)
_enqueue_extract_act_job(runtime, celex=celex, force=force)
except Exception as e:
error_msg = f"Chunking failed: {str(e)}"
logger.error(f"[Job {job_id}] {error_msg}", exc_info=True)
update_job_status(runtime, job_id, "failed", error_msg)
raise
[docs]
def process_job(runtime: WorkerRuntime, job_data: dict[str, Any]) -> None:
"""Dispatch one chunking-related job through the instrumented worker loop."""
instrumented_process_job(
runtime=runtime,
job_data=job_data,
dispatch={
JOB_TYPE_CHUNK_ALL: process_chunk_all,
JOB_TYPE_CHUNK_ACT: process_chunk_act,
},
observe_execution=observe_job_execution,
observe_error=observe_job_error,
)
[docs]
def main() -> None:
"""Start the chunking worker loop and its Prometheus endpoint."""
metrics_port = int(get_config().workers.chunk_metrics_port)
start_http_server(metrics_port)
logger.info("Chunking worker metrics endpoint started on port %d", metrics_port)
runtime = build_runtime()
_, _, reconcile_interval = get_reconcile_params("chunk")
run_worker_loop(
queue_name=CHUNK_QUEUE_NAME,
worker_name="Chunking Worker",
redis_client=runtime.redis_client,
brpop_timeout_seconds=runtime.brpop_timeout_seconds,
process_job=lambda job_data: process_job(runtime, job_data),
reconcile_callback=lambda: _maybe_enqueue_reconcile_job(runtime),
reconcile_interval_seconds=reconcile_interval,
)
if __name__ == "__main__":
main()