Source code for lalandre_embedding.providers.mistral

"""
Mistral embedding provider with multi-key round-robin support and Redis cache.

Token counting uses ``mistral-common`` with the v1 SentencePiece tokenizer
(the tokenizer family used by ``mistral-embed``).
"""

import hashlib
import itertools
import json
import logging
from threading import Lock
from typing import Any, Iterator, List, Optional, cast

from lalandre_core.config import get_config
from lalandre_core.utils import APIKeyPool
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistralai import Mistral

from ..base import EmbeddingProvider, SupportsNumKeys

logger = logging.getLogger(__name__)

_MISTRAL_MAX_INPUT_TOKENS = 8192

# mistral-embed uses the v1 SentencePiece tokenizer.
# Loaded once at module level (cheap, ~2 MB, no network).
_v1_tokenizer = MistralTokenizer.v1()
_v1_sp = _v1_tokenizer.instruct_tokenizer.tokenizer


# ═══════════════════════════════════════════════════════════════════════════
# Provider
# ═══════════════════════════════════════════════════════════════════════════


[docs] class MistralEmbedding(EmbeddingProvider, SupportsNumKeys): """Mistral embedding provider with multi-key round-robin support and Redis cache""" # ─────────────────────────────────────────────────────────────────── # Initialization # ─────────────────────────────────────────────────────────────────── def __init__( self, model_name: Optional[str] = None, redis_client: Any | None = None, cache_ttl: Optional[int] = None, key_pool: Optional[APIKeyPool] = None, ): config = get_config() resolved_model_name = model_name if model_name is not None else config.embedding.model_name if not resolved_model_name: raise ValueError("embedding.model_name must be configured for Mistral embeddings") resolved_cache_ttl = cache_ttl if cache_ttl is not None else config.embedding.cache_ttl_seconds self.model_name: str = resolved_model_name self.vector_size = config.vector.vector_size self.redis_client: Any | None = redis_client self.cache_ttl: int = int(resolved_cache_ttl) # Use injected API key pool or build from env (supports 1 or multiple keys) self.key_pool = key_pool or APIKeyPool.from_env() self.api_keys = self.key_pool.keys if not self.api_keys: raise ValueError( "Mistral API key required. Set MISTRAL_API_KEY env var.\n" "For multiple keys (faster): MISTRAL_API_KEY, MISTRAL_API_KEY_2, ..., MISTRAL_API_KEY_10" ) # Create clients for each key self.clients: List[Mistral] = [Mistral(api_key=key) for key in self.api_keys] # Round-robin iterator over instantiated clients (thread-safe). self._client_cycle: Iterator[Mistral] = itertools.cycle(self.clients) self._lock = Lock() self._client_index = 0 for idx in range(len(self.clients)): logger.info( "MistralEmbedding client %d uses API key index %d/%d", idx + 1, idx + 1, len(self.clients), ) # Log multi-key setup if len(self.api_keys) > 1: logger.info( "Mistral multi-key mode: %d API keys loaded (round-robin)", len(self.api_keys), ) # ─────────────────────────────────────────────────────────────────── # Token estimation # ───────────────────────────────────────────────────────────────────
[docs] def get_max_input_tokens(self) -> Optional[int]: """Return the documented maximum input length for ``mistral-embed``.""" return _MISTRAL_MAX_INPUT_TOKENS
[docs] def estimate_tokens(self, text: str) -> Optional[int]: """Count tokens using the Mistral v1 SentencePiece tokenizer.""" return len(_v1_sp.encode(text, bos=False, eos=False))
# ─────────────────────────────────────────────────────────────────── # Client selection (round-robin) # ─────────────────────────────────────────────────────────────────── def _get_next_client(self) -> Mistral: """Get next client in round-robin fashion (thread-safe)""" with self._lock: client = next(self._client_cycle) idx = self._client_index % len(self.clients) logger.info( "MistralEmbedding: using API key index %d/%d", idx + 1, len(self.clients), ) self._client_index += 1 return client # ─────────────────────────────────────────────────────────────────── # Redis cache helpers # ─────────────────────────────────────────────────────────────────── def _get_cache_key(self, text: str) -> str: """Generate cache key from text""" text_hash = hashlib.sha256(text.encode()).hexdigest() return f"embedding:{self.model_name}:{text_hash}" @staticmethod def _coerce_embedding(raw: Any) -> List[float]: if not isinstance(raw, list): raise ValueError("Embedding payload is not a list") normalized: List[float] = [] for value in cast(List[object], raw): if not isinstance(value, (int, float)): raise ValueError("Embedding values must be numeric") normalized.append(float(value)) return normalized def _get_from_cache(self, text: str) -> Optional[List[float]]: """Get embedding from Redis cache""" if not self.redis_client: return None try: cache_key = self._get_cache_key(text) cached = self.redis_client.get(cache_key) if cached is None: return None if isinstance(cached, bytes): cached_text = cached.decode("utf-8", errors="ignore") elif isinstance(cached, str): cached_text = cached else: return None loaded = json.loads(cached_text) return self._coerce_embedding(loaded) except Exception: pass # Cache miss or error, will compute return None def _set_in_cache(self, text: str, embedding: List[float]): """Store embedding in Redis cache""" if not self.redis_client: return try: cache_key = self._get_cache_key(text) self.redis_client.setex(cache_key, self.cache_ttl, json.dumps(embedding)) except Exception: pass # Cache write error, not critical # ─────────────────────────────────────────────────────────────────── # Usage logging # ─────────────────────────────────────────────────────────────────── @staticmethod def _log_usage(response: Any) -> None: """Log prompt_tokens from the API response for monitoring.""" usage = getattr(response, "usage", None) if usage is None: return prompt_tokens = getattr(usage, "prompt_tokens", None) total_tokens = getattr(usage, "total_tokens", None) logger.debug( "Mistral embedding usage: prompt_tokens=%s total_tokens=%s", prompt_tokens, total_tokens, ) # ─────────────────────────────────────────────────────────────────── # Public API (EmbeddingProvider) # ───────────────────────────────────────────────────────────────────
[docs] def embed_text(self, text: str) -> List[float]: """Generate embedding using next available client (round-robin) with cache""" cached = self._get_from_cache(text) if cached is not None: return cached client = self._get_next_client() response = client.embeddings.create(model=self.model_name, inputs=[text]) self._log_usage(response) embedding = self._coerce_embedding(response.data[0].embedding) self._set_in_cache(text, embedding) return embedding
[docs] def embed_batch(self, texts: List[str]) -> List[List[float]]: """Generate batch embeddings using next available client (round-robin) with cache""" if not self.redis_client: client = self._get_next_client() response = client.embeddings.create(model=self.model_name, inputs=texts) self._log_usage(response) return [self._coerce_embedding(data.embedding) for data in response.data] # With cache: check what's already cached results: List[Optional[List[float]]] = [None] * len(texts) uncached_indices: List[int] = [] uncached_texts: List[str] = [] for i, text in enumerate(texts): cached = self._get_from_cache(text) if cached is not None: results[i] = cached else: uncached_indices.append(i) uncached_texts.append(text) # Compute uncached if uncached_texts: client = self._get_next_client() response = client.embeddings.create(model=self.model_name, inputs=uncached_texts) self._log_usage(response) for idx, embedding_data in zip(uncached_indices, response.data): embedding = self._coerce_embedding(embedding_data.embedding) results[idx] = embedding self._set_in_cache(texts[idx], embedding) final_results: List[List[float]] = [] for embedding in results: if embedding is None: raise RuntimeError("Embedding provider returned an incomplete batch") final_results.append(embedding) return final_results
[docs] def get_vector_size(self) -> int: """Return the embedding vector dimension configured for the provider.""" return self.vector_size
[docs] def get_num_keys(self) -> int: """Return the number of API keys participating in round-robin calls.""" return len(self.api_keys)