Source code for lalandre_embedding.service

"""
Embedding service.

Supports multiple providers (Mistral API, local sentence-transformers)
with automatic token-limit guard, adaptive text splitting, and
weighted-average aggregation for long documents.
"""

import logging
import math
from typing import Any, List, Optional, cast

import redis
from lalandre_core.config import get_config
from langchain_text_splitters import RecursiveCharacterTextSplitter

from .base import EmbeddingProvider

logger = logging.getLogger(__name__)


[docs] class EmbeddingService(EmbeddingProvider): """ Supports multiple providers: Mistral and local models. The service itself exposes the EmbeddingProvider interface so callers can benefit from token guards and adaptive splitting without reaching into the raw provider implementation. """ # ═══════════════════════════════════════════════════════════════════ # 1. Initialization # ═══════════════════════════════════════════════════════════════════ def __init__( self, provider: Optional[str] = None, # mistral, local model_name: Optional[str] = None, # model name specific to the provider api_key: Optional[str] = None, # reserved for future providers device: Optional[str] = None, # device for local models (cpu, cuda, mps) cache_dir: Optional[str] = None, # cache directory for local models normalize_embeddings: Optional[bool] = None, # normalize local embeddings enable_cache: Optional[bool] = None, # enable local in-memory cache cache_max_size: Optional[int] = None, # max local cache size key_pool: Optional[Any] = None, # APIKeyPool for Mistral provider ): """ Initialize embedding service """ config = get_config() resolved_provider = provider or config.embedding.provider if not resolved_provider: raise ValueError("embedding.provider must be configured") resolved_model_name = model_name or config.embedding.model_name if not resolved_model_name: raise ValueError("embedding.model_name must be configured") resolved_device = device or config.embedding.device if not resolved_device: raise ValueError("embedding.device must be configured") self.provider_name: str = resolved_provider self.model_name: str = resolved_model_name self.api_key = api_key self.device: str = resolved_device self.cache_dir = cache_dir or config.embedding.cache_dir or config.models_cache_dir self.batch_size = config.embedding.batch_size self.configured_embedding_max_input_tokens = config.token_limits.embedding_max_input_tokens self.embedding_max_input_tokens = self.configured_embedding_max_input_tokens self.chars_per_token = config.token_limits.chars_per_token self.embedding_safety_ratio = config.token_limits.embedding_safety_ratio self.safe_embedding_max_chars = 0 self.safe_embedding_max_tokens = 0 normalize = normalize_embeddings if normalize_embeddings is not None else config.embedding.normalize_embeddings self.normalize_embeddings = normalize self.enable_cache = enable_cache if enable_cache is not None else config.embedding.enable_cache self.cache_max_size = cache_max_size if cache_max_size is not None else config.embedding.cache_max_size self._redis_socket_timeout: int = config.embedding.redis_socket_timeout self._key_pool = key_pool # Initialize the appropriate provider self.provider = self._create_provider() self.embedding_max_input_tokens = self._resolve_embedding_max_input_tokens() # Conservative token budget to stay under provider token limits. self.safe_embedding_max_tokens = max( 64, int(self.embedding_max_input_tokens * self.embedding_safety_ratio), ) # Conservative char budget used only as a cheap pre-check before token counting. self.safe_embedding_max_chars = max( 256, int(self.embedding_max_input_tokens * max(float(self.chars_per_token), 0.1) * self.embedding_safety_ratio), ) def _create_provider(self) -> EmbeddingProvider: """Factory method to create the appropriate embedding provider""" if self.provider_name == "mistral": from .providers.mistral import MistralEmbedding config = get_config() redis_host = config.gateway.redis_host redis_port = config.gateway.redis_port cache_ttl = config.embedding.cache_ttl_seconds if redis_host is None: raise ValueError("gateway.redis_host must be configured for embedding cache") if redis_port is None: raise ValueError("gateway.redis_port must be configured for embedding cache") redis_client: Optional[Any] = None try: raw_client = redis.Redis( host=redis_host, port=redis_port, decode_responses=True, socket_connect_timeout=self._redis_socket_timeout, socket_timeout=self._redis_socket_timeout, ) redis_client = raw_client redis_client.ping() logger.info( "Embedding cache enabled (Redis at %s:%s, TTL=%ss)", redis_host, redis_port, cache_ttl, ) except Exception as e: logger.warning( "Redis cache unavailable (%s), embeddings will not be cached", e, ) redis_client = None mistral_model_name: str = self.model_name return MistralEmbedding( model_name=mistral_model_name, redis_client=redis_client, cache_ttl=cache_ttl, key_pool=self._key_pool, ) elif self.provider_name == "local": from .providers.local import LocalEmbedding local_model_name: str = self.model_name local_device: str = self.device return LocalEmbedding( model_name=local_model_name, device=local_device, cache_dir=self.cache_dir, normalize_embeddings=self.normalize_embeddings, enable_cache=self.enable_cache, cache_max_size=self.cache_max_size, ) else: raise ValueError(f"Unknown embedding provider: {self.provider_name}. Supported providers: mistral, local") # ═══════════════════════════════════════════════════════════════════ # 2. Token estimation # ═══════════════════════════════════════════════════════════════════ def _resolve_embedding_max_input_tokens(self) -> int: provider_limit = self.provider.get_max_input_tokens() configured_limit = int(self.configured_embedding_max_input_tokens) if provider_limit is None or provider_limit <= 0: return configured_limit effective_limit = min(configured_limit, int(provider_limit)) if effective_limit < configured_limit: logger.info( "Embedding token budget capped by provider model limit: configured=%s provider=%s effective=%s", configured_limit, provider_limit, effective_limit, ) return effective_limit def _estimate_tokens_provider(self, text: str) -> Optional[int]: try: return self.provider.estimate_tokens(text) except Exception as exc: logger.warning( "Provider-native token counting unavailable (%s); falling back to generic tokenizers", exc, ) return None def _estimate_tokens(self, text: str) -> int: """Estimate token count: provider-native first, char-ratio fallback.""" provider_count = self._estimate_tokens_provider(text) if provider_count is not None: return provider_count chars_per_token = max(float(self.chars_per_token), 0.1) return max(1, int(math.ceil(len(text) / chars_per_token))) def _is_text_within_safe_budget(self, text: str) -> bool: if len(text) > self.safe_embedding_max_chars: return False return self._estimate_tokens(text) <= self.safe_embedding_max_tokens # ═══════════════════════════════════════════════════════════════════ # 3. Text splitting (long-document guard) # ═══════════════════════════════════════════════════════════════════ @staticmethod def _is_token_limit_error(exc: Exception) -> bool: message = str(exc).lower() return "exceeding max" in message and "token" in message @staticmethod def _weighted_average(vectors: List[List[float]], weights: List[float]) -> List[float]: if not vectors: raise ValueError("Cannot average an empty list of vectors") if len(vectors) != len(weights): raise ValueError("vectors and weights length mismatch") dimension = len(vectors[0]) if dimension == 0: raise ValueError("Vector dimension cannot be zero") weighted_sum = [0.0] * dimension total_weight = 0.0 for vec, weight in zip(vectors, weights): if len(vec) != dimension: raise ValueError("Inconsistent vector dimensions in aggregation") w = max(float(weight), 0.0) total_weight += w for i, value in enumerate(vec): weighted_sum[i] += float(value) * w if total_weight <= 0.0: total_weight = float(len(vectors)) for i in range(dimension): weighted_sum[i] = sum(float(vec[i]) for vec in vectors) return [value / total_weight for value in weighted_sum] @staticmethod def _split_by_limit(text: str, max_chars: int) -> List[str]: clean = text.strip() if not clean: return [text] if len(clean) <= max_chars: return [clean] splitter = RecursiveCharacterTextSplitter( chunk_size=max_chars, chunk_overlap=0, separators=["\n\n", "\n", r"(?<=[\.\!\?\:\;])\s+", " ", ""], is_separator_regex=True, ) parts = splitter.split_text(clean) return parts if parts else [clean] def _embed_long_text(self, text: str, max_chars: Optional[int] = None) -> List[float]: current_limit = max_chars if max_chars is not None else self.safe_embedding_max_chars embed_cfg = get_config().embedding current_limit = max(embed_cfg.retry_min_tokens, int(current_limit)) segments = self._split_by_limit(text, current_limit) if any(self._estimate_tokens(segment) > self.safe_embedding_max_tokens for segment in segments): if current_limit <= embed_cfg.retry_fallback_threshold: # Last fallback: call provider and let its native guard decide. if len(segments) == 1: return self.provider.embed_text(segments[0]) else: reduced_limit = max(embed_cfg.retry_min_tokens, int(current_limit * embed_cfg.retry_reduction_factor)) return self._embed_long_text(text, max_chars=reduced_limit) if len(segments) == 1: try: return self.provider.embed_text(segments[0]) except Exception as exc: if not self._is_token_limit_error(exc) or current_limit <= embed_cfg.retry_fallback_threshold: raise reduced_limit = max(embed_cfg.retry_min_tokens, int(current_limit * embed_cfg.retry_reduction_factor)) return self._embed_long_text(text, max_chars=reduced_limit) try: vectors = self.provider.embed_batch(segments) except Exception as exc: if not self._is_token_limit_error(exc) or current_limit <= embed_cfg.retry_fallback_threshold: raise # Tighten the split budget and retry recursively. reduced_limit = max(embed_cfg.retry_min_tokens, int(current_limit * embed_cfg.retry_reduction_factor)) return self._embed_long_text(text, max_chars=reduced_limit) weights = [float(max(self._estimate_tokens(seg), 1)) for seg in segments] return self._weighted_average(vectors, weights) def _embed_batch_with_token_guard(self, texts: List[str]) -> List[List[float]]: if not texts: return [] try: return self.provider.embed_batch(texts) except Exception as exc: if not self._is_token_limit_error(exc): raise logger.warning("Provider token limit hit on batch. Retrying with adaptive per-text split.") return [self._embed_long_text(text) for text in texts] # ═══════════════════════════════════════════════════════════════════ # 4. Public API # ═══════════════════════════════════════════════════════════════════
[docs] def embed_text(self, text: str) -> List[float]: """ Generate embedding for a single text Returns: List of floats representing the embedding vector """ if not text: return self.provider.embed_text(text) if self._is_text_within_safe_budget(text): try: return self.provider.embed_text(text) except Exception as exc: if not self._is_token_limit_error(exc): raise logger.warning("Token limit hit for single text under char budget. Retrying with adaptive split.") return self._embed_long_text(text)
[docs] def embed_batch(self, texts: List[str], batch_size: Optional[int] = None) -> List[List[float]]: """ Generate embeddings for multiple texts efficiently Args: texts: List of input texts batch_size: Number of texts to process at once (None = use config default) Returns: List of embedding vectors """ if not texts: return [] effective_batch_size = batch_size or self.batch_size or len(texts) or 1 batch_size = max(int(effective_batch_size), 1) total_texts = len(texts) # Show progress for large batches (>500 texts) show_progress = total_texts > 500 # For local models, process in batches to avoid memory issues if self.provider_name == "local": all_embeddings: List[List[float]] = [] for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] short_texts: List[str] = [] short_indices: List[int] = [] batch_results: List[Optional[List[float]]] = [None] * len(batch) for idx, text in enumerate(batch): if self._is_text_within_safe_budget(text): short_indices.append(idx) short_texts.append(text) else: batch_results[idx] = self._embed_long_text(text) if short_texts: batch_embeddings = self.provider.embed_batch(short_texts) for idx, vector in zip(short_indices, batch_embeddings): batch_results[idx] = vector all_embeddings.extend(cast(List[List[float]], batch_results)) # Progress indicator if show_progress and (i + batch_size) % (batch_size * 10) == 0: progress = min(i + batch_size, total_texts) logger.info("[EMBEDDING] %d/%d texts processed...", progress, total_texts) return all_embeddings else: # API providers: process in chunks with progress if total_texts <= batch_size: # Single API call short_texts: List[str] = [] short_indices: List[int] = [] results: List[Optional[List[float]]] = [None] * total_texts for idx, text in enumerate(texts): if self._is_text_within_safe_budget(text): short_indices.append(idx) short_texts.append(text) else: results[idx] = self._embed_long_text(text) if short_texts: short_vectors = self._embed_batch_with_token_guard(short_texts) for idx, vector in zip(short_indices, short_vectors): results[idx] = vector return [cast(List[float], vector) for vector in results] else: # Multiple API calls with progress all_embeddings: List[List[float]] = [] for i in range(0, total_texts, batch_size): batch = texts[i : i + batch_size] batch_embeddings = self.embed_batch(batch, batch_size=len(batch)) all_embeddings.extend(batch_embeddings) # Progress indicator every 10 batches if show_progress and ((i // batch_size) + 1) % 10 == 0: progress = min(i + batch_size, total_texts) logger.info("[EMBEDDING] %d/%d texts processed...", progress, total_texts) return all_embeddings
[docs] def estimate_tokens(self, text: str) -> Optional[int]: """Delegate to the internal token estimation chain (provider → char-ratio).""" return self._estimate_tokens(text)
[docs] def get_max_input_tokens(self) -> Optional[int]: """Return the effective token limit (config ∩ provider).""" return self.embedding_max_input_tokens
[docs] def get_vector_size(self) -> int: """Get the dimension of the embedding vectors""" return self.provider.get_vector_size()