"""
Embedding Service
Generates embeddings with support for multiple providers (Mistral, OpenAI, Local)
"""
from __future__ import annotations
import logging
import time
from contextlib import asynccontextmanager
from typing import List
import uvicorn
from fastapi import FastAPI, HTTPException, Request # noqa: E402
from fastapi.responses import PlainTextResponse # noqa: E402
from lalandre_core.config import get_config, reset_config # noqa: E402
from lalandre_core.http.middleware import make_http_instrumentation_middleware # noqa: E402
from lalandre_core.logging_setup import setup_worker_logging # noqa: E402
from lalandre_embedding.base import SupportsCacheSize, SupportsNumKeys # noqa: E402
from lalandre_embedding.service import EmbeddingService # noqa: E402
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest # noqa: E402
from pydantic import BaseModel, Field # noqa: E402
try:
from .service_metrics import ( # noqa: E402
observe_embed_error,
observe_embed_request,
observe_http_request,
)
except ImportError: # pragma: no cover - keeps direct script execution working
from service_metrics import ( # noqa: E402
observe_embed_error,
observe_embed_request,
observe_http_request,
)
setup_worker_logging()
logger = logging.getLogger(__name__)
[docs]
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize embedding service on startup"""
try:
config = get_config()
logger.info("Initializing embedding service...")
logger.info("Provider: %s", config.embedding.provider)
logger.info("Model: %s", config.embedding.model_name)
logger.info("Device: %s", config.embedding.device)
if config.embedding.provider == "local":
logger.info("Preloading local embedding model...")
app.state.embedding_service = EmbeddingService()
vector_size = app.state.embedding_service.get_vector_size()
logger.info("Embedding service ready!")
logger.info("Vector dimension: %d", vector_size)
app.state.service_ready = True
except Exception as e:
logger.exception("Failed to initialize embedding service: %s", e)
app.state.embedding_service = None
app.state.service_ready = False
yield
app = FastAPI(
title="Lalandre Embedding Service",
description="Text embedding generation with multiple providers (Mistral, OpenAI, Local)",
version="0.1.0",
lifespan=lifespan,
)
app.middleware("http")(make_http_instrumentation_middleware(observe_http_request))
[docs]
def get_embedding_service(request: Request) -> EmbeddingService:
"""Get the embedding service instance"""
service = getattr(request.app.state, "embedding_service", None)
service_ready = getattr(request.app.state, "service_ready", False)
if not service_ready or service is None:
raise ValueError("Embedding service not initialized")
return service
# ============================================================================
# Models
# ============================================================================
[docs]
class EmbedRequest(BaseModel):
"""Request model for embedding"""
texts: List[str] = Field(..., description="Texts to embed", min_length=1, max_length=128)
use_cache: bool = Field(
default=True,
description="Use provider's cache (Redis for Mistral, in-memory for Local)",
)
[docs]
class EmbedResponse(BaseModel):
"""Response model for embedding"""
embeddings: List[List[float]]
model: str
provider: str
vector_dimension: int
[docs]
class EmbedSingleResponse(BaseModel):
"""Response model for single embedding"""
embedding: List[float]
model: str
provider: str
vector_dimension: int
[docs]
class HealthResponse(BaseModel):
"""Response model for health check"""
status: str
service: str
provider: str | None
model: str | None
device: str | None
vector_dimension: int
[docs]
class ServiceInfoResponse(BaseModel):
"""Response model for service metadata"""
provider: str | None
model: str | None
device: str | None
vector_dimension: int
batch_size: int | None
normalize_embeddings: bool
cache_enabled: bool | None = None
cache_max_size: int | None = None
current_cache_size: int | None = None
num_api_keys: int | None = None
# ============================================================================
# Endpoints
# ============================================================================
[docs]
@app.get("/health")
async def health_check(request: Request) -> HealthResponse:
"""Health check"""
if not getattr(request.app.state, "service_ready", False):
raise HTTPException(
status_code=503,
detail={
"status": "unhealthy",
"service": "embedding-service",
"error": "Embedding service not initialized",
},
)
try:
service = get_embedding_service(request)
config = get_config()
return HealthResponse(
status="healthy",
service="embedding-service",
provider=config.embedding.provider,
model=config.embedding.model_name,
device=config.embedding.device,
vector_dimension=service.get_vector_size(),
)
except Exception as e:
raise HTTPException(
status_code=503, detail={"status": "unhealthy", "service": "embedding-service", "error": str(e)}
)
[docs]
@app.post("/embed", response_model=EmbedResponse)
async def embed_texts(request: EmbedRequest, http_request: Request) -> EmbedResponse:
"""
Generate embeddings for texts using configured provider
Supports:
- Mistral AI (with Redis cache)
- OpenAI
- Local models (with in-memory LRU cache)
"""
started_at = time.perf_counter()
provider_for_metrics = "unknown"
try:
service = get_embedding_service(http_request)
config = get_config()
provider_for_metrics = str(config.embedding.provider or "unknown")
# Generate embeddings using the service
embeddings = service.embed_batch(request.texts)
model_name = config.embedding.model_name
if model_name is None:
raise ValueError("embedding.model_name must be configured")
provider_name = config.embedding.provider
if provider_name is None:
raise ValueError("embedding.provider must be configured")
response = EmbedResponse(
embeddings=embeddings,
model=model_name,
provider=provider_name,
vector_dimension=service.get_vector_size(),
)
observe_embed_request(
endpoint="embed",
provider=provider_for_metrics,
batch_size=len(request.texts),
duration_seconds=time.perf_counter() - started_at,
outcome="success",
)
return response
except Exception as e:
logger.exception("Embedding generation failed: %s", e)
observe_embed_error(
endpoint="embed",
provider=provider_for_metrics,
exc_or_reason=e,
)
observe_embed_request(
endpoint="embed",
provider=provider_for_metrics,
batch_size=len(request.texts),
duration_seconds=time.perf_counter() - started_at,
outcome="error",
)
raise HTTPException(status_code=500, detail=f"Embedding generation failed: {str(e)}")
[docs]
@app.post("/embed/single")
async def embed_single(text: str, http_request: Request) -> EmbedSingleResponse:
"""
Embed a single text (convenience endpoint)
"""
started_at = time.perf_counter()
provider_for_metrics = "unknown"
try:
service = get_embedding_service(http_request)
config = get_config()
provider_name = config.embedding.provider
model_name = config.embedding.model_name
if provider_name is None:
raise ValueError("embedding.provider must be configured")
if model_name is None:
raise ValueError("embedding.model_name must be configured")
provider_for_metrics = provider_name
embedding_values = service.embed_batch([text])
if not embedding_values:
raise ValueError("No embedding generated for input text")
observe_embed_request(
endpoint="embed_single",
provider=provider_for_metrics,
batch_size=1,
duration_seconds=time.perf_counter() - started_at,
outcome="success",
)
return EmbedSingleResponse(
embedding=embedding_values[0],
model=model_name,
provider=provider_name,
vector_dimension=service.get_vector_size(),
)
except Exception as exc:
observe_embed_error(
endpoint="embed_single",
provider=provider_for_metrics,
exc_or_reason=exc,
)
observe_embed_request(
endpoint="embed_single",
provider=provider_for_metrics,
batch_size=1,
duration_seconds=time.perf_counter() - started_at,
outcome="error",
)
logger.exception("Single embedding generation failed: %s", exc)
raise HTTPException(
status_code=500,
detail=f"Embedding generation failed: {str(exc)}",
)
[docs]
@app.get("/info")
async def service_info(request: Request) -> ServiceInfoResponse:
"""
Get embedding service configuration info
"""
try:
service = get_embedding_service(request)
config = get_config()
info = ServiceInfoResponse(
provider=config.embedding.provider,
model=config.embedding.model_name,
device=config.embedding.device,
vector_dimension=service.get_vector_size(),
batch_size=config.embedding.batch_size,
normalize_embeddings=config.embedding.normalize_embeddings,
)
# Add provider-specific info
provider = service.provider
if config.embedding.provider == "local":
info.cache_enabled = config.embedding.enable_cache
info.cache_max_size = config.embedding.cache_max_size
if isinstance(provider, SupportsCacheSize):
info.current_cache_size = provider.get_cache_size()
elif config.embedding.provider == "mistral":
if isinstance(provider, SupportsNumKeys):
info.num_api_keys = provider.get_num_keys()
return info
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get service info: {str(e)}")
[docs]
@app.post("/reload")
async def reload_model(request: Request) -> ServiceInfoResponse:
"""
Hot-reload: re-read the config override file and reinitialize the embedding model.
Called by the api-gateway after writing a new embedding.yaml override.
"""
try:
reset_config()
new_service = EmbeddingService()
request.app.state.embedding_service = new_service
request.app.state.service_ready = True
config = get_config()
logger.info(
"Embedding model reloaded: provider=%s model=%s",
config.embedding.provider,
config.embedding.model_name,
)
info = ServiceInfoResponse(
provider=config.embedding.provider,
model=config.embedding.model_name,
device=config.embedding.device,
vector_dimension=new_service.get_vector_size(),
batch_size=config.embedding.batch_size,
normalize_embeddings=config.embedding.normalize_embeddings,
)
provider_obj = new_service.provider
if config.embedding.provider == "local":
info.cache_enabled = config.embedding.enable_cache
info.cache_max_size = config.embedding.cache_max_size
if isinstance(provider_obj, SupportsCacheSize):
info.current_cache_size = provider_obj.get_cache_size()
elif config.embedding.provider == "mistral":
if isinstance(provider_obj, SupportsNumKeys):
info.num_api_keys = provider_obj.get_num_keys()
return info
except Exception as e:
logger.exception("Failed to reload embedding model: %s", e)
request.app.state.service_ready = False
raise HTTPException(status_code=500, detail=f"Reload failed: {e}")
[docs]
@app.get("/metrics")
async def metrics() -> PlainTextResponse:
"""Expose Prometheus metrics for the embedding service."""
return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8002)