Source code for api_gateway.service_metrics

"""Prometheus helpers for api-gateway runtime metrics."""

from typing import Any, Optional

import httpx
from fastapi import FastAPI
from lalandre_core.utils.metrics_utils import (
    LATENCY_BUCKETS as _LATENCY_BUCKETS,
)
from lalandre_core.utils.metrics_utils import (
    TOP_K_BUCKETS as _TOP_K_BUCKETS,
)
from lalandre_core.utils.metrics_utils import (
    classify_error as _classify_error_full,
)
from lalandre_core.utils.metrics_utils import (
    normalize_granularity as _normalize_granularity,
)
from lalandre_core.utils.metrics_utils import (
    normalize_label as _normalize_label,
)
from lalandre_core.utils.metrics_utils import (
    normalize_search_mode as _normalize_search_mode,
)
from lalandre_core.utils.metrics_utils import (
    status_class as _status_class,
)
from prometheus_client import Counter, Gauge, Histogram

HTTP_REQUESTS_TOTAL = Counter(
    "lalandre_api_gateway_http_requests_total",
    "api-gateway HTTP requests by path, method, and status class.",
    ["path", "method", "status_class"],
)

HTTP_REQUEST_DURATION_SECONDS = Histogram(
    "lalandre_api_gateway_http_request_duration_seconds",
    "api-gateway HTTP request duration.",
    ["path", "method"],
    buckets=_LATENCY_BUCKETS,
)

QUERY_REQUESTS_TOTAL = Counter(
    "lalandre_api_gateway_query_requests_total",
    "api-gateway /api/v1/query requests by mode, granularity, and outcome.",
    ["mode", "granularity", "outcome"],
)

QUERY_DURATION_SECONDS = Histogram(
    "lalandre_api_gateway_query_duration_seconds",
    "api-gateway /api/v1/query duration by mode.",
    ["mode"],
    buckets=_LATENCY_BUCKETS,
)

QUERY_TOP_K = Histogram(
    "lalandre_api_gateway_query_top_k",
    "Requested top_k for api-gateway /api/v1/query calls.",
    ["mode"],
    buckets=_TOP_K_BUCKETS,
)

SEARCH_REQUESTS_TOTAL = Counter(
    "lalandre_api_gateway_search_requests_total",
    "api-gateway /api/v1/search requests by mode, granularity, and outcome.",
    ["mode", "granularity", "outcome"],
)

SEARCH_DURATION_SECONDS = Histogram(
    "lalandre_api_gateway_search_duration_seconds",
    "api-gateway /api/v1/search duration by mode.",
    ["mode"],
    buckets=_LATENCY_BUCKETS,
)

SEARCH_TOP_K = Histogram(
    "lalandre_api_gateway_search_top_k",
    "Requested top_k for api-gateway /api/v1/search calls.",
    ["mode"],
    buckets=_TOP_K_BUCKETS,
)

PROXY_ERRORS_TOTAL = Counter(
    "lalandre_api_gateway_proxy_errors_total",
    "api-gateway proxy errors by endpoint and backend.",
    ["endpoint", "target", "error_type"],
)

BACKEND_HEALTH = Gauge(
    "lalandre_api_gateway_backend_health",
    "api-gateway backend health probes.",
    ["backend"],
)


[docs] def observe_http_request( *, path: str, method: str, status_code: int, duration_seconds: float, ) -> None: """Record one HTTP request handled by the API gateway.""" normalized_path = _normalize_path(path) normalized_method = _normalize_label(method.upper()) HTTP_REQUESTS_TOTAL.labels( path=normalized_path, method=normalized_method, status_class=_status_class(status_code), ).inc() HTTP_REQUEST_DURATION_SECONDS.labels( path=normalized_path, method=normalized_method, ).observe(max(float(duration_seconds), 0.0))
[docs] def observe_query_request( *, mode: Optional[str], granularity: Optional[str], top_k: int, duration_seconds: float, outcome: str, ) -> None: """Record one proxied `/query` request.""" normalized_mode = _normalize_query_mode(mode) QUERY_REQUESTS_TOTAL.labels( mode=normalized_mode, granularity=_normalize_granularity(granularity), outcome=_normalize_label(outcome), ).inc() QUERY_DURATION_SECONDS.labels(mode=normalized_mode).observe(max(float(duration_seconds), 0.0)) QUERY_TOP_K.labels(mode=normalized_mode).observe(max(float(top_k), 1.0))
[docs] def observe_search_request( *, mode: Optional[str], granularity: Optional[str], top_k: int, duration_seconds: float, outcome: str, ) -> None: """Record one proxied `/search` request.""" normalized_mode = _normalize_search_mode(mode) SEARCH_REQUESTS_TOTAL.labels( mode=normalized_mode, granularity=_normalize_granularity(granularity), outcome=_normalize_label(outcome), ).inc() SEARCH_DURATION_SECONDS.labels(mode=normalized_mode).observe(max(float(duration_seconds), 0.0)) SEARCH_TOP_K.labels(mode=normalized_mode).observe(max(float(top_k), 1.0))
[docs] def observe_proxy_error( *, endpoint: str, target: str, exc_or_reason: Any, ) -> None: """Record a backend proxy error observed by the gateway.""" PROXY_ERRORS_TOTAL.labels( endpoint=_normalize_path(endpoint), target=_normalize_label(target), error_type=_classify_error_full(exc_or_reason)[1], ).inc()
[docs] async def refresh_backend_health(app: FastAPI) -> None: """Probe Redis and downstream HTTP services and update health gauges.""" redis_conn = getattr(app.state, "redis", None) if redis_conn is None: BACKEND_HEALTH.labels(backend="redis").set(0.0) else: try: await redis_conn.ping() BACKEND_HEALTH.labels(backend="redis").set(1.0) except Exception: BACKEND_HEALTH.labels(backend="redis").set(0.0) timeout = getattr(app.state, "healthcheck_timeout_seconds", 5.0) await _probe_http_backend( backend="rag_service", url=getattr(app.state, "rag_service_url", None), timeout_seconds=float(timeout), ) await _probe_http_backend( backend="embedding_service", url=getattr(app.state, "embedding_service_url", None), timeout_seconds=float(timeout), ) await _probe_http_backend( backend="rerank_service", url=getattr(app.state, "rerank_service_url", None), timeout_seconds=float(timeout), )
async def _probe_http_backend( *, backend: str, url: Optional[str], timeout_seconds: float, ) -> None: if not url: BACKEND_HEALTH.labels(backend=backend).set(0.0) return try: async with httpx.AsyncClient(timeout=timeout_seconds) as client: response = await client.get(f"{url}/health") BACKEND_HEALTH.labels(backend=backend).set(1.0 if response.status_code == 200 else 0.0) except Exception: BACKEND_HEALTH.labels(backend=backend).set(0.0) def _normalize_path(path: str) -> str: if path in { "/", "/health", "/metrics", "/api/v1/query", "/api/v1/query/stream", "/api/v1/search", }: return path if path.startswith("/api/v1/chunk/"): return "/api/v1/chunk/*" if path.startswith("/api/v1/embed/"): return "/api/v1/embed/*" if path.startswith("/api/v1/extract/"): return "/api/v1/extract/*" return "other" def _normalize_query_mode(mode: Optional[str]) -> str: normalized = _normalize_label(mode) if normalized in {"semantic", "hybrid"}: return "rag" if normalized == "lexical": return "search" if normalized in {"rag", "search", "llm_only", "graph", "summarize", "compare"}: return normalized return "unknown"