"""
Local embedding provider
"""
import functools
import logging
from typing import Any, List, Optional, cast
from lalandre_core.config import get_config
from ..base import EmbeddingProvider, SupportsCacheSize
logger = logging.getLogger(__name__)
[docs]
class LocalEmbedding(EmbeddingProvider, SupportsCacheSize):
"""Local embedding provider using sentence-transformers with in-memory LRU cache"""
model: Any
_cached_encode: Any # functools.lru_cache closure, or None
# ───────────────────────────────────────────────────────────────────
# 1. Initialization
# ───────────────────────────────────────────────────────────────────
def __init__(
self,
model_name: Optional[str] = None,
device: Optional[str] = None,
cache_dir: Optional[str] = None,
normalize_embeddings: bool = True,
enable_cache: Optional[bool] = None,
cache_max_size: Optional[int] = None,
):
config = get_config()
resolved_model_name = model_name or config.embedding.model_name
if not resolved_model_name:
raise ValueError("Missing local embedding model name (set embedding.model_name in app_config.yaml)")
resolved_device = device or config.embedding.device
if not resolved_device:
raise ValueError("Missing local embedding device (set embedding.device in app_config.yaml)")
resolved_enable_cache = enable_cache if enable_cache is not None else config.embedding.enable_cache
resolved_cache_max_size = cache_max_size if cache_max_size is not None else config.embedding.cache_max_size
if resolved_cache_max_size <= 0:
raise ValueError("embedding.cache_max_size must be > 0")
self.model_name = resolved_model_name
self.device = resolved_device
self.normalize_embeddings = normalize_embeddings
self.enable_cache = resolved_enable_cache
self.cache_dir = cache_dir or config.embedding.cache_dir or config.models_cache_dir
# Load model with cache directory if specified
logger.info("Loading local embedding model: %s on %s", self.model_name, self.device)
if self.cache_dir:
logger.info("Using model cache directory: %s", self.cache_dir)
from sentence_transformers import SentenceTransformer
model: Any = SentenceTransformer(self.model_name, device=self.device, cache_folder=self.cache_dir)
self.model = model
vector_size = model.get_sentence_embedding_dimension()
if vector_size is None:
raise ValueError(f"Unable to determine embedding dimension for model: {self.model_name}")
self.vector_size = int(vector_size)
logger.info("Model loaded successfully. Vector dimension: %d", self.vector_size)
# Setup in-memory LRU cache via functools.lru_cache on a closure
if self.enable_cache:
@functools.lru_cache(maxsize=resolved_cache_max_size)
def _encode(text: str) -> tuple:
arr = self.model.encode(
text,
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings,
)
return tuple(arr.tolist())
self._cached_encode = _encode
logger.info("In-memory embedding cache enabled (max size: %d)", resolved_cache_max_size)
else:
self._cached_encode = None
# ───────────────────────────────────────────────────────────────────
# 2. Token estimation
# ───────────────────────────────────────────────────────────────────
[docs]
def estimate_tokens(self, text: str) -> Optional[int]:
"""Estimate token usage with the underlying sentence-transformers tokenizer."""
tokenizer = getattr(self.model, "tokenizer", None)
if tokenizer is None:
return None
try:
encoded: Any = tokenizer(
text,
add_special_tokens=True,
truncation=False,
return_attention_mask=False,
return_token_type_ids=False,
)
except Exception:
return None
encoded_payload = cast(dict[str, object], encoded) if isinstance(encoded, dict) else None
if encoded_payload is None:
return None
raw_input_ids = encoded_payload.get("input_ids")
if isinstance(raw_input_ids, list):
input_ids = cast(list[object], raw_input_ids)
if input_ids and isinstance(input_ids[0], list):
first_sequence = cast(list[object], input_ids[0])
return len(first_sequence)
return len(input_ids)
return None
# ───────────────────────────────────────────────────────────────────
# 4. Public API (EmbeddingProvider)
# ───────────────────────────────────────────────────────────────────
[docs]
def embed_text(self, text: str) -> List[float]:
"""Generate embedding with cache support"""
if self._cached_encode is not None:
return list(self._cached_encode(text))
arr: Any = self.model.encode(
text,
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings,
)
return cast(List[float], arr.tolist())
[docs]
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate batch embeddings with cache support"""
if self._cached_encode is None:
embeddings: Any = self.model.encode(
texts,
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings,
show_progress_bar=False,
)
return cast(List[List[float]], embeddings.tolist())
# lru_cache handles hit/miss transparently per text
return [list(self._cached_encode(text)) for text in texts]
[docs]
def get_vector_size(self) -> int:
"""Return the embedding vector dimension exposed by the model."""
return self.vector_size
[docs]
def get_cache_size(self) -> int:
"""Return the current number of cached embeddings."""
return self._cached_encode.cache_info().currsize if self._cached_encode is not None else 0