Source code for lalandre_embedding.providers.local

"""
Local embedding provider
"""

import functools
import logging
from typing import Any, List, Optional, cast

from lalandre_core.config import get_config

from ..base import EmbeddingProvider, SupportsCacheSize

logger = logging.getLogger(__name__)


[docs] class LocalEmbedding(EmbeddingProvider, SupportsCacheSize): """Local embedding provider using sentence-transformers with in-memory LRU cache""" model: Any _cached_encode: Any # functools.lru_cache closure, or None # ─────────────────────────────────────────────────────────────────── # 1. Initialization # ─────────────────────────────────────────────────────────────────── def __init__( self, model_name: Optional[str] = None, device: Optional[str] = None, cache_dir: Optional[str] = None, normalize_embeddings: bool = True, enable_cache: Optional[bool] = None, cache_max_size: Optional[int] = None, ): config = get_config() resolved_model_name = model_name or config.embedding.model_name if not resolved_model_name: raise ValueError("Missing local embedding model name (set embedding.model_name in app_config.yaml)") resolved_device = device or config.embedding.device if not resolved_device: raise ValueError("Missing local embedding device (set embedding.device in app_config.yaml)") resolved_enable_cache = enable_cache if enable_cache is not None else config.embedding.enable_cache resolved_cache_max_size = cache_max_size if cache_max_size is not None else config.embedding.cache_max_size if resolved_cache_max_size <= 0: raise ValueError("embedding.cache_max_size must be > 0") self.model_name = resolved_model_name self.device = resolved_device self.normalize_embeddings = normalize_embeddings self.enable_cache = resolved_enable_cache self.cache_dir = cache_dir or config.embedding.cache_dir or config.models_cache_dir # Load model with cache directory if specified logger.info("Loading local embedding model: %s on %s", self.model_name, self.device) if self.cache_dir: logger.info("Using model cache directory: %s", self.cache_dir) from sentence_transformers import SentenceTransformer model: Any = SentenceTransformer(self.model_name, device=self.device, cache_folder=self.cache_dir) self.model = model vector_size = model.get_sentence_embedding_dimension() if vector_size is None: raise ValueError(f"Unable to determine embedding dimension for model: {self.model_name}") self.vector_size = int(vector_size) logger.info("Model loaded successfully. Vector dimension: %d", self.vector_size) # Setup in-memory LRU cache via functools.lru_cache on a closure if self.enable_cache: @functools.lru_cache(maxsize=resolved_cache_max_size) def _encode(text: str) -> tuple: arr = self.model.encode( text, convert_to_numpy=True, normalize_embeddings=self.normalize_embeddings, ) return tuple(arr.tolist()) self._cached_encode = _encode logger.info("In-memory embedding cache enabled (max size: %d)", resolved_cache_max_size) else: self._cached_encode = None # ─────────────────────────────────────────────────────────────────── # 2. Token estimation # ───────────────────────────────────────────────────────────────────
[docs] def estimate_tokens(self, text: str) -> Optional[int]: """Estimate token usage with the underlying sentence-transformers tokenizer.""" tokenizer = getattr(self.model, "tokenizer", None) if tokenizer is None: return None try: encoded: Any = tokenizer( text, add_special_tokens=True, truncation=False, return_attention_mask=False, return_token_type_ids=False, ) except Exception: return None encoded_payload = cast(dict[str, object], encoded) if isinstance(encoded, dict) else None if encoded_payload is None: return None raw_input_ids = encoded_payload.get("input_ids") if isinstance(raw_input_ids, list): input_ids = cast(list[object], raw_input_ids) if input_ids and isinstance(input_ids[0], list): first_sequence = cast(list[object], input_ids[0]) return len(first_sequence) return len(input_ids) return None
[docs] def get_max_input_tokens(self) -> Optional[int]: """Return the maximum sequence length supported by the local model.""" raw_limit = getattr(self.model, "max_seq_length", None) if not isinstance(raw_limit, int) or raw_limit <= 0: return None return raw_limit
# ─────────────────────────────────────────────────────────────────── # 4. Public API (EmbeddingProvider) # ───────────────────────────────────────────────────────────────────
[docs] def embed_text(self, text: str) -> List[float]: """Generate embedding with cache support""" if self._cached_encode is not None: return list(self._cached_encode(text)) arr: Any = self.model.encode( text, convert_to_numpy=True, normalize_embeddings=self.normalize_embeddings, ) return cast(List[float], arr.tolist())
[docs] def embed_batch(self, texts: List[str]) -> List[List[float]]: """Generate batch embeddings with cache support""" if self._cached_encode is None: embeddings: Any = self.model.encode( texts, convert_to_numpy=True, normalize_embeddings=self.normalize_embeddings, show_progress_bar=False, ) return cast(List[List[float]], embeddings.tolist()) # lru_cache handles hit/miss transparently per text return [list(self._cached_encode(text)) for text in texts]
[docs] def get_vector_size(self) -> int: """Return the embedding vector dimension exposed by the model.""" return self.vector_size
[docs] def get_cache_size(self) -> int: """Return the current number of cached embeddings.""" return self._cached_encode.cache_info().currsize if self._cached_encode is not None else 0