Source code for extraction_worker.worker

"""
Extraction Worker
Processes relation extraction jobs from Redis queue
"""

import logging
from dataclasses import dataclass
from typing import Any

from extraction_worker.bootstrap import ExtractionComponents, init_components
from extraction_worker.graph import run_community_detection
from extraction_worker.service_metrics import (
    PrometheusExtractionMetricsRecorder,
    observe_job_error,
    observe_job_execution,
)
from lalandre_core.config import get_config
from lalandre_core.http.llm_client import JSONHTTPLLMClient
from lalandre_core.logging_setup import setup_worker_logging
from lalandre_core.queue.dispatch_all import dispatch_all_act_jobs
from lalandre_core.queue.job_queue import (
    QueueRuntime,
    update_job_status,
)
from lalandre_core.queue.job_queue import (
    enqueue_job as _enqueue_job,
)
from lalandre_core.queue.job_queue import (
    job_already_queued as _job_already_queued,
)
from lalandre_core.queue.reconcile import with_reconcile_lock
from lalandre_core.queue.worker_config import get_reconcile_params, require_gateway_config
from lalandre_core.queue.worker_loop import (
    instrumented_process_job,
    resolve_base_runtime_params,
    run_worker_loop,
)
from lalandre_core.utils import APIKeyPool, coerce_bool, coerce_float, normalize_celex
from lalandre_extraction.metrics import set_extraction_metrics_recorder
from prometheus_client import start_http_server

setup_worker_logging()
logger = logging.getLogger(__name__)

EXTRACT_QUEUE_NAME = "extract_jobs"
JOB_TYPE_EXTRACT_ALL = "extract_all"
JOB_TYPE_EXTRACT_ACT = "extract_act"
JOB_TYPE_SUMMARIZE_ALL = "summarize_all"
JOB_TYPE_SUMMARIZE_ACT = "summarize_act"
JOB_TYPE_BUILD_COMMUNITIES = "build_communities"


[docs] @dataclass class WorkerRuntime(QueueRuntime): """Runtime state and lazy dependencies for the extraction worker loop.""" brpop_timeout_seconds: int components: ExtractionComponents | None = None
[docs] def ensure_components(self) -> ExtractionComponents: """Initialize worker dependencies on first use and return them.""" if self.components is None: self.components = init_components() logger.info("Components initialized successfully") return self.components
[docs] def build_runtime() -> WorkerRuntime: """Build the queue runtime used by the extraction worker.""" config = get_config() base = resolve_base_runtime_params( redis_host=config.gateway.redis_host, redis_port=config.gateway.redis_port, job_ttl_seconds=config.gateway.job_ttl_seconds, brpop_timeout_seconds=config.workers.brpop_timeout_seconds, ) return WorkerRuntime( redis_client=base.redis_client, job_ttl_seconds=base.job_ttl_seconds, brpop_timeout_seconds=base.brpop_timeout_seconds, )
def _extract_all_job_already_queued(runtime: WorkerRuntime) -> bool: return _job_already_queued( runtime, queue_name=EXTRACT_QUEUE_NAME, job_type=JOB_TYPE_EXTRACT_ALL, ) def _summarize_all_job_already_queued(runtime: WorkerRuntime) -> bool: return _job_already_queued( runtime, queue_name=EXTRACT_QUEUE_NAME, job_type=JOB_TYPE_SUMMARIZE_ALL, ) def _maybe_enqueue_summarize_act_job(runtime: WorkerRuntime, *, celex: str) -> None: queued_job_id = _enqueue_job( runtime, queue_name=EXTRACT_QUEUE_NAME, job_type=JOB_TYPE_SUMMARIZE_ACT, params={"celex": celex}, dedupe_celex=celex, ) if queued_job_id: logger.info("[Summaries] Enqueued summarize_act for %s (job_id=%s)", celex, queued_job_id) def _maybe_enqueue_build_communities(runtime: WorkerRuntime) -> None: """ Enqueue a build_communities job unless one is already queued. Called after each successful extract_act so that community detection runs automatically once all pending extraction jobs have drained. Because job_already_queued deduplicates on job_type when celex=None, only one build_communities job will ever sit in the queue at a time. """ if _job_already_queued( runtime, queue_name=EXTRACT_QUEUE_NAME, job_type=JOB_TYPE_BUILD_COMMUNITIES, ): return config = get_config() queued_job_id = _enqueue_job( runtime, queue_name=EXTRACT_QUEUE_NAME, job_type=JOB_TYPE_BUILD_COMMUNITIES, params={ "resolution": config.workers.community_resolution, "min_community_size": config.workers.community_min_size, }, ) if queued_job_id: logger.info("[Communities] Enqueued build_communities job %s", queued_job_id) def _maybe_enqueue_reconcile_job(runtime: WorkerRuntime) -> None: """ Enqueue an extract_all job on startup if acts are not extracted. Controlled by AUTO_EXTRACT_RECONCILE / workers.auto_extract_reconcile. """ enabled, ttl, _ = get_reconcile_params("extract") if not enabled: return def _do_reconcile() -> None: components = runtime.ensure_components() pg_repo = components.pg_repo min_confidence: float = require_gateway_config("job_extract_min_confidence") skip_existing: bool = require_gateway_config("job_extract_skip_existing_default") with pg_repo.get_session() as session: total_acts = pg_repo.count_acts(session) if total_acts == 0: logger.info("[Reconcile] No acts found, skipping") return stale_timeout = get_config().workers.extract_stale_timeout_minutes reset_count = pg_repo.reset_stale_extracting_acts(session, stale_timeout) if reset_count > 0: logger.warning( "[Reconcile] Reset %d acts stuck in 'extracting' (timeout=%d min)", reset_count, stale_timeout, ) session.commit() pending = pg_repo.count_acts_pending_extraction(session) if pending == 0: logger.info("[Reconcile] All acts extracted, skipping") return if _extract_all_job_already_queued(runtime): logger.info("[Reconcile] extract_all already queued, skipping") return queued_job_id = _enqueue_job( runtime, queue_name=EXTRACT_QUEUE_NAME, job_type=JOB_TYPE_EXTRACT_ALL, params={ "min_confidence": min_confidence, "skip_existing": skip_existing, }, ) if queued_job_id: logger.info(f"[Reconcile] Enqueued extract_all (pending={pending})") with_reconcile_lock( redis_client=runtime.redis_client, lock_key="extract:reconcile:lock", lock_ttl=ttl, action=_do_reconcile, ) def _maybe_enqueue_summary_reconcile_job(runtime: WorkerRuntime) -> None: enabled, ttl, _ = get_reconcile_params("extract") if not enabled: return def _do_reconcile() -> None: components = runtime.ensure_components() celexes = components.act_summary_service.list_celex_needing_canonical_summary() if not celexes: logger.info("[Summary Reconcile] All act summaries are up to date, skipping") return if _summarize_all_job_already_queued(runtime): logger.info("[Summary Reconcile] summarize_all already queued, skipping") return queued_job_id = _enqueue_job( runtime, queue_name=EXTRACT_QUEUE_NAME, job_type=JOB_TYPE_SUMMARIZE_ALL, params={}, ) if queued_job_id: logger.info( "[Summary Reconcile] Enqueued summarize_all (pending=%d)", len(celexes), ) with_reconcile_lock( redis_client=runtime.redis_client, lock_key="summary:reconcile:lock", lock_ttl=ttl, action=_do_reconcile, )
[docs] def process_extract_all(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None: """Dispatch extract_act jobs for all acts (event-driven pipeline).""" components = runtime.ensure_components() pg_repo = components.pg_repo skip_existing = coerce_bool( params.get("skip_existing"), require_gateway_config("job_extract_skip_existing_default"), ) min_confidence = coerce_float( params.get("min_confidence"), require_gateway_config("job_extract_min_confidence"), ) with pg_repo.get_session() as session: acts = pg_repo.list_acts_with_metadata(session) dispatch_all_act_jobs( runtime=runtime, job_id=job_id, queue_name=EXTRACT_QUEUE_NAME, job_type=JOB_TYPE_EXTRACT_ACT, label="extract_act", acts=acts, build_params=lambda celex: { "celex": celex, "min_confidence": min_confidence, }, skip_filter=((lambda act: getattr(act, "extraction_status", None) == "extracted") if skip_existing else None), error_label="Extraction", )
[docs] def process_extract_act(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None: """Process extract single act job""" try: celex_value: Any = params.get("celex") celex_raw = celex_value.strip() if isinstance(celex_value, str) else "" celex = normalize_celex(celex_raw) if not celex: raise ValueError("Missing required 'celex' for extract_act job") min_confidence = coerce_float( params.get("min_confidence"), require_gateway_config("job_extract_min_confidence"), ) force = coerce_bool(params.get("force"), False) logger.info(f"[Job {job_id}] Starting extract for act {celex} (force={force})") update_job_status(runtime, job_id, "running", f"Extracting relations from {celex}...") components = runtime.ensure_components() relation_service = components.relation_graph relation_service.min_confidence = min_confidence logger.info(f"[Job {job_id}] Reusing initialized relation graph service...") # Extract relations for this act result: dict[str, Any] = relation_service.extract_and_store_for_act( celex=celex, sync_to_neo4j=True, force=force, ) relations_count_raw: Any = result.get("relations_stored", 0) relations_count = int(relations_count_raw) if isinstance(relations_count_raw, (int, float)) else 0 status_raw: Any = result.get("status", "unknown") status = str(status_raw) if status == "success": message = f"Extracted {relations_count} relations from {celex}" else: message = f"Extraction for {celex} completed with status: {status}" logger.info(f"[Job {job_id}] {message}") update_job_status(runtime, job_id, "completed", message, 100) _maybe_enqueue_summarize_act_job(runtime, celex=celex) # Trigger community rebuild after this act's extraction succeeds. _maybe_enqueue_build_communities(runtime) except Exception as e: error_msg = f"Extraction failed: {str(e)}" logger.error(f"[Job {job_id}] {error_msg}", exc_info=True) update_job_status(runtime, job_id, "failed", error_msg) raise
[docs] def process_summarize_all(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None: """Enqueue summarize jobs for every act missing a canonical summary.""" logger.info("[Job %s] Dispatching summarize_act jobs for acts needing summaries", job_id) update_job_status(runtime, job_id, "running", "Dispatching summarize_act jobs...") try: components = runtime.ensure_components() celexes = components.act_summary_service.list_celex_needing_canonical_summary() total = len(celexes) if total == 0: update_job_status(runtime, job_id, "completed", "No act summaries to refresh", 100) return queued = 0 skipped = 0 for index, celex in enumerate(celexes, start=1): queued_job_id = _enqueue_job( runtime, queue_name=EXTRACT_QUEUE_NAME, job_type=JOB_TYPE_SUMMARIZE_ACT, params={"celex": celex}, dedupe_celex=celex, ) if queued_job_id: queued += 1 else: skipped += 1 progress = int((index / total) * 100) update_job_status( runtime, job_id, "running", (f"Queued summarize_act for {index}/{total} acts (queued={queued}, skipped={skipped})"), progress, ) message = f"Queued {queued} summarize_act jobs (skipped {skipped})" logger.info("[Job %s] %s", job_id, message) update_job_status(runtime, job_id, "completed", message, 100) except Exception as e: error_msg = f"Summary reconcile failed: {e}" logger.error("[Job %s] %s", job_id, error_msg, exc_info=True) update_job_status(runtime, job_id, "failed", error_msg) raise
[docs] def process_summarize_act(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None: """Refresh the canonical summary for one CELEX identifier.""" try: celex_value: Any = params.get("celex") celex_raw = celex_value.strip() if isinstance(celex_value, str) else "" celex = normalize_celex(celex_raw) if not celex: raise ValueError("Missing required 'celex' for summarize_act job") logger.info("[Job %s] Starting canonical summary refresh for %s", job_id, celex) update_job_status(runtime, job_id, "running", f"Summarizing act {celex}...") components = runtime.ensure_components() snapshot = components.act_summary_service.refresh_canonical_summary_for_celex(celex) if snapshot.available: suffix = " (stale)" if snapshot.is_stale else "" message = f"Canonical summary ready for {celex}{suffix}" logger.info("[Job %s] %s", job_id, message) update_job_status(runtime, job_id, "completed", message, 100) return error_detail = snapshot.error_text or "Summary could not be materialized" raise RuntimeError(error_detail) except Exception as e: error_msg = f"Canonical summary failed: {e}" logger.error("[Job %s] %s", job_id, error_msg, exc_info=True) update_job_status(runtime, job_id, "failed", error_msg) raise
def _build_community_llm_client() -> JSONHTTPLLMClient | None: """Build a LLM client for community summary generation, or None if unavailable.""" try: key_pool = APIKeyPool.from_env("MISTRAL_API_KEY", start_index=6) except ValueError: logger.info("No MISTRAL_API_KEY found; community summaries will be deterministic") return None config = get_config() extraction_cfg = config.extraction return JSONHTTPLLMClient( provider=extraction_cfg.llm_provider, model=extraction_cfg.llm_model, base_url=extraction_cfg.llm_base_url, timeout_seconds=extraction_cfg.llm_timeout_seconds, api_key=key_pool.next_key(), max_output_tokens=512, temperature=0.1, system_prompt="Tu es un expert en droit européen et français.", )
[docs] def process_build_communities(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None: """Detect Louvain communities on the Act graph and write back to Neo4j.""" logger.info("[Job %s] Starting community detection", job_id) update_job_status(runtime, job_id, "running", "Running Louvain community detection...") try: components = runtime.ensure_components() neo4j_repo = components.neo4j_repo driver = neo4j_repo.driver database = neo4j_repo.settings.database or "neo4j" workers_cfg = get_config().workers resolution = float(params.get("resolution") or workers_cfg.community_resolution) min_community_size = int(params.get("min_community_size") or workers_cfg.community_min_size) llm_client = _build_community_llm_client() result = run_community_detection( driver=driver, database=database, resolution=resolution, min_community_size=min_community_size, llm_client=llm_client, ) status_val = result.get("status", "unknown") if status_val == "skipped": message = "Community detection skipped: no Act nodes found" else: message = ( f"Community detection complete: " f"{result.get('num_communities', 0)} communities, " f"{result.get('num_nodes', 0)} nodes" ) logger.info("[Job %s] %s", job_id, message) update_job_status(runtime, job_id, "completed", message, 100) except Exception as e: error_msg = f"Community detection failed: {e}" logger.error("[Job %s] %s", job_id, error_msg, exc_info=True) update_job_status(runtime, job_id, "failed", error_msg) raise
[docs] def process_job(runtime: WorkerRuntime, job_data: dict[str, Any]) -> None: """Dispatch one extraction-related job through the instrumented worker loop.""" instrumented_process_job( runtime=runtime, job_data=job_data, dispatch={ JOB_TYPE_EXTRACT_ALL: process_extract_all, JOB_TYPE_EXTRACT_ACT: process_extract_act, JOB_TYPE_SUMMARIZE_ALL: process_summarize_all, JOB_TYPE_SUMMARIZE_ACT: process_summarize_act, JOB_TYPE_BUILD_COMMUNITIES: process_build_communities, }, observe_execution=observe_job_execution, observe_error=observe_job_error, )
def _run_reconcile(runtime: WorkerRuntime) -> None: _maybe_enqueue_reconcile_job(runtime) _maybe_enqueue_summary_reconcile_job(runtime)
[docs] def main() -> None: """Start the extraction worker loop and its Prometheus endpoint.""" set_extraction_metrics_recorder(PrometheusExtractionMetricsRecorder()) metrics_port = int(get_config().workers.extract_metrics_port) start_http_server(metrics_port) logger.info("Extraction worker metrics endpoint started on port %d", metrics_port) runtime = build_runtime() _, _, reconcile_interval = get_reconcile_params("extract") run_worker_loop( queue_name=EXTRACT_QUEUE_NAME, worker_name="Extraction Worker", redis_client=runtime.redis_client, brpop_timeout_seconds=runtime.brpop_timeout_seconds, process_job=lambda job_data: process_job(runtime, job_data), reconcile_callback=lambda: _run_reconcile(runtime), reconcile_interval_seconds=reconcile_interval, )
if __name__ == "__main__": main()