"""
Embedding service.
Supports multiple providers (Mistral API, local sentence-transformers)
with automatic token-limit guard, adaptive text splitting, and
weighted-average aggregation for long documents.
"""
import logging
import math
from typing import Any, List, Optional, cast
import redis
from lalandre_core.config import get_config
from langchain_text_splitters import RecursiveCharacterTextSplitter
from .base import EmbeddingProvider
logger = logging.getLogger(__name__)
[docs]
class EmbeddingService(EmbeddingProvider):
"""
Supports multiple providers: Mistral and local models.
The service itself exposes the EmbeddingProvider interface so callers can
benefit from token guards and adaptive splitting without reaching into the
raw provider implementation.
"""
# ═══════════════════════════════════════════════════════════════════
# 1. Initialization
# ═══════════════════════════════════════════════════════════════════
def __init__(
self,
provider: Optional[str] = None, # mistral, local
model_name: Optional[str] = None, # model name specific to the provider
api_key: Optional[str] = None, # reserved for future providers
device: Optional[str] = None, # device for local models (cpu, cuda, mps)
cache_dir: Optional[str] = None, # cache directory for local models
normalize_embeddings: Optional[bool] = None, # normalize local embeddings
enable_cache: Optional[bool] = None, # enable local in-memory cache
cache_max_size: Optional[int] = None, # max local cache size
key_pool: Optional[Any] = None, # APIKeyPool for Mistral provider
):
"""
Initialize embedding service
"""
config = get_config()
resolved_provider = provider or config.embedding.provider
if not resolved_provider:
raise ValueError("embedding.provider must be configured")
resolved_model_name = model_name or config.embedding.model_name
if not resolved_model_name:
raise ValueError("embedding.model_name must be configured")
resolved_device = device or config.embedding.device
if not resolved_device:
raise ValueError("embedding.device must be configured")
self.provider_name: str = resolved_provider
self.model_name: str = resolved_model_name
self.api_key = api_key
self.device: str = resolved_device
self.cache_dir = cache_dir or config.embedding.cache_dir or config.models_cache_dir
self.batch_size = config.embedding.batch_size
self.configured_embedding_max_input_tokens = config.token_limits.embedding_max_input_tokens
self.embedding_max_input_tokens = self.configured_embedding_max_input_tokens
self.chars_per_token = config.token_limits.chars_per_token
self.embedding_safety_ratio = config.token_limits.embedding_safety_ratio
self.safe_embedding_max_chars = 0
self.safe_embedding_max_tokens = 0
normalize = normalize_embeddings if normalize_embeddings is not None else config.embedding.normalize_embeddings
self.normalize_embeddings = normalize
self.enable_cache = enable_cache if enable_cache is not None else config.embedding.enable_cache
self.cache_max_size = cache_max_size if cache_max_size is not None else config.embedding.cache_max_size
self._redis_socket_timeout: int = config.embedding.redis_socket_timeout
self._key_pool = key_pool
# Initialize the appropriate provider
self.provider = self._create_provider()
self.embedding_max_input_tokens = self._resolve_embedding_max_input_tokens()
# Conservative token budget to stay under provider token limits.
self.safe_embedding_max_tokens = max(
64,
int(self.embedding_max_input_tokens * self.embedding_safety_ratio),
)
# Conservative char budget used only as a cheap pre-check before token counting.
self.safe_embedding_max_chars = max(
256,
int(self.embedding_max_input_tokens * max(float(self.chars_per_token), 0.1) * self.embedding_safety_ratio),
)
def _create_provider(self) -> EmbeddingProvider:
"""Factory method to create the appropriate embedding provider"""
if self.provider_name == "mistral":
from .providers.mistral import MistralEmbedding
config = get_config()
redis_host = config.gateway.redis_host
redis_port = config.gateway.redis_port
cache_ttl = config.embedding.cache_ttl_seconds
if redis_host is None:
raise ValueError("gateway.redis_host must be configured for embedding cache")
if redis_port is None:
raise ValueError("gateway.redis_port must be configured for embedding cache")
redis_client: Optional[Any] = None
try:
raw_client = redis.Redis(
host=redis_host,
port=redis_port,
decode_responses=True,
socket_connect_timeout=self._redis_socket_timeout,
socket_timeout=self._redis_socket_timeout,
)
redis_client = raw_client
redis_client.ping()
logger.info(
"Embedding cache enabled (Redis at %s:%s, TTL=%ss)",
redis_host,
redis_port,
cache_ttl,
)
except Exception as e:
logger.warning(
"Redis cache unavailable (%s), embeddings will not be cached",
e,
)
redis_client = None
mistral_model_name: str = self.model_name
return MistralEmbedding(
model_name=mistral_model_name,
redis_client=redis_client,
cache_ttl=cache_ttl,
key_pool=self._key_pool,
)
elif self.provider_name == "local":
from .providers.local import LocalEmbedding
local_model_name: str = self.model_name
local_device: str = self.device
return LocalEmbedding(
model_name=local_model_name,
device=local_device,
cache_dir=self.cache_dir,
normalize_embeddings=self.normalize_embeddings,
enable_cache=self.enable_cache,
cache_max_size=self.cache_max_size,
)
else:
raise ValueError(f"Unknown embedding provider: {self.provider_name}. Supported providers: mistral, local")
# ═══════════════════════════════════════════════════════════════════
# 2. Token estimation
# ═══════════════════════════════════════════════════════════════════
def _resolve_embedding_max_input_tokens(self) -> int:
provider_limit = self.provider.get_max_input_tokens()
configured_limit = int(self.configured_embedding_max_input_tokens)
if provider_limit is None or provider_limit <= 0:
return configured_limit
effective_limit = min(configured_limit, int(provider_limit))
if effective_limit < configured_limit:
logger.info(
"Embedding token budget capped by provider model limit: configured=%s provider=%s effective=%s",
configured_limit,
provider_limit,
effective_limit,
)
return effective_limit
def _estimate_tokens_provider(self, text: str) -> Optional[int]:
try:
return self.provider.estimate_tokens(text)
except Exception as exc:
logger.warning(
"Provider-native token counting unavailable (%s); falling back to generic tokenizers",
exc,
)
return None
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count: provider-native first, char-ratio fallback."""
provider_count = self._estimate_tokens_provider(text)
if provider_count is not None:
return provider_count
chars_per_token = max(float(self.chars_per_token), 0.1)
return max(1, int(math.ceil(len(text) / chars_per_token)))
def _is_text_within_safe_budget(self, text: str) -> bool:
if len(text) > self.safe_embedding_max_chars:
return False
return self._estimate_tokens(text) <= self.safe_embedding_max_tokens
# ═══════════════════════════════════════════════════════════════════
# 3. Text splitting (long-document guard)
# ═══════════════════════════════════════════════════════════════════
@staticmethod
def _is_token_limit_error(exc: Exception) -> bool:
message = str(exc).lower()
return "exceeding max" in message and "token" in message
@staticmethod
def _weighted_average(vectors: List[List[float]], weights: List[float]) -> List[float]:
if not vectors:
raise ValueError("Cannot average an empty list of vectors")
if len(vectors) != len(weights):
raise ValueError("vectors and weights length mismatch")
dimension = len(vectors[0])
if dimension == 0:
raise ValueError("Vector dimension cannot be zero")
weighted_sum = [0.0] * dimension
total_weight = 0.0
for vec, weight in zip(vectors, weights):
if len(vec) != dimension:
raise ValueError("Inconsistent vector dimensions in aggregation")
w = max(float(weight), 0.0)
total_weight += w
for i, value in enumerate(vec):
weighted_sum[i] += float(value) * w
if total_weight <= 0.0:
total_weight = float(len(vectors))
for i in range(dimension):
weighted_sum[i] = sum(float(vec[i]) for vec in vectors)
return [value / total_weight for value in weighted_sum]
@staticmethod
def _split_by_limit(text: str, max_chars: int) -> List[str]:
clean = text.strip()
if not clean:
return [text]
if len(clean) <= max_chars:
return [clean]
splitter = RecursiveCharacterTextSplitter(
chunk_size=max_chars,
chunk_overlap=0,
separators=["\n\n", "\n", r"(?<=[\.\!\?\:\;])\s+", " ", ""],
is_separator_regex=True,
)
parts = splitter.split_text(clean)
return parts if parts else [clean]
def _embed_long_text(self, text: str, max_chars: Optional[int] = None) -> List[float]:
current_limit = max_chars if max_chars is not None else self.safe_embedding_max_chars
embed_cfg = get_config().embedding
current_limit = max(embed_cfg.retry_min_tokens, int(current_limit))
segments = self._split_by_limit(text, current_limit)
if any(self._estimate_tokens(segment) > self.safe_embedding_max_tokens for segment in segments):
if current_limit <= embed_cfg.retry_fallback_threshold:
# Last fallback: call provider and let its native guard decide.
if len(segments) == 1:
return self.provider.embed_text(segments[0])
else:
reduced_limit = max(embed_cfg.retry_min_tokens, int(current_limit * embed_cfg.retry_reduction_factor))
return self._embed_long_text(text, max_chars=reduced_limit)
if len(segments) == 1:
try:
return self.provider.embed_text(segments[0])
except Exception as exc:
if not self._is_token_limit_error(exc) or current_limit <= embed_cfg.retry_fallback_threshold:
raise
reduced_limit = max(embed_cfg.retry_min_tokens, int(current_limit * embed_cfg.retry_reduction_factor))
return self._embed_long_text(text, max_chars=reduced_limit)
try:
vectors = self.provider.embed_batch(segments)
except Exception as exc:
if not self._is_token_limit_error(exc) or current_limit <= embed_cfg.retry_fallback_threshold:
raise
# Tighten the split budget and retry recursively.
reduced_limit = max(embed_cfg.retry_min_tokens, int(current_limit * embed_cfg.retry_reduction_factor))
return self._embed_long_text(text, max_chars=reduced_limit)
weights = [float(max(self._estimate_tokens(seg), 1)) for seg in segments]
return self._weighted_average(vectors, weights)
def _embed_batch_with_token_guard(self, texts: List[str]) -> List[List[float]]:
if not texts:
return []
try:
return self.provider.embed_batch(texts)
except Exception as exc:
if not self._is_token_limit_error(exc):
raise
logger.warning("Provider token limit hit on batch. Retrying with adaptive per-text split.")
return [self._embed_long_text(text) for text in texts]
# ═══════════════════════════════════════════════════════════════════
# 4. Public API
# ═══════════════════════════════════════════════════════════════════
[docs]
def embed_text(self, text: str) -> List[float]:
"""
Generate embedding for a single text
Returns:
List of floats representing the embedding vector
"""
if not text:
return self.provider.embed_text(text)
if self._is_text_within_safe_budget(text):
try:
return self.provider.embed_text(text)
except Exception as exc:
if not self._is_token_limit_error(exc):
raise
logger.warning("Token limit hit for single text under char budget. Retrying with adaptive split.")
return self._embed_long_text(text)
[docs]
def embed_batch(self, texts: List[str], batch_size: Optional[int] = None) -> List[List[float]]:
"""
Generate embeddings for multiple texts efficiently
Args:
texts: List of input texts
batch_size: Number of texts to process at once (None = use config default)
Returns:
List of embedding vectors
"""
if not texts:
return []
effective_batch_size = batch_size or self.batch_size or len(texts) or 1
batch_size = max(int(effective_batch_size), 1)
total_texts = len(texts)
# Show progress for large batches (>500 texts)
show_progress = total_texts > 500
# For local models, process in batches to avoid memory issues
if self.provider_name == "local":
all_embeddings: List[List[float]] = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
short_texts: List[str] = []
short_indices: List[int] = []
batch_results: List[Optional[List[float]]] = [None] * len(batch)
for idx, text in enumerate(batch):
if self._is_text_within_safe_budget(text):
short_indices.append(idx)
short_texts.append(text)
else:
batch_results[idx] = self._embed_long_text(text)
if short_texts:
batch_embeddings = self.provider.embed_batch(short_texts)
for idx, vector in zip(short_indices, batch_embeddings):
batch_results[idx] = vector
all_embeddings.extend(cast(List[List[float]], batch_results))
# Progress indicator
if show_progress and (i + batch_size) % (batch_size * 10) == 0:
progress = min(i + batch_size, total_texts)
logger.info("[EMBEDDING] %d/%d texts processed...", progress, total_texts)
return all_embeddings
else:
# API providers: process in chunks with progress
if total_texts <= batch_size:
# Single API call
short_texts: List[str] = []
short_indices: List[int] = []
results: List[Optional[List[float]]] = [None] * total_texts
for idx, text in enumerate(texts):
if self._is_text_within_safe_budget(text):
short_indices.append(idx)
short_texts.append(text)
else:
results[idx] = self._embed_long_text(text)
if short_texts:
short_vectors = self._embed_batch_with_token_guard(short_texts)
for idx, vector in zip(short_indices, short_vectors):
results[idx] = vector
return [cast(List[float], vector) for vector in results]
else:
# Multiple API calls with progress
all_embeddings: List[List[float]] = []
for i in range(0, total_texts, batch_size):
batch = texts[i : i + batch_size]
batch_embeddings = self.embed_batch(batch, batch_size=len(batch))
all_embeddings.extend(batch_embeddings)
# Progress indicator every 10 batches
if show_progress and ((i // batch_size) + 1) % 10 == 0:
progress = min(i + batch_size, total_texts)
logger.info("[EMBEDDING] %d/%d texts processed...", progress, total_texts)
return all_embeddings
[docs]
def estimate_tokens(self, text: str) -> Optional[int]:
"""Delegate to the internal token estimation chain (provider → char-ratio)."""
return self._estimate_tokens(text)
[docs]
def get_vector_size(self) -> int:
"""Get the dimension of the embedding vectors"""
return self.provider.get_vector_size()