"""
Extraction Worker
Processes relation extraction jobs from Redis queue
"""
import logging
from dataclasses import dataclass
from typing import Any
from extraction_worker.bootstrap import ExtractionComponents, init_components
from extraction_worker.graph import run_community_detection
from extraction_worker.service_metrics import (
PrometheusExtractionMetricsRecorder,
observe_job_error,
observe_job_execution,
)
from lalandre_core.config import get_config
from lalandre_core.http.llm_client import JSONHTTPLLMClient
from lalandre_core.logging_setup import setup_worker_logging
from lalandre_core.queue.dispatch_all import dispatch_all_act_jobs
from lalandre_core.queue.job_queue import (
QueueRuntime,
update_job_status,
)
from lalandre_core.queue.job_queue import (
enqueue_job as _enqueue_job,
)
from lalandre_core.queue.job_queue import (
job_already_queued as _job_already_queued,
)
from lalandre_core.queue.reconcile import with_reconcile_lock
from lalandre_core.queue.worker_config import get_reconcile_params, require_gateway_config
from lalandre_core.queue.worker_loop import (
instrumented_process_job,
resolve_base_runtime_params,
run_worker_loop,
)
from lalandre_core.utils import APIKeyPool, coerce_bool, coerce_float, normalize_celex
from lalandre_extraction.metrics import set_extraction_metrics_recorder
from prometheus_client import start_http_server
setup_worker_logging()
logger = logging.getLogger(__name__)
EXTRACT_QUEUE_NAME = "extract_jobs"
JOB_TYPE_EXTRACT_ALL = "extract_all"
JOB_TYPE_EXTRACT_ACT = "extract_act"
JOB_TYPE_SUMMARIZE_ALL = "summarize_all"
JOB_TYPE_SUMMARIZE_ACT = "summarize_act"
JOB_TYPE_BUILD_COMMUNITIES = "build_communities"
[docs]
@dataclass
class WorkerRuntime(QueueRuntime):
"""Runtime state and lazy dependencies for the extraction worker loop."""
brpop_timeout_seconds: int
components: ExtractionComponents | None = None
[docs]
def ensure_components(self) -> ExtractionComponents:
"""Initialize worker dependencies on first use and return them."""
if self.components is None:
self.components = init_components()
logger.info("Components initialized successfully")
return self.components
[docs]
def build_runtime() -> WorkerRuntime:
"""Build the queue runtime used by the extraction worker."""
config = get_config()
base = resolve_base_runtime_params(
redis_host=config.gateway.redis_host,
redis_port=config.gateway.redis_port,
job_ttl_seconds=config.gateway.job_ttl_seconds,
brpop_timeout_seconds=config.workers.brpop_timeout_seconds,
)
return WorkerRuntime(
redis_client=base.redis_client,
job_ttl_seconds=base.job_ttl_seconds,
brpop_timeout_seconds=base.brpop_timeout_seconds,
)
def _extract_all_job_already_queued(runtime: WorkerRuntime) -> bool:
return _job_already_queued(
runtime,
queue_name=EXTRACT_QUEUE_NAME,
job_type=JOB_TYPE_EXTRACT_ALL,
)
def _summarize_all_job_already_queued(runtime: WorkerRuntime) -> bool:
return _job_already_queued(
runtime,
queue_name=EXTRACT_QUEUE_NAME,
job_type=JOB_TYPE_SUMMARIZE_ALL,
)
def _maybe_enqueue_summarize_act_job(runtime: WorkerRuntime, *, celex: str) -> None:
queued_job_id = _enqueue_job(
runtime,
queue_name=EXTRACT_QUEUE_NAME,
job_type=JOB_TYPE_SUMMARIZE_ACT,
params={"celex": celex},
dedupe_celex=celex,
)
if queued_job_id:
logger.info("[Summaries] Enqueued summarize_act for %s (job_id=%s)", celex, queued_job_id)
def _maybe_enqueue_build_communities(runtime: WorkerRuntime) -> None:
"""
Enqueue a build_communities job unless one is already queued.
Called after each successful extract_act so that community detection
runs automatically once all pending extraction jobs have drained.
Because job_already_queued deduplicates on job_type when celex=None,
only one build_communities job will ever sit in the queue at a time.
"""
if _job_already_queued(
runtime,
queue_name=EXTRACT_QUEUE_NAME,
job_type=JOB_TYPE_BUILD_COMMUNITIES,
):
return
config = get_config()
queued_job_id = _enqueue_job(
runtime,
queue_name=EXTRACT_QUEUE_NAME,
job_type=JOB_TYPE_BUILD_COMMUNITIES,
params={
"resolution": config.workers.community_resolution,
"min_community_size": config.workers.community_min_size,
},
)
if queued_job_id:
logger.info("[Communities] Enqueued build_communities job %s", queued_job_id)
def _maybe_enqueue_reconcile_job(runtime: WorkerRuntime) -> None:
"""
Enqueue an extract_all job on startup if acts are not extracted.
Controlled by AUTO_EXTRACT_RECONCILE / workers.auto_extract_reconcile.
"""
enabled, ttl, _ = get_reconcile_params("extract")
if not enabled:
return
def _do_reconcile() -> None:
components = runtime.ensure_components()
pg_repo = components.pg_repo
min_confidence: float = require_gateway_config("job_extract_min_confidence")
skip_existing: bool = require_gateway_config("job_extract_skip_existing_default")
with pg_repo.get_session() as session:
total_acts = pg_repo.count_acts(session)
if total_acts == 0:
logger.info("[Reconcile] No acts found, skipping")
return
stale_timeout = get_config().workers.extract_stale_timeout_minutes
reset_count = pg_repo.reset_stale_extracting_acts(session, stale_timeout)
if reset_count > 0:
logger.warning(
"[Reconcile] Reset %d acts stuck in 'extracting' (timeout=%d min)",
reset_count,
stale_timeout,
)
session.commit()
pending = pg_repo.count_acts_pending_extraction(session)
if pending == 0:
logger.info("[Reconcile] All acts extracted, skipping")
return
if _extract_all_job_already_queued(runtime):
logger.info("[Reconcile] extract_all already queued, skipping")
return
queued_job_id = _enqueue_job(
runtime,
queue_name=EXTRACT_QUEUE_NAME,
job_type=JOB_TYPE_EXTRACT_ALL,
params={
"min_confidence": min_confidence,
"skip_existing": skip_existing,
},
)
if queued_job_id:
logger.info(f"[Reconcile] Enqueued extract_all (pending={pending})")
with_reconcile_lock(
redis_client=runtime.redis_client,
lock_key="extract:reconcile:lock",
lock_ttl=ttl,
action=_do_reconcile,
)
def _maybe_enqueue_summary_reconcile_job(runtime: WorkerRuntime) -> None:
enabled, ttl, _ = get_reconcile_params("extract")
if not enabled:
return
def _do_reconcile() -> None:
components = runtime.ensure_components()
celexes = components.act_summary_service.list_celex_needing_canonical_summary()
if not celexes:
logger.info("[Summary Reconcile] All act summaries are up to date, skipping")
return
if _summarize_all_job_already_queued(runtime):
logger.info("[Summary Reconcile] summarize_all already queued, skipping")
return
queued_job_id = _enqueue_job(
runtime,
queue_name=EXTRACT_QUEUE_NAME,
job_type=JOB_TYPE_SUMMARIZE_ALL,
params={},
)
if queued_job_id:
logger.info(
"[Summary Reconcile] Enqueued summarize_all (pending=%d)",
len(celexes),
)
with_reconcile_lock(
redis_client=runtime.redis_client,
lock_key="summary:reconcile:lock",
lock_ttl=ttl,
action=_do_reconcile,
)
[docs]
def process_summarize_all(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None:
"""Enqueue summarize jobs for every act missing a canonical summary."""
logger.info("[Job %s] Dispatching summarize_act jobs for acts needing summaries", job_id)
update_job_status(runtime, job_id, "running", "Dispatching summarize_act jobs...")
try:
components = runtime.ensure_components()
celexes = components.act_summary_service.list_celex_needing_canonical_summary()
total = len(celexes)
if total == 0:
update_job_status(runtime, job_id, "completed", "No act summaries to refresh", 100)
return
queued = 0
skipped = 0
for index, celex in enumerate(celexes, start=1):
queued_job_id = _enqueue_job(
runtime,
queue_name=EXTRACT_QUEUE_NAME,
job_type=JOB_TYPE_SUMMARIZE_ACT,
params={"celex": celex},
dedupe_celex=celex,
)
if queued_job_id:
queued += 1
else:
skipped += 1
progress = int((index / total) * 100)
update_job_status(
runtime,
job_id,
"running",
(f"Queued summarize_act for {index}/{total} acts (queued={queued}, skipped={skipped})"),
progress,
)
message = f"Queued {queued} summarize_act jobs (skipped {skipped})"
logger.info("[Job %s] %s", job_id, message)
update_job_status(runtime, job_id, "completed", message, 100)
except Exception as e:
error_msg = f"Summary reconcile failed: {e}"
logger.error("[Job %s] %s", job_id, error_msg, exc_info=True)
update_job_status(runtime, job_id, "failed", error_msg)
raise
[docs]
def process_summarize_act(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None:
"""Refresh the canonical summary for one CELEX identifier."""
try:
celex_value: Any = params.get("celex")
celex_raw = celex_value.strip() if isinstance(celex_value, str) else ""
celex = normalize_celex(celex_raw)
if not celex:
raise ValueError("Missing required 'celex' for summarize_act job")
logger.info("[Job %s] Starting canonical summary refresh for %s", job_id, celex)
update_job_status(runtime, job_id, "running", f"Summarizing act {celex}...")
components = runtime.ensure_components()
snapshot = components.act_summary_service.refresh_canonical_summary_for_celex(celex)
if snapshot.available:
suffix = " (stale)" if snapshot.is_stale else ""
message = f"Canonical summary ready for {celex}{suffix}"
logger.info("[Job %s] %s", job_id, message)
update_job_status(runtime, job_id, "completed", message, 100)
return
error_detail = snapshot.error_text or "Summary could not be materialized"
raise RuntimeError(error_detail)
except Exception as e:
error_msg = f"Canonical summary failed: {e}"
logger.error("[Job %s] %s", job_id, error_msg, exc_info=True)
update_job_status(runtime, job_id, "failed", error_msg)
raise
def _build_community_llm_client() -> JSONHTTPLLMClient | None:
"""Build a LLM client for community summary generation, or None if unavailable."""
try:
key_pool = APIKeyPool.from_env("MISTRAL_API_KEY", start_index=6)
except ValueError:
logger.info("No MISTRAL_API_KEY found; community summaries will be deterministic")
return None
config = get_config()
extraction_cfg = config.extraction
return JSONHTTPLLMClient(
provider=extraction_cfg.llm_provider,
model=extraction_cfg.llm_model,
base_url=extraction_cfg.llm_base_url,
timeout_seconds=extraction_cfg.llm_timeout_seconds,
api_key=key_pool.next_key(),
max_output_tokens=512,
temperature=0.1,
system_prompt="Tu es un expert en droit européen et français.",
)
[docs]
def process_build_communities(runtime: WorkerRuntime, job_id: str, params: dict[str, Any]) -> None:
"""Detect Louvain communities on the Act graph and write back to Neo4j."""
logger.info("[Job %s] Starting community detection", job_id)
update_job_status(runtime, job_id, "running", "Running Louvain community detection...")
try:
components = runtime.ensure_components()
neo4j_repo = components.neo4j_repo
driver = neo4j_repo.driver
database = neo4j_repo.settings.database or "neo4j"
workers_cfg = get_config().workers
resolution = float(params.get("resolution") or workers_cfg.community_resolution)
min_community_size = int(params.get("min_community_size") or workers_cfg.community_min_size)
llm_client = _build_community_llm_client()
result = run_community_detection(
driver=driver,
database=database,
resolution=resolution,
min_community_size=min_community_size,
llm_client=llm_client,
)
status_val = result.get("status", "unknown")
if status_val == "skipped":
message = "Community detection skipped: no Act nodes found"
else:
message = (
f"Community detection complete: "
f"{result.get('num_communities', 0)} communities, "
f"{result.get('num_nodes', 0)} nodes"
)
logger.info("[Job %s] %s", job_id, message)
update_job_status(runtime, job_id, "completed", message, 100)
except Exception as e:
error_msg = f"Community detection failed: {e}"
logger.error("[Job %s] %s", job_id, error_msg, exc_info=True)
update_job_status(runtime, job_id, "failed", error_msg)
raise
[docs]
def process_job(runtime: WorkerRuntime, job_data: dict[str, Any]) -> None:
"""Dispatch one extraction-related job through the instrumented worker loop."""
instrumented_process_job(
runtime=runtime,
job_data=job_data,
dispatch={
JOB_TYPE_EXTRACT_ALL: process_extract_all,
JOB_TYPE_EXTRACT_ACT: process_extract_act,
JOB_TYPE_SUMMARIZE_ALL: process_summarize_all,
JOB_TYPE_SUMMARIZE_ACT: process_summarize_act,
JOB_TYPE_BUILD_COMMUNITIES: process_build_communities,
},
observe_execution=observe_job_execution,
observe_error=observe_job_error,
)
def _run_reconcile(runtime: WorkerRuntime) -> None:
_maybe_enqueue_reconcile_job(runtime)
_maybe_enqueue_summary_reconcile_job(runtime)
[docs]
def main() -> None:
"""Start the extraction worker loop and its Prometheus endpoint."""
set_extraction_metrics_recorder(PrometheusExtractionMetricsRecorder())
metrics_port = int(get_config().workers.extract_metrics_port)
start_http_server(metrics_port)
logger.info("Extraction worker metrics endpoint started on port %d", metrics_port)
runtime = build_runtime()
_, _, reconcile_interval = get_reconcile_params("extract")
run_worker_loop(
queue_name=EXTRACT_QUEUE_NAME,
worker_name="Extraction Worker",
redis_client=runtime.redis_client,
brpop_timeout_seconds=runtime.brpop_timeout_seconds,
process_job=lambda job_data: process_job(runtime, job_data),
reconcile_callback=lambda: _run_reconcile(runtime),
reconcile_interval_seconds=reconcile_interval,
)
if __name__ == "__main__":
main()