Source code for api_gateway.main

"""
API Gateway
Entry point for all API requests.
"""

import logging
from contextlib import asynccontextmanager
from typing import Awaitable, Callable, Optional, Sequence, cast

import redis.asyncio as redis_lib
from api_gateway.auth import require_auth
from api_gateway.bootstrap import RedisClient
from api_gateway.bootstrap import bootstrap_system as _bootstrap_system
from api_gateway.rate_limit import limiter
from api_gateway.routers import (
    admin_pipeline_router,
    config_router,
    conversations_router,
    embedding_router,
    health_router,
    jobs_router,
    rag_proxy_router,
)
from api_gateway.service_metrics import observe_http_request, refresh_backend_health
from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, PlainTextResponse
from lalandre_core.config import get_config, get_gateway_config
from lalandre_core.http.middleware import make_http_instrumentation_middleware
from lalandre_core.logging_setup import setup_worker_logging
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from starlette.exceptions import HTTPException as StarletteHTTPException

# --- Configuration & Constants ---

setup_worker_logging()
logger = logging.getLogger(__name__)


gateway = get_gateway_config()
config = get_config()


def _coalesce_str(value: Optional[str], default: str) -> str:
    if value is None:
        return default
    return value


def _coalesce_int(value: Optional[int], default: int) -> int:
    if value is None:
        return default
    return int(value)


def _coalesce_sequence(value: Optional[Sequence[str]], default: Sequence[str]) -> Sequence[str]:
    if value is None:
        return default
    return value


bootstrap_system: Callable[[RedisClient], Awaitable[bool]] = _bootstrap_system


# Redis Settings
REDIS_HOST = _coalesce_str(gateway.redis_host, "redis")
REDIS_PORT = _coalesce_int(gateway.redis_port, 6379)

# Service Settings
# Note: RAG Service is for synchronous queries; Workers handle async jobs via Redis.
RAG_SERVICE_URL = gateway.rag_service_url
EMBEDDING_SERVICE_URL = gateway.embedding_service_url
RERANK_SERVICE_URL = gateway.rerank_service_url
RERANK_ENABLED = bool(config.search.rerank_enabled)
ALLOWED_ORIGINS = _coalesce_sequence(gateway.allowed_origins, [])
AUTO_BOOTSTRAP = gateway.auto_bootstrap
HEALTHCHECK_TIMEOUT_SECONDS = gateway.healthcheck_timeout_seconds
RAG_PROXY_TIMEOUT_SECONDS = gateway.rag_proxy_timeout_seconds

# --- Lifecycle Management ---


[docs] async def get_redis(app: FastAPI) -> redis_lib.Redis: """ Dependency to get the active Redis client. Initializes it if it doesn't exist (Lazy Loading pattern). """ redis_conn = getattr(app.state, "redis", None) if redis_conn is None: logger.info(f"Connecting to Redis at {REDIS_HOST}:{REDIS_PORT}...") redis_conn = redis_lib.Redis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True) app.state.redis = redis_conn return redis_conn
[docs] @asynccontextmanager async def lifespan(app: FastAPI): """ Manages application startup and shutdown events. 1. Connects to Redis. 2. Optionally triggers the bootstrap sequence. 3. Cleans up connections on shutdown. """ logger.info("=" * 80) logger.info("API GATEWAY STARTUP") logger.info("=" * 80) # 1. Initialize Redis try: redis_conn = await get_redis(app) logger.info("Redis connection established.") # 2. Auto-Bootstrap Logic if AUTO_BOOTSTRAP: logger.info("Auto-bootstrap enabled. Triggering system initialization...") try: await bootstrap_system(cast(RedisClient, redis_conn)) except Exception as e: logger.error(f"Bootstrap failed (non-fatal): {e}", exc_info=True) logger.warning("System continuing with graceful degradation.") else: logger.info("Auto-bootstrap disabled.") except Exception as e: logger.critical(f"Critical startup failure: {e}") # We might choose to let the app crash here if Redis is mandatory logger.info("=" * 80) # Initialize runtime state after Redis is ready await initialize_routers() yield # Application runs here # 3. Shutdown / Cleanup logger.info("Shutting down API Gateway...") redis_conn = getattr(app.state, "redis", None) if redis_conn: await redis_conn.close() app.state.redis = None logger.info("Redis connection closed.")
# --- FastAPI Application Setup --- app = FastAPI( title="Lalandre API Gateway", description="Legal RAG System API", version="0.1.0", docs_url="/docs", redoc_url="/redoc", lifespan=lifespan, ) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # type: ignore[arg-type] app.middleware("http")(make_http_instrumentation_middleware(observe_http_request)) app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, allow_credentials=True, allow_methods=["GET", "POST", "DELETE", "OPTIONS"], allow_headers=["Authorization", "Content-Type"], ) # Initialize routers with dependencies
[docs] async def initialize_routers(): """Initialize runtime state for routers""" redis_conn = await get_redis(app) app.state.redis = redis_conn app.state.rag_service_url = RAG_SERVICE_URL app.state.embedding_service_url = EMBEDDING_SERVICE_URL app.state.rerank_service_url = RERANK_SERVICE_URL app.state.rerank_enabled = RERANK_ENABLED app.state.healthcheck_timeout_seconds = HEALTHCHECK_TIMEOUT_SECONDS app.state.rag_proxy_timeout_seconds = RAG_PROXY_TIMEOUT_SECONDS logger.info("Router runtime state initialized successfully")
# Register routers (health is public, others require auth) app.include_router(health_router) app.include_router(rag_proxy_router, dependencies=[Depends(require_auth)]) app.include_router(conversations_router, dependencies=[Depends(require_auth)]) app.include_router(jobs_router, dependencies=[Depends(require_auth)]) app.include_router(config_router, dependencies=[Depends(require_auth)]) app.include_router(embedding_router, prefix="/api/v1", dependencies=[Depends(require_auth)]) app.include_router(admin_pipeline_router, dependencies=[Depends(require_auth)]) # --- Root Endpoint ---
[docs] @app.get("/") async def root(): """Basic service info and useful entrypoints.""" return { "service": "lalandre-api-gateway", "version": app.version, "docs": "/docs", "redoc": "/redoc", "health": "/health", }
[docs] @app.get("/metrics") async def metrics(): """Expose Prometheus metrics after refreshing backend health probes.""" await refresh_backend_health(app) return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)
# --- Global Exception Handlers ---
[docs] @app.exception_handler(404) async def not_found_handler(request: Request, exc: StarletteHTTPException): """Return a normalized JSON payload for unknown endpoints.""" return JSONResponse( status_code=404, content={"error": "Not Found", "message": "The requested endpoint does not exist", "path": str(request.url)}, )
[docs] @app.exception_handler(500) async def internal_error_handler(request: Request, exc: Exception): """Return a normalized JSON payload for uncaught internal errors.""" logger.exception("Unhandled internal error") return JSONResponse( status_code=500, content={"error": "Internal Server Error", "message": "An unexpected error occurred", "detail": str(exc)}, )