"""
API Gateway
Entry point for all API requests.
"""
import logging
from contextlib import asynccontextmanager
from typing import Awaitable, Callable, Optional, Sequence, cast
import redis.asyncio as redis_lib
from api_gateway.auth import require_auth
from api_gateway.bootstrap import RedisClient
from api_gateway.bootstrap import bootstrap_system as _bootstrap_system
from api_gateway.rate_limit import limiter
from api_gateway.routers import (
admin_pipeline_router,
config_router,
conversations_router,
embedding_router,
health_router,
jobs_router,
rag_proxy_router,
)
from api_gateway.service_metrics import observe_http_request, refresh_backend_health
from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, PlainTextResponse
from lalandre_core.config import get_config, get_gateway_config
from lalandre_core.http.middleware import make_http_instrumentation_middleware
from lalandre_core.logging_setup import setup_worker_logging
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from starlette.exceptions import HTTPException as StarletteHTTPException
# --- Configuration & Constants ---
setup_worker_logging()
logger = logging.getLogger(__name__)
gateway = get_gateway_config()
config = get_config()
def _coalesce_str(value: Optional[str], default: str) -> str:
if value is None:
return default
return value
def _coalesce_int(value: Optional[int], default: int) -> int:
if value is None:
return default
return int(value)
def _coalesce_sequence(value: Optional[Sequence[str]], default: Sequence[str]) -> Sequence[str]:
if value is None:
return default
return value
bootstrap_system: Callable[[RedisClient], Awaitable[bool]] = _bootstrap_system
# Redis Settings
REDIS_HOST = _coalesce_str(gateway.redis_host, "redis")
REDIS_PORT = _coalesce_int(gateway.redis_port, 6379)
# Service Settings
# Note: RAG Service is for synchronous queries; Workers handle async jobs via Redis.
RAG_SERVICE_URL = gateway.rag_service_url
EMBEDDING_SERVICE_URL = gateway.embedding_service_url
RERANK_SERVICE_URL = gateway.rerank_service_url
RERANK_ENABLED = bool(config.search.rerank_enabled)
ALLOWED_ORIGINS = _coalesce_sequence(gateway.allowed_origins, [])
AUTO_BOOTSTRAP = gateway.auto_bootstrap
HEALTHCHECK_TIMEOUT_SECONDS = gateway.healthcheck_timeout_seconds
RAG_PROXY_TIMEOUT_SECONDS = gateway.rag_proxy_timeout_seconds
# --- Lifecycle Management ---
[docs]
async def get_redis(app: FastAPI) -> redis_lib.Redis:
"""
Dependency to get the active Redis client.
Initializes it if it doesn't exist (Lazy Loading pattern).
"""
redis_conn = getattr(app.state, "redis", None)
if redis_conn is None:
logger.info(f"Connecting to Redis at {REDIS_HOST}:{REDIS_PORT}...")
redis_conn = redis_lib.Redis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)
app.state.redis = redis_conn
return redis_conn
[docs]
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Manages application startup and shutdown events.
1. Connects to Redis.
2. Optionally triggers the bootstrap sequence.
3. Cleans up connections on shutdown.
"""
logger.info("=" * 80)
logger.info("API GATEWAY STARTUP")
logger.info("=" * 80)
# 1. Initialize Redis
try:
redis_conn = await get_redis(app)
logger.info("Redis connection established.")
# 2. Auto-Bootstrap Logic
if AUTO_BOOTSTRAP:
logger.info("Auto-bootstrap enabled. Triggering system initialization...")
try:
await bootstrap_system(cast(RedisClient, redis_conn))
except Exception as e:
logger.error(f"Bootstrap failed (non-fatal): {e}", exc_info=True)
logger.warning("System continuing with graceful degradation.")
else:
logger.info("Auto-bootstrap disabled.")
except Exception as e:
logger.critical(f"Critical startup failure: {e}")
# We might choose to let the app crash here if Redis is mandatory
logger.info("=" * 80)
# Initialize runtime state after Redis is ready
await initialize_routers()
yield # Application runs here
# 3. Shutdown / Cleanup
logger.info("Shutting down API Gateway...")
redis_conn = getattr(app.state, "redis", None)
if redis_conn:
await redis_conn.close()
app.state.redis = None
logger.info("Redis connection closed.")
# --- FastAPI Application Setup ---
app = FastAPI(
title="Lalandre API Gateway",
description="Legal RAG System API",
version="0.1.0",
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan,
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # type: ignore[arg-type]
app.middleware("http")(make_http_instrumentation_middleware(observe_http_request))
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
)
# Initialize routers with dependencies
[docs]
async def initialize_routers():
"""Initialize runtime state for routers"""
redis_conn = await get_redis(app)
app.state.redis = redis_conn
app.state.rag_service_url = RAG_SERVICE_URL
app.state.embedding_service_url = EMBEDDING_SERVICE_URL
app.state.rerank_service_url = RERANK_SERVICE_URL
app.state.rerank_enabled = RERANK_ENABLED
app.state.healthcheck_timeout_seconds = HEALTHCHECK_TIMEOUT_SECONDS
app.state.rag_proxy_timeout_seconds = RAG_PROXY_TIMEOUT_SECONDS
logger.info("Router runtime state initialized successfully")
# Register routers (health is public, others require auth)
app.include_router(health_router)
app.include_router(rag_proxy_router, dependencies=[Depends(require_auth)])
app.include_router(conversations_router, dependencies=[Depends(require_auth)])
app.include_router(jobs_router, dependencies=[Depends(require_auth)])
app.include_router(config_router, dependencies=[Depends(require_auth)])
app.include_router(embedding_router, prefix="/api/v1", dependencies=[Depends(require_auth)])
app.include_router(admin_pipeline_router, dependencies=[Depends(require_auth)])
# --- Root Endpoint ---
[docs]
@app.get("/")
async def root():
"""Basic service info and useful entrypoints."""
return {
"service": "lalandre-api-gateway",
"version": app.version,
"docs": "/docs",
"redoc": "/redoc",
"health": "/health",
}
[docs]
@app.get("/metrics")
async def metrics():
"""Expose Prometheus metrics after refreshing backend health probes."""
await refresh_backend_health(app)
return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)
# --- Global Exception Handlers ---
[docs]
@app.exception_handler(404)
async def not_found_handler(request: Request, exc: StarletteHTTPException):
"""Return a normalized JSON payload for unknown endpoints."""
return JSONResponse(
status_code=404,
content={"error": "Not Found", "message": "The requested endpoint does not exist", "path": str(request.url)},
)
[docs]
@app.exception_handler(500)
async def internal_error_handler(request: Request, exc: Exception):
"""Return a normalized JSON payload for uncaught internal errors."""
logger.exception("Unhandled internal error")
return JSONResponse(
status_code=500,
content={"error": "Internal Server Error", "message": "An unexpected error occurred", "detail": str(exc)},
)