Source code for rag_service.metrics.service

"""Prometheus helpers for rag-service runtime metrics."""

from typing import Any, Dict, Optional, cast

from lalandre_core.utils.metrics_utils import (
    LATENCY_BUCKETS as _LATENCY_BUCKETS,
)
from lalandre_core.utils.metrics_utils import (
    TOP_K_BUCKETS as _TOP_K_BUCKETS,
)
from lalandre_core.utils.metrics_utils import (
    classify_error as _classify_error,
)
from lalandre_core.utils.metrics_utils import (
    normalize_granularity as _normalize_granularity,
)
from lalandre_core.utils.metrics_utils import (
    normalize_label as _normalize_label,
)
from lalandre_core.utils.metrics_utils import (
    normalize_search_mode as _normalize_search_mode,
)
from lalandre_core.utils.metrics_utils import (
    status_class as _status_class,
)
from prometheus_client import Counter, Gauge, Histogram
from rag_service.bootstrap import RagComponents

HTTP_REQUESTS_TOTAL = Counter(
    "lalandre_rag_service_http_requests_total",
    "rag-service HTTP requests by path, method, and status class.",
    ["path", "method", "status_class"],
)

HTTP_REQUEST_DURATION_SECONDS = Histogram(
    "lalandre_rag_service_http_request_duration_seconds",
    "rag-service HTTP request duration.",
    ["path", "method"],
    buckets=_LATENCY_BUCKETS,
)

QUERY_REQUESTS_TOTAL = Counter(
    "lalandre_rag_service_query_requests_total",
    "rag-service /query requests by mode, granularity, and outcome.",
    ["mode", "granularity", "outcome"],
)

QUERY_DURATION_SECONDS = Histogram(
    "lalandre_rag_service_query_duration_seconds",
    "rag-service /query request duration by mode.",
    ["mode"],
    buckets=_LATENCY_BUCKETS,
)

QUERY_TOP_K = Histogram(
    "lalandre_rag_service_query_top_k",
    "Requested top_k for rag-service /query calls by mode.",
    ["mode"],
    buckets=_TOP_K_BUCKETS,
)

SEARCH_REQUESTS_TOTAL = Counter(
    "lalandre_rag_service_search_requests_total",
    "rag-service /search requests by mode, granularity, and outcome.",
    ["mode", "granularity", "outcome"],
)

SEARCH_DURATION_SECONDS = Histogram(
    "lalandre_rag_service_search_duration_seconds",
    "rag-service /search request duration by mode.",
    ["mode"],
    buckets=_LATENCY_BUCKETS,
)

SEARCH_TOP_K = Histogram(
    "lalandre_rag_service_search_top_k",
    "Requested top_k for rag-service /search calls by mode.",
    ["mode"],
    buckets=_TOP_K_BUCKETS,
)

PHASE_DURATION_SECONDS = Histogram(
    "lalandre_rag_service_phase_duration_seconds",
    "rag-service runtime phase durations extracted from response metadata.",
    ["mode", "phase"],
    buckets=_LATENCY_BUCKETS,
)

PROVIDER_ERRORS_TOTAL = Counter(
    "lalandre_rag_service_provider_errors_total",
    "rag-service provider/runtime errors by mode and stage.",
    ["mode", "stage", "provider", "error_type"],
)

BACKEND_HEALTH = Gauge(
    "lalandre_rag_service_backend_health",
    "rag-service backend health and collection availability.",
    ["backend", "target"],
)


[docs] def observe_http_request( *, path: str, method: str, status_code: int, duration_seconds: float, ) -> None: """Record one HTTP request handled by rag-service.""" normalized_path = _normalize_path(path) normalized_method = _normalize_label(method.upper()) HTTP_REQUESTS_TOTAL.labels( path=normalized_path, method=normalized_method, status_class=_status_class(status_code), ).inc() HTTP_REQUEST_DURATION_SECONDS.labels( path=normalized_path, method=normalized_method, ).observe(max(float(duration_seconds), 0.0))
[docs] def observe_query_request( *, mode: Optional[str], granularity: Optional[str], top_k: Optional[int], duration_seconds: float, outcome: str, metadata: Optional[Dict[str, Any]] = None, ) -> None: """Record one `/query` request handled by rag-service.""" normalized_mode = _normalize_query_mode(mode) QUERY_REQUESTS_TOTAL.labels( mode=normalized_mode, granularity=_normalize_granularity(granularity), outcome=_normalize_label(outcome), ).inc() QUERY_DURATION_SECONDS.labels(mode=normalized_mode).observe(max(float(duration_seconds), 0.0)) QUERY_TOP_K.labels(mode=normalized_mode).observe(max(float(top_k or 1), 1.0)) if metadata: _observe_phase_timings(normalized_mode, metadata)
[docs] def observe_search_request( *, mode: Optional[str], granularity: Optional[str], top_k: Optional[int], duration_seconds: float, outcome: str, ) -> None: """Record one `/search` request handled by rag-service.""" normalized_mode = _normalize_search_mode(mode) SEARCH_REQUESTS_TOTAL.labels( mode=normalized_mode, granularity=_normalize_granularity(granularity), outcome=_normalize_label(outcome), ).inc() SEARCH_DURATION_SECONDS.labels(mode=normalized_mode).observe(max(float(duration_seconds), 0.0)) SEARCH_TOP_K.labels(mode=normalized_mode).observe(max(float(top_k or 1), 1.0))
[docs] def observe_provider_error( *, mode: Optional[str], stage: str, exc_or_reason: Any, ) -> None: """Record one provider-side error or fallback reason.""" provider, error_type = _classify_error(exc_or_reason) PROVIDER_ERRORS_TOTAL.labels( mode=_normalize_query_mode(mode), stage=_normalize_label(stage), provider=provider, error_type=error_type, ).inc()
[docs] def observe_provider_fallbacks( *, mode: Optional[str], metadata: Dict[str, Any], ) -> None: """Emit provider fallback counters derived from response metadata.""" if metadata.get("llm_fallback"): observe_provider_error( mode=mode, stage="llm_fallback", exc_or_reason=metadata.get("llm_fallback_reason"), ) if metadata.get("graph_fallback"): observe_provider_error( mode=mode, stage="graph_fallback", exc_or_reason=metadata.get("graph_fallback_reason"), ) if metadata.get("graph_answer_fallback"): observe_provider_error( mode=mode, stage="graph_answer_fallback", exc_or_reason=metadata.get("graph_answer_fallback_reason"), ) fallback_search_error = metadata.get("fallback_search_error") if fallback_search_error: observe_provider_error( mode=mode, stage="fallback_search", exc_or_reason=fallback_search_error, )
[docs] def infer_query_outcome(metadata: Optional[Dict[str, Any]]) -> str: """Infer the outcome label used for query-level service metrics.""" if not metadata: return "success" response_policy = metadata.get("response_policy") if isinstance(response_policy, dict): state = response_policy.get("state") if isinstance(state, str) and state: return state if metadata.get("llm_fallback") or metadata.get("graph_fallback") or metadata.get("graph_answer_fallback"): return "fallback" return "success"
[docs] def refresh_backend_health(components: "RagComponents") -> None: """Probe initialized dependencies and update backend health gauges.""" BACKEND_HEALTH.labels(backend="components", target="initialized").set(1.0) context_service = components.context_service pg_repo = cast(Any, getattr(context_service, "pg_repo", None)) BACKEND_HEALTH.labels(backend="postgres", target="document_db").set( 1.0 if pg_repo is not None and pg_repo.health_check() else 0.0 ) retrieval_service = components.retrieval_service qdrant_repos_obj = getattr(retrieval_service, "qdrant_repos", {}) if retrieval_service is not None else {} qdrant_service_ok = False if isinstance(qdrant_repos_obj, dict): qdrant_repos = cast(dict[str, Any], qdrant_repos_obj) for name, repo in qdrant_repos.items(): service_ok = bool(repo.health_check()) qdrant_service_ok = qdrant_service_ok or service_ok BACKEND_HEALTH.labels( backend="qdrant_collection", target=_normalize_label(name), ).set(1.0 if repo.collection_exists() else 0.0) BACKEND_HEALTH.labels(backend="qdrant", target="vector").set(1.0 if qdrant_service_ok else 0.0) graph_rag_service = components.graph_rag_service neo4j_repo = cast(Any, getattr(graph_rag_service, "neo4j", None)) BACKEND_HEALTH.labels(backend="neo4j", target="graph").set( 1.0 if neo4j_repo is not None and neo4j_repo.health_check() else 0.0 )
def _observe_phase_timings(mode: str, metadata: Dict[str, Any]) -> None: for key, prefix in ( ("phase_timings_ms", ""), ("graph_query_phase_timings_ms", "graph_query"), ("graph_retrieval_phase_timings_ms", "graph_retrieval"), ): raw_phase_timings = metadata.get(key) if not isinstance(raw_phase_timings, dict): continue phase_timings = cast(Dict[str, Any], raw_phase_timings) for phase_name, raw_value in phase_timings.items(): try: duration_ms = float(raw_value) except (TypeError, ValueError): continue PHASE_DURATION_SECONDS.labels( mode=mode, phase=_normalize_phase_name(prefix, phase_name), ).observe(max(duration_ms, 0.0) / 1000.0) def _normalize_path(path: str) -> str: if path in {"/query", "/query/stream", "/search", "/health", "/metrics", "/"}: return path return "other" def _normalize_query_mode(mode: Optional[str]) -> str: normalized = _normalize_label(mode) if normalized in { "rag", "search", "llm_only", "graph", "summarize", "compare", "semantic", "lexical", "hybrid", }: return normalized return "unknown" def _normalize_phase_name(prefix: str, phase_name: str) -> str: normalized = _normalize_label(phase_name) if normalized.endswith("_ms"): normalized = normalized[:-3] normalized = normalized.replace("source_build", "sources_build") if normalized == "reranking": normalized = "rerank" if normalized == "total_graph_pipeline": normalized = "total" if prefix: return f"{_normalize_label(prefix)}_{normalized}" return normalized