"""
Mistral embedding provider with multi-key round-robin support and Redis cache.
Token counting uses ``mistral-common`` with the v1 SentencePiece tokenizer
(the tokenizer family used by ``mistral-embed``).
"""
import hashlib
import itertools
import json
import logging
from threading import Lock
from typing import Any, Iterator, List, Optional, cast
from lalandre_core.config import get_config
from lalandre_core.utils import APIKeyPool
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistralai import Mistral
from ..base import EmbeddingProvider, SupportsNumKeys
logger = logging.getLogger(__name__)
_MISTRAL_MAX_INPUT_TOKENS = 8192
# mistral-embed uses the v1 SentencePiece tokenizer.
# Loaded once at module level (cheap, ~2 MB, no network).
_v1_tokenizer = MistralTokenizer.v1()
_v1_sp = _v1_tokenizer.instruct_tokenizer.tokenizer
# ═══════════════════════════════════════════════════════════════════════════
# Provider
# ═══════════════════════════════════════════════════════════════════════════
[docs]
class MistralEmbedding(EmbeddingProvider, SupportsNumKeys):
"""Mistral embedding provider with multi-key round-robin support and Redis cache"""
# ───────────────────────────────────────────────────────────────────
# Initialization
# ───────────────────────────────────────────────────────────────────
def __init__(
self,
model_name: Optional[str] = None,
redis_client: Any | None = None,
cache_ttl: Optional[int] = None,
key_pool: Optional[APIKeyPool] = None,
):
config = get_config()
resolved_model_name = model_name if model_name is not None else config.embedding.model_name
if not resolved_model_name:
raise ValueError("embedding.model_name must be configured for Mistral embeddings")
resolved_cache_ttl = cache_ttl if cache_ttl is not None else config.embedding.cache_ttl_seconds
self.model_name: str = resolved_model_name
self.vector_size = config.vector.vector_size
self.redis_client: Any | None = redis_client
self.cache_ttl: int = int(resolved_cache_ttl)
# Use injected API key pool or build from env (supports 1 or multiple keys)
self.key_pool = key_pool or APIKeyPool.from_env()
self.api_keys = self.key_pool.keys
if not self.api_keys:
raise ValueError(
"Mistral API key required. Set MISTRAL_API_KEY env var.\n"
"For multiple keys (faster): MISTRAL_API_KEY, MISTRAL_API_KEY_2, ..., MISTRAL_API_KEY_10"
)
# Create clients for each key
self.clients: List[Mistral] = [Mistral(api_key=key) for key in self.api_keys]
# Round-robin iterator over instantiated clients (thread-safe).
self._client_cycle: Iterator[Mistral] = itertools.cycle(self.clients)
self._lock = Lock()
self._client_index = 0
for idx in range(len(self.clients)):
logger.info(
"MistralEmbedding client %d uses API key index %d/%d",
idx + 1,
idx + 1,
len(self.clients),
)
# Log multi-key setup
if len(self.api_keys) > 1:
logger.info(
"Mistral multi-key mode: %d API keys loaded (round-robin)",
len(self.api_keys),
)
# ───────────────────────────────────────────────────────────────────
# Token estimation
# ───────────────────────────────────────────────────────────────────
[docs]
def estimate_tokens(self, text: str) -> Optional[int]:
"""Count tokens using the Mistral v1 SentencePiece tokenizer."""
return len(_v1_sp.encode(text, bos=False, eos=False))
# ───────────────────────────────────────────────────────────────────
# Client selection (round-robin)
# ───────────────────────────────────────────────────────────────────
def _get_next_client(self) -> Mistral:
"""Get next client in round-robin fashion (thread-safe)"""
with self._lock:
client = next(self._client_cycle)
idx = self._client_index % len(self.clients)
logger.info(
"MistralEmbedding: using API key index %d/%d",
idx + 1,
len(self.clients),
)
self._client_index += 1
return client
# ───────────────────────────────────────────────────────────────────
# Redis cache helpers
# ───────────────────────────────────────────────────────────────────
def _get_cache_key(self, text: str) -> str:
"""Generate cache key from text"""
text_hash = hashlib.sha256(text.encode()).hexdigest()
return f"embedding:{self.model_name}:{text_hash}"
@staticmethod
def _coerce_embedding(raw: Any) -> List[float]:
if not isinstance(raw, list):
raise ValueError("Embedding payload is not a list")
normalized: List[float] = []
for value in cast(List[object], raw):
if not isinstance(value, (int, float)):
raise ValueError("Embedding values must be numeric")
normalized.append(float(value))
return normalized
def _get_from_cache(self, text: str) -> Optional[List[float]]:
"""Get embedding from Redis cache"""
if not self.redis_client:
return None
try:
cache_key = self._get_cache_key(text)
cached = self.redis_client.get(cache_key)
if cached is None:
return None
if isinstance(cached, bytes):
cached_text = cached.decode("utf-8", errors="ignore")
elif isinstance(cached, str):
cached_text = cached
else:
return None
loaded = json.loads(cached_text)
return self._coerce_embedding(loaded)
except Exception:
pass # Cache miss or error, will compute
return None
def _set_in_cache(self, text: str, embedding: List[float]):
"""Store embedding in Redis cache"""
if not self.redis_client:
return
try:
cache_key = self._get_cache_key(text)
self.redis_client.setex(cache_key, self.cache_ttl, json.dumps(embedding))
except Exception:
pass # Cache write error, not critical
# ───────────────────────────────────────────────────────────────────
# Usage logging
# ───────────────────────────────────────────────────────────────────
@staticmethod
def _log_usage(response: Any) -> None:
"""Log prompt_tokens from the API response for monitoring."""
usage = getattr(response, "usage", None)
if usage is None:
return
prompt_tokens = getattr(usage, "prompt_tokens", None)
total_tokens = getattr(usage, "total_tokens", None)
logger.debug(
"Mistral embedding usage: prompt_tokens=%s total_tokens=%s",
prompt_tokens,
total_tokens,
)
# ───────────────────────────────────────────────────────────────────
# Public API (EmbeddingProvider)
# ───────────────────────────────────────────────────────────────────
[docs]
def embed_text(self, text: str) -> List[float]:
"""Generate embedding using next available client (round-robin) with cache"""
cached = self._get_from_cache(text)
if cached is not None:
return cached
client = self._get_next_client()
response = client.embeddings.create(model=self.model_name, inputs=[text])
self._log_usage(response)
embedding = self._coerce_embedding(response.data[0].embedding)
self._set_in_cache(text, embedding)
return embedding
[docs]
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate batch embeddings using next available client (round-robin) with cache"""
if not self.redis_client:
client = self._get_next_client()
response = client.embeddings.create(model=self.model_name, inputs=texts)
self._log_usage(response)
return [self._coerce_embedding(data.embedding) for data in response.data]
# With cache: check what's already cached
results: List[Optional[List[float]]] = [None] * len(texts)
uncached_indices: List[int] = []
uncached_texts: List[str] = []
for i, text in enumerate(texts):
cached = self._get_from_cache(text)
if cached is not None:
results[i] = cached
else:
uncached_indices.append(i)
uncached_texts.append(text)
# Compute uncached
if uncached_texts:
client = self._get_next_client()
response = client.embeddings.create(model=self.model_name, inputs=uncached_texts)
self._log_usage(response)
for idx, embedding_data in zip(uncached_indices, response.data):
embedding = self._coerce_embedding(embedding_data.embedding)
results[idx] = embedding
self._set_in_cache(texts[idx], embedding)
final_results: List[List[float]] = []
for embedding in results:
if embedding is None:
raise RuntimeError("Embedding provider returned an incomplete batch")
final_results.append(embedding)
return final_results
[docs]
def get_vector_size(self) -> int:
"""Return the embedding vector dimension configured for the provider."""
return self.vector_size
[docs]
def get_num_keys(self) -> int:
"""Return the number of API keys participating in round-robin calls."""
return len(self.api_keys)