Source code for api_gateway.auth

"""
JWT authentication middleware for the API Gateway.
Validates Bearer tokens against Keycloak JWKS.
"""

import logging
import os
import time
from typing import Any, Optional

import httpx
import jwt
from fastapi import HTTPException, Request
from jwt.algorithms import RSAAlgorithm  # noqa: F401 — used in _get_signing_key body
from lalandre_core.config import get_config

logger = logging.getLogger(__name__)

KEYCLOAK_URL = os.getenv("KEYCLOAK_URL", "")
# External issuer URL that appears in token `iss` claim (falls back to KEYCLOAK_URL)
KEYCLOAK_ISSUER_URL = os.getenv("KEYCLOAK_ISSUER_URL", "") or KEYCLOAK_URL
JWKS_REFRESH_INTERVAL = 3600  # 1 hour

_jwks_cache: dict[str, dict] = {}
_jwks_last_refresh: float = 0


def _fetch_jwks() -> None:
    """Fetch JWKS from Keycloak and cache the keys."""
    global _jwks_cache, _jwks_last_refresh

    if not KEYCLOAK_URL:
        logger.warning("KEYCLOAK_URL not set — JWT auth disabled")
        return

    jwks_url = f"{KEYCLOAK_URL}/protocol/openid-connect/certs"
    try:
        resp = httpx.get(jwks_url, timeout=get_config().gateway.healthcheck_timeout_seconds)
        resp.raise_for_status()
        jwks = resp.json()
        _jwks_cache = {key["kid"]: key for key in jwks.get("keys", [])}
        _jwks_last_refresh = time.time()
        logger.info(f"JWKS refreshed: {len(_jwks_cache)} keys loaded")
    except Exception as e:
        logger.error(f"Failed to fetch JWKS: {e}")


def _get_signing_key(kid: str) -> Optional[Any]:
    """Get a signing key by kid, refreshing if needed."""
    if kid not in _jwks_cache or (time.time() - _jwks_last_refresh) > JWKS_REFRESH_INTERVAL:
        _fetch_jwks()

    key_data = _jwks_cache.get(kid)
    if not key_data:
        return None

    return RSAAlgorithm.from_jwk(key_data)


[docs] async def require_auth(request: Request) -> dict: """FastAPI dependency that validates JWT Bearer tokens.""" if not KEYCLOAK_URL: # Auth disabled — pass through return {} auth_header = request.headers.get("authorization", "") if not auth_header.startswith("Bearer "): raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") token = auth_header[7:] try: # Decode header to get kid unverified = jwt.get_unverified_header(token) kid = unverified.get("kid") if not kid: raise HTTPException(status_code=401, detail="Token missing kid") public_key = _get_signing_key(kid) if not public_key: raise HTTPException(status_code=401, detail="Unknown signing key") claims = jwt.decode( token, public_key, algorithms=["RS256"], issuer=KEYCLOAK_ISSUER_URL, options={"verify_aud": False}, ) return claims except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") except jwt.InvalidTokenError as e: logger.debug(f"JWT validation failed: {e}") raise HTTPException(status_code=401, detail="Invalid token")
ADMIN_ROLE = os.getenv("ADMIN_ROLE", "admin")
[docs] async def require_admin(request: Request) -> dict: """FastAPI dependency that requires the caller to hold the admin role. Checks ``realm_access.roles`` in the Keycloak JWT. When auth is disabled (no KEYCLOAK_URL) every caller is treated as admin. """ claims = await require_auth(request) if not KEYCLOAK_URL: # Auth disabled — dev mode, everyone is admin return claims roles: list = (claims.get("realm_access") or {}).get("roles", []) if ADMIN_ROLE not in roles: raise HTTPException(status_code=403, detail="Admin role required") return claims
# Fetch keys at import time (best-effort) _fetch_jwks()