Source code for embedding_service.main

"""
Embedding Service
Generates embeddings with support for multiple providers (Mistral, OpenAI, Local)
"""

from __future__ import annotations

import logging
import time
from contextlib import asynccontextmanager
from typing import List

import uvicorn
from fastapi import FastAPI, HTTPException, Request  # noqa: E402
from fastapi.responses import PlainTextResponse  # noqa: E402
from lalandre_core.config import get_config, reset_config  # noqa: E402
from lalandre_core.http.middleware import make_http_instrumentation_middleware  # noqa: E402
from lalandre_core.logging_setup import setup_worker_logging  # noqa: E402
from lalandre_embedding.base import SupportsCacheSize, SupportsNumKeys  # noqa: E402
from lalandre_embedding.service import EmbeddingService  # noqa: E402
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest  # noqa: E402
from pydantic import BaseModel, Field  # noqa: E402

try:
    from .service_metrics import (  # noqa: E402
        observe_embed_error,
        observe_embed_request,
        observe_http_request,
    )
except ImportError:  # pragma: no cover - keeps direct script execution working
    from service_metrics import (  # noqa: E402
        observe_embed_error,
        observe_embed_request,
        observe_http_request,
    )

setup_worker_logging()
logger = logging.getLogger(__name__)


[docs] @asynccontextmanager async def lifespan(app: FastAPI): """Initialize embedding service on startup""" try: config = get_config() logger.info("Initializing embedding service...") logger.info("Provider: %s", config.embedding.provider) logger.info("Model: %s", config.embedding.model_name) logger.info("Device: %s", config.embedding.device) if config.embedding.provider == "local": logger.info("Preloading local embedding model...") app.state.embedding_service = EmbeddingService() vector_size = app.state.embedding_service.get_vector_size() logger.info("Embedding service ready!") logger.info("Vector dimension: %d", vector_size) app.state.service_ready = True except Exception as e: logger.exception("Failed to initialize embedding service: %s", e) app.state.embedding_service = None app.state.service_ready = False yield
app = FastAPI( title="Lalandre Embedding Service", description="Text embedding generation with multiple providers (Mistral, OpenAI, Local)", version="0.1.0", lifespan=lifespan, ) app.middleware("http")(make_http_instrumentation_middleware(observe_http_request))
[docs] def get_embedding_service(request: Request) -> EmbeddingService: """Get the embedding service instance""" service = getattr(request.app.state, "embedding_service", None) service_ready = getattr(request.app.state, "service_ready", False) if not service_ready or service is None: raise ValueError("Embedding service not initialized") return service
# ============================================================================ # Models # ============================================================================
[docs] class EmbedRequest(BaseModel): """Request model for embedding""" texts: List[str] = Field(..., description="Texts to embed", min_length=1, max_length=128) use_cache: bool = Field( default=True, description="Use provider's cache (Redis for Mistral, in-memory for Local)", )
[docs] class EmbedResponse(BaseModel): """Response model for embedding""" embeddings: List[List[float]] model: str provider: str vector_dimension: int
[docs] class EmbedSingleResponse(BaseModel): """Response model for single embedding""" embedding: List[float] model: str provider: str vector_dimension: int
[docs] class HealthResponse(BaseModel): """Response model for health check""" status: str service: str provider: str | None model: str | None device: str | None vector_dimension: int
[docs] class ServiceInfoResponse(BaseModel): """Response model for service metadata""" provider: str | None model: str | None device: str | None vector_dimension: int batch_size: int | None normalize_embeddings: bool cache_enabled: bool | None = None cache_max_size: int | None = None current_cache_size: int | None = None num_api_keys: int | None = None
# ============================================================================ # Endpoints # ============================================================================
[docs] @app.get("/health") async def health_check(request: Request) -> HealthResponse: """Health check""" if not getattr(request.app.state, "service_ready", False): raise HTTPException( status_code=503, detail={ "status": "unhealthy", "service": "embedding-service", "error": "Embedding service not initialized", }, ) try: service = get_embedding_service(request) config = get_config() return HealthResponse( status="healthy", service="embedding-service", provider=config.embedding.provider, model=config.embedding.model_name, device=config.embedding.device, vector_dimension=service.get_vector_size(), ) except Exception as e: raise HTTPException( status_code=503, detail={"status": "unhealthy", "service": "embedding-service", "error": str(e)} )
[docs] @app.post("/embed", response_model=EmbedResponse) async def embed_texts(request: EmbedRequest, http_request: Request) -> EmbedResponse: """ Generate embeddings for texts using configured provider Supports: - Mistral AI (with Redis cache) - OpenAI - Local models (with in-memory LRU cache) """ started_at = time.perf_counter() provider_for_metrics = "unknown" try: service = get_embedding_service(http_request) config = get_config() provider_for_metrics = str(config.embedding.provider or "unknown") # Generate embeddings using the service embeddings = service.embed_batch(request.texts) model_name = config.embedding.model_name if model_name is None: raise ValueError("embedding.model_name must be configured") provider_name = config.embedding.provider if provider_name is None: raise ValueError("embedding.provider must be configured") response = EmbedResponse( embeddings=embeddings, model=model_name, provider=provider_name, vector_dimension=service.get_vector_size(), ) observe_embed_request( endpoint="embed", provider=provider_for_metrics, batch_size=len(request.texts), duration_seconds=time.perf_counter() - started_at, outcome="success", ) return response except Exception as e: logger.exception("Embedding generation failed: %s", e) observe_embed_error( endpoint="embed", provider=provider_for_metrics, exc_or_reason=e, ) observe_embed_request( endpoint="embed", provider=provider_for_metrics, batch_size=len(request.texts), duration_seconds=time.perf_counter() - started_at, outcome="error", ) raise HTTPException(status_code=500, detail=f"Embedding generation failed: {str(e)}")
[docs] @app.post("/embed/single") async def embed_single(text: str, http_request: Request) -> EmbedSingleResponse: """ Embed a single text (convenience endpoint) """ started_at = time.perf_counter() provider_for_metrics = "unknown" try: service = get_embedding_service(http_request) config = get_config() provider_name = config.embedding.provider model_name = config.embedding.model_name if provider_name is None: raise ValueError("embedding.provider must be configured") if model_name is None: raise ValueError("embedding.model_name must be configured") provider_for_metrics = provider_name embedding_values = service.embed_batch([text]) if not embedding_values: raise ValueError("No embedding generated for input text") observe_embed_request( endpoint="embed_single", provider=provider_for_metrics, batch_size=1, duration_seconds=time.perf_counter() - started_at, outcome="success", ) return EmbedSingleResponse( embedding=embedding_values[0], model=model_name, provider=provider_name, vector_dimension=service.get_vector_size(), ) except Exception as exc: observe_embed_error( endpoint="embed_single", provider=provider_for_metrics, exc_or_reason=exc, ) observe_embed_request( endpoint="embed_single", provider=provider_for_metrics, batch_size=1, duration_seconds=time.perf_counter() - started_at, outcome="error", ) logger.exception("Single embedding generation failed: %s", exc) raise HTTPException( status_code=500, detail=f"Embedding generation failed: {str(exc)}", )
[docs] @app.get("/info") async def service_info(request: Request) -> ServiceInfoResponse: """ Get embedding service configuration info """ try: service = get_embedding_service(request) config = get_config() info = ServiceInfoResponse( provider=config.embedding.provider, model=config.embedding.model_name, device=config.embedding.device, vector_dimension=service.get_vector_size(), batch_size=config.embedding.batch_size, normalize_embeddings=config.embedding.normalize_embeddings, ) # Add provider-specific info provider = service.provider if config.embedding.provider == "local": info.cache_enabled = config.embedding.enable_cache info.cache_max_size = config.embedding.cache_max_size if isinstance(provider, SupportsCacheSize): info.current_cache_size = provider.get_cache_size() elif config.embedding.provider == "mistral": if isinstance(provider, SupportsNumKeys): info.num_api_keys = provider.get_num_keys() return info except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get service info: {str(e)}")
[docs] @app.post("/reload") async def reload_model(request: Request) -> ServiceInfoResponse: """ Hot-reload: re-read the config override file and reinitialize the embedding model. Called by the api-gateway after writing a new embedding.yaml override. """ try: reset_config() new_service = EmbeddingService() request.app.state.embedding_service = new_service request.app.state.service_ready = True config = get_config() logger.info( "Embedding model reloaded: provider=%s model=%s", config.embedding.provider, config.embedding.model_name, ) info = ServiceInfoResponse( provider=config.embedding.provider, model=config.embedding.model_name, device=config.embedding.device, vector_dimension=new_service.get_vector_size(), batch_size=config.embedding.batch_size, normalize_embeddings=config.embedding.normalize_embeddings, ) provider_obj = new_service.provider if config.embedding.provider == "local": info.cache_enabled = config.embedding.enable_cache info.cache_max_size = config.embedding.cache_max_size if isinstance(provider_obj, SupportsCacheSize): info.current_cache_size = provider_obj.get_cache_size() elif config.embedding.provider == "mistral": if isinstance(provider_obj, SupportsNumKeys): info.num_api_keys = provider_obj.get_num_keys() return info except Exception as e: logger.exception("Failed to reload embedding model: %s", e) request.app.state.service_ready = False raise HTTPException(status_code=500, detail=f"Reload failed: {e}")
[docs] @app.get("/metrics") async def metrics() -> PlainTextResponse: """Expose Prometheus metrics for the embedding service.""" return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)
if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8002)