Source code for api_gateway.auth
"""
JWT authentication middleware for the API Gateway.
Validates Bearer tokens against Keycloak JWKS.
"""
import logging
import os
import time
from typing import Any, Optional
import httpx
import jwt
from fastapi import HTTPException, Request
from jwt.algorithms import RSAAlgorithm # noqa: F401 — used in _get_signing_key body
from lalandre_core.config import get_config
logger = logging.getLogger(__name__)
KEYCLOAK_URL = os.getenv("KEYCLOAK_URL", "")
# External issuer URL that appears in token `iss` claim (falls back to KEYCLOAK_URL)
KEYCLOAK_ISSUER_URL = os.getenv("KEYCLOAK_ISSUER_URL", "") or KEYCLOAK_URL
JWKS_REFRESH_INTERVAL = 3600 # 1 hour
_jwks_cache: dict[str, dict] = {}
_jwks_last_refresh: float = 0
def _fetch_jwks() -> None:
"""Fetch JWKS from Keycloak and cache the keys."""
global _jwks_cache, _jwks_last_refresh
if not KEYCLOAK_URL:
logger.warning("KEYCLOAK_URL not set — JWT auth disabled")
return
jwks_url = f"{KEYCLOAK_URL}/protocol/openid-connect/certs"
try:
resp = httpx.get(jwks_url, timeout=get_config().gateway.healthcheck_timeout_seconds)
resp.raise_for_status()
jwks = resp.json()
_jwks_cache = {key["kid"]: key for key in jwks.get("keys", [])}
_jwks_last_refresh = time.time()
logger.info(f"JWKS refreshed: {len(_jwks_cache)} keys loaded")
except Exception as e:
logger.error(f"Failed to fetch JWKS: {e}")
def _get_signing_key(kid: str) -> Optional[Any]:
"""Get a signing key by kid, refreshing if needed."""
if kid not in _jwks_cache or (time.time() - _jwks_last_refresh) > JWKS_REFRESH_INTERVAL:
_fetch_jwks()
key_data = _jwks_cache.get(kid)
if not key_data:
return None
return RSAAlgorithm.from_jwk(key_data)
[docs]
async def require_auth(request: Request) -> dict:
"""FastAPI dependency that validates JWT Bearer tokens."""
if not KEYCLOAK_URL:
# Auth disabled — pass through
return {}
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
token = auth_header[7:]
try:
# Decode header to get kid
unverified = jwt.get_unverified_header(token)
kid = unverified.get("kid")
if not kid:
raise HTTPException(status_code=401, detail="Token missing kid")
public_key = _get_signing_key(kid)
if not public_key:
raise HTTPException(status_code=401, detail="Unknown signing key")
claims = jwt.decode(
token,
public_key,
algorithms=["RS256"],
issuer=KEYCLOAK_ISSUER_URL,
options={"verify_aud": False},
)
return claims
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError as e:
logger.debug(f"JWT validation failed: {e}")
raise HTTPException(status_code=401, detail="Invalid token")
ADMIN_ROLE = os.getenv("ADMIN_ROLE", "admin")
[docs]
async def require_admin(request: Request) -> dict:
"""FastAPI dependency that requires the caller to hold the admin role.
Checks ``realm_access.roles`` in the Keycloak JWT.
When auth is disabled (no KEYCLOAK_URL) every caller is treated as admin.
"""
claims = await require_auth(request)
if not KEYCLOAK_URL:
# Auth disabled — dev mode, everyone is admin
return claims
roles: list = (claims.get("realm_access") or {}).get("roles", [])
if ADMIN_ROLE not in roles:
raise HTTPException(status_code=403, detail="Admin role required")
return claims
# Fetch keys at import time (best-effort)
_fetch_jwks()