"""Prometheus helpers for api-gateway runtime metrics."""
from typing import Any, Optional
import httpx
from fastapi import FastAPI
from lalandre_core.utils.metrics_utils import (
LATENCY_BUCKETS as _LATENCY_BUCKETS,
)
from lalandre_core.utils.metrics_utils import (
TOP_K_BUCKETS as _TOP_K_BUCKETS,
)
from lalandre_core.utils.metrics_utils import (
classify_error as _classify_error_full,
)
from lalandre_core.utils.metrics_utils import (
normalize_granularity as _normalize_granularity,
)
from lalandre_core.utils.metrics_utils import (
normalize_label as _normalize_label,
)
from lalandre_core.utils.metrics_utils import (
normalize_search_mode as _normalize_search_mode,
)
from lalandre_core.utils.metrics_utils import (
status_class as _status_class,
)
from prometheus_client import Counter, Gauge, Histogram
HTTP_REQUESTS_TOTAL = Counter(
"lalandre_api_gateway_http_requests_total",
"api-gateway HTTP requests by path, method, and status class.",
["path", "method", "status_class"],
)
HTTP_REQUEST_DURATION_SECONDS = Histogram(
"lalandre_api_gateway_http_request_duration_seconds",
"api-gateway HTTP request duration.",
["path", "method"],
buckets=_LATENCY_BUCKETS,
)
QUERY_REQUESTS_TOTAL = Counter(
"lalandre_api_gateway_query_requests_total",
"api-gateway /api/v1/query requests by mode, granularity, and outcome.",
["mode", "granularity", "outcome"],
)
QUERY_DURATION_SECONDS = Histogram(
"lalandre_api_gateway_query_duration_seconds",
"api-gateway /api/v1/query duration by mode.",
["mode"],
buckets=_LATENCY_BUCKETS,
)
QUERY_TOP_K = Histogram(
"lalandre_api_gateway_query_top_k",
"Requested top_k for api-gateway /api/v1/query calls.",
["mode"],
buckets=_TOP_K_BUCKETS,
)
SEARCH_REQUESTS_TOTAL = Counter(
"lalandre_api_gateway_search_requests_total",
"api-gateway /api/v1/search requests by mode, granularity, and outcome.",
["mode", "granularity", "outcome"],
)
SEARCH_DURATION_SECONDS = Histogram(
"lalandre_api_gateway_search_duration_seconds",
"api-gateway /api/v1/search duration by mode.",
["mode"],
buckets=_LATENCY_BUCKETS,
)
SEARCH_TOP_K = Histogram(
"lalandre_api_gateway_search_top_k",
"Requested top_k for api-gateway /api/v1/search calls.",
["mode"],
buckets=_TOP_K_BUCKETS,
)
PROXY_ERRORS_TOTAL = Counter(
"lalandre_api_gateway_proxy_errors_total",
"api-gateway proxy errors by endpoint and backend.",
["endpoint", "target", "error_type"],
)
BACKEND_HEALTH = Gauge(
"lalandre_api_gateway_backend_health",
"api-gateway backend health probes.",
["backend"],
)
[docs]
def observe_http_request(
*,
path: str,
method: str,
status_code: int,
duration_seconds: float,
) -> None:
"""Record one HTTP request handled by the API gateway."""
normalized_path = _normalize_path(path)
normalized_method = _normalize_label(method.upper())
HTTP_REQUESTS_TOTAL.labels(
path=normalized_path,
method=normalized_method,
status_class=_status_class(status_code),
).inc()
HTTP_REQUEST_DURATION_SECONDS.labels(
path=normalized_path,
method=normalized_method,
).observe(max(float(duration_seconds), 0.0))
[docs]
def observe_query_request(
*,
mode: Optional[str],
granularity: Optional[str],
top_k: int,
duration_seconds: float,
outcome: str,
) -> None:
"""Record one proxied `/query` request."""
normalized_mode = _normalize_query_mode(mode)
QUERY_REQUESTS_TOTAL.labels(
mode=normalized_mode,
granularity=_normalize_granularity(granularity),
outcome=_normalize_label(outcome),
).inc()
QUERY_DURATION_SECONDS.labels(mode=normalized_mode).observe(max(float(duration_seconds), 0.0))
QUERY_TOP_K.labels(mode=normalized_mode).observe(max(float(top_k), 1.0))
[docs]
def observe_search_request(
*,
mode: Optional[str],
granularity: Optional[str],
top_k: int,
duration_seconds: float,
outcome: str,
) -> None:
"""Record one proxied `/search` request."""
normalized_mode = _normalize_search_mode(mode)
SEARCH_REQUESTS_TOTAL.labels(
mode=normalized_mode,
granularity=_normalize_granularity(granularity),
outcome=_normalize_label(outcome),
).inc()
SEARCH_DURATION_SECONDS.labels(mode=normalized_mode).observe(max(float(duration_seconds), 0.0))
SEARCH_TOP_K.labels(mode=normalized_mode).observe(max(float(top_k), 1.0))
[docs]
def observe_proxy_error(
*,
endpoint: str,
target: str,
exc_or_reason: Any,
) -> None:
"""Record a backend proxy error observed by the gateway."""
PROXY_ERRORS_TOTAL.labels(
endpoint=_normalize_path(endpoint),
target=_normalize_label(target),
error_type=_classify_error_full(exc_or_reason)[1],
).inc()
[docs]
async def refresh_backend_health(app: FastAPI) -> None:
"""Probe Redis and downstream HTTP services and update health gauges."""
redis_conn = getattr(app.state, "redis", None)
if redis_conn is None:
BACKEND_HEALTH.labels(backend="redis").set(0.0)
else:
try:
await redis_conn.ping()
BACKEND_HEALTH.labels(backend="redis").set(1.0)
except Exception:
BACKEND_HEALTH.labels(backend="redis").set(0.0)
timeout = getattr(app.state, "healthcheck_timeout_seconds", 5.0)
await _probe_http_backend(
backend="rag_service",
url=getattr(app.state, "rag_service_url", None),
timeout_seconds=float(timeout),
)
await _probe_http_backend(
backend="embedding_service",
url=getattr(app.state, "embedding_service_url", None),
timeout_seconds=float(timeout),
)
await _probe_http_backend(
backend="rerank_service",
url=getattr(app.state, "rerank_service_url", None),
timeout_seconds=float(timeout),
)
async def _probe_http_backend(
*,
backend: str,
url: Optional[str],
timeout_seconds: float,
) -> None:
if not url:
BACKEND_HEALTH.labels(backend=backend).set(0.0)
return
try:
async with httpx.AsyncClient(timeout=timeout_seconds) as client:
response = await client.get(f"{url}/health")
BACKEND_HEALTH.labels(backend=backend).set(1.0 if response.status_code == 200 else 0.0)
except Exception:
BACKEND_HEALTH.labels(backend=backend).set(0.0)
def _normalize_path(path: str) -> str:
if path in {
"/",
"/health",
"/metrics",
"/api/v1/query",
"/api/v1/query/stream",
"/api/v1/search",
}:
return path
if path.startswith("/api/v1/chunk/"):
return "/api/v1/chunk/*"
if path.startswith("/api/v1/embed/"):
return "/api/v1/embed/*"
if path.startswith("/api/v1/extract/"):
return "/api/v1/extract/*"
return "other"
def _normalize_query_mode(mode: Optional[str]) -> str:
normalized = _normalize_label(mode)
if normalized in {"semantic", "hybrid"}:
return "rag"
if normalized == "lexical":
return "search"
if normalized in {"rag", "search", "llm_only", "graph", "summarize", "compare"}:
return normalized
return "unknown"