"""Prometheus helpers for rag-service runtime metrics."""
from typing import Any, Dict, Optional, cast
from lalandre_core.utils.metrics_utils import (
LATENCY_BUCKETS as _LATENCY_BUCKETS,
)
from lalandre_core.utils.metrics_utils import (
TOP_K_BUCKETS as _TOP_K_BUCKETS,
)
from lalandre_core.utils.metrics_utils import (
classify_error as _classify_error,
)
from lalandre_core.utils.metrics_utils import (
normalize_granularity as _normalize_granularity,
)
from lalandre_core.utils.metrics_utils import (
normalize_label as _normalize_label,
)
from lalandre_core.utils.metrics_utils import (
normalize_search_mode as _normalize_search_mode,
)
from lalandre_core.utils.metrics_utils import (
status_class as _status_class,
)
from prometheus_client import Counter, Gauge, Histogram
from rag_service.bootstrap import RagComponents
HTTP_REQUESTS_TOTAL = Counter(
"lalandre_rag_service_http_requests_total",
"rag-service HTTP requests by path, method, and status class.",
["path", "method", "status_class"],
)
HTTP_REQUEST_DURATION_SECONDS = Histogram(
"lalandre_rag_service_http_request_duration_seconds",
"rag-service HTTP request duration.",
["path", "method"],
buckets=_LATENCY_BUCKETS,
)
QUERY_REQUESTS_TOTAL = Counter(
"lalandre_rag_service_query_requests_total",
"rag-service /query requests by mode, granularity, and outcome.",
["mode", "granularity", "outcome"],
)
QUERY_DURATION_SECONDS = Histogram(
"lalandre_rag_service_query_duration_seconds",
"rag-service /query request duration by mode.",
["mode"],
buckets=_LATENCY_BUCKETS,
)
QUERY_TOP_K = Histogram(
"lalandre_rag_service_query_top_k",
"Requested top_k for rag-service /query calls by mode.",
["mode"],
buckets=_TOP_K_BUCKETS,
)
SEARCH_REQUESTS_TOTAL = Counter(
"lalandre_rag_service_search_requests_total",
"rag-service /search requests by mode, granularity, and outcome.",
["mode", "granularity", "outcome"],
)
SEARCH_DURATION_SECONDS = Histogram(
"lalandre_rag_service_search_duration_seconds",
"rag-service /search request duration by mode.",
["mode"],
buckets=_LATENCY_BUCKETS,
)
SEARCH_TOP_K = Histogram(
"lalandre_rag_service_search_top_k",
"Requested top_k for rag-service /search calls by mode.",
["mode"],
buckets=_TOP_K_BUCKETS,
)
PHASE_DURATION_SECONDS = Histogram(
"lalandre_rag_service_phase_duration_seconds",
"rag-service runtime phase durations extracted from response metadata.",
["mode", "phase"],
buckets=_LATENCY_BUCKETS,
)
PROVIDER_ERRORS_TOTAL = Counter(
"lalandre_rag_service_provider_errors_total",
"rag-service provider/runtime errors by mode and stage.",
["mode", "stage", "provider", "error_type"],
)
BACKEND_HEALTH = Gauge(
"lalandre_rag_service_backend_health",
"rag-service backend health and collection availability.",
["backend", "target"],
)
[docs]
def observe_http_request(
*,
path: str,
method: str,
status_code: int,
duration_seconds: float,
) -> None:
"""Record one HTTP request handled by rag-service."""
normalized_path = _normalize_path(path)
normalized_method = _normalize_label(method.upper())
HTTP_REQUESTS_TOTAL.labels(
path=normalized_path,
method=normalized_method,
status_class=_status_class(status_code),
).inc()
HTTP_REQUEST_DURATION_SECONDS.labels(
path=normalized_path,
method=normalized_method,
).observe(max(float(duration_seconds), 0.0))
[docs]
def observe_query_request(
*,
mode: Optional[str],
granularity: Optional[str],
top_k: Optional[int],
duration_seconds: float,
outcome: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Record one `/query` request handled by rag-service."""
normalized_mode = _normalize_query_mode(mode)
QUERY_REQUESTS_TOTAL.labels(
mode=normalized_mode,
granularity=_normalize_granularity(granularity),
outcome=_normalize_label(outcome),
).inc()
QUERY_DURATION_SECONDS.labels(mode=normalized_mode).observe(max(float(duration_seconds), 0.0))
QUERY_TOP_K.labels(mode=normalized_mode).observe(max(float(top_k or 1), 1.0))
if metadata:
_observe_phase_timings(normalized_mode, metadata)
[docs]
def observe_search_request(
*,
mode: Optional[str],
granularity: Optional[str],
top_k: Optional[int],
duration_seconds: float,
outcome: str,
) -> None:
"""Record one `/search` request handled by rag-service."""
normalized_mode = _normalize_search_mode(mode)
SEARCH_REQUESTS_TOTAL.labels(
mode=normalized_mode,
granularity=_normalize_granularity(granularity),
outcome=_normalize_label(outcome),
).inc()
SEARCH_DURATION_SECONDS.labels(mode=normalized_mode).observe(max(float(duration_seconds), 0.0))
SEARCH_TOP_K.labels(mode=normalized_mode).observe(max(float(top_k or 1), 1.0))
[docs]
def observe_provider_error(
*,
mode: Optional[str],
stage: str,
exc_or_reason: Any,
) -> None:
"""Record one provider-side error or fallback reason."""
provider, error_type = _classify_error(exc_or_reason)
PROVIDER_ERRORS_TOTAL.labels(
mode=_normalize_query_mode(mode),
stage=_normalize_label(stage),
provider=provider,
error_type=error_type,
).inc()
[docs]
def observe_provider_fallbacks(
*,
mode: Optional[str],
metadata: Dict[str, Any],
) -> None:
"""Emit provider fallback counters derived from response metadata."""
if metadata.get("llm_fallback"):
observe_provider_error(
mode=mode,
stage="llm_fallback",
exc_or_reason=metadata.get("llm_fallback_reason"),
)
if metadata.get("graph_fallback"):
observe_provider_error(
mode=mode,
stage="graph_fallback",
exc_or_reason=metadata.get("graph_fallback_reason"),
)
if metadata.get("graph_answer_fallback"):
observe_provider_error(
mode=mode,
stage="graph_answer_fallback",
exc_or_reason=metadata.get("graph_answer_fallback_reason"),
)
fallback_search_error = metadata.get("fallback_search_error")
if fallback_search_error:
observe_provider_error(
mode=mode,
stage="fallback_search",
exc_or_reason=fallback_search_error,
)
[docs]
def infer_query_outcome(metadata: Optional[Dict[str, Any]]) -> str:
"""Infer the outcome label used for query-level service metrics."""
if not metadata:
return "success"
response_policy = metadata.get("response_policy")
if isinstance(response_policy, dict):
state = response_policy.get("state")
if isinstance(state, str) and state:
return state
if metadata.get("llm_fallback") or metadata.get("graph_fallback") or metadata.get("graph_answer_fallback"):
return "fallback"
return "success"
[docs]
def refresh_backend_health(components: "RagComponents") -> None:
"""Probe initialized dependencies and update backend health gauges."""
BACKEND_HEALTH.labels(backend="components", target="initialized").set(1.0)
context_service = components.context_service
pg_repo = cast(Any, getattr(context_service, "pg_repo", None))
BACKEND_HEALTH.labels(backend="postgres", target="document_db").set(
1.0 if pg_repo is not None and pg_repo.health_check() else 0.0
)
retrieval_service = components.retrieval_service
qdrant_repos_obj = getattr(retrieval_service, "qdrant_repos", {}) if retrieval_service is not None else {}
qdrant_service_ok = False
if isinstance(qdrant_repos_obj, dict):
qdrant_repos = cast(dict[str, Any], qdrant_repos_obj)
for name, repo in qdrant_repos.items():
service_ok = bool(repo.health_check())
qdrant_service_ok = qdrant_service_ok or service_ok
BACKEND_HEALTH.labels(
backend="qdrant_collection",
target=_normalize_label(name),
).set(1.0 if repo.collection_exists() else 0.0)
BACKEND_HEALTH.labels(backend="qdrant", target="vector").set(1.0 if qdrant_service_ok else 0.0)
graph_rag_service = components.graph_rag_service
neo4j_repo = cast(Any, getattr(graph_rag_service, "neo4j", None))
BACKEND_HEALTH.labels(backend="neo4j", target="graph").set(
1.0 if neo4j_repo is not None and neo4j_repo.health_check() else 0.0
)
def _observe_phase_timings(mode: str, metadata: Dict[str, Any]) -> None:
for key, prefix in (
("phase_timings_ms", ""),
("graph_query_phase_timings_ms", "graph_query"),
("graph_retrieval_phase_timings_ms", "graph_retrieval"),
):
raw_phase_timings = metadata.get(key)
if not isinstance(raw_phase_timings, dict):
continue
phase_timings = cast(Dict[str, Any], raw_phase_timings)
for phase_name, raw_value in phase_timings.items():
try:
duration_ms = float(raw_value)
except (TypeError, ValueError):
continue
PHASE_DURATION_SECONDS.labels(
mode=mode,
phase=_normalize_phase_name(prefix, phase_name),
).observe(max(duration_ms, 0.0) / 1000.0)
def _normalize_path(path: str) -> str:
if path in {"/query", "/query/stream", "/search", "/health", "/metrics", "/"}:
return path
return "other"
def _normalize_query_mode(mode: Optional[str]) -> str:
normalized = _normalize_label(mode)
if normalized in {
"rag",
"search",
"llm_only",
"graph",
"summarize",
"compare",
"semantic",
"lexical",
"hybrid",
}:
return normalized
return "unknown"
def _normalize_phase_name(prefix: str, phase_name: str) -> str:
normalized = _normalize_label(phase_name)
if normalized.endswith("_ms"):
normalized = normalized[:-3]
normalized = normalized.replace("source_build", "sources_build")
if normalized == "reranking":
normalized = "rerank"
if normalized == "total_graph_pipeline":
normalized = "total"
if prefix:
return f"{_normalize_label(prefix)}_{normalized}"
return normalized