Source code for lalandre_rag.llm.factory

"""
Factory utilities for RAG LLM clients.
"""

import logging
from dataclasses import dataclass
from typing import Any, Optional, cast

from lalandre_core.llm import (
    SharedKeyPoolChatModel,
    build_chat_model,
    normalize_base_url,
    normalize_provider,
    resolve_api_key,
)
from lalandre_core.utils.api_key_pool import APIKeyPool
from lalandre_core.utils.shared_key_pool import SharedKeyPoolProxy, build_clients_by_key
from llama_index.core.llms import LLM
from llama_index.llms.openai_like import OpenAILike

logger = logging.getLogger(__name__)


[docs] @dataclass(frozen=True) class RAGLLMClients: """Bundled LLM clients used by RAG modes.""" provider: str model_name: str chat_llm: Any llamaindex_llm: Optional[LLM]
def _build_llamaindex_mistral_llm( *, model_name: str, api_key: str, temperature: float, max_tokens: int, mistral_base_url: str, context_window: int, ) -> LLM: return cast( LLM, OpenAILike( model=model_name, api_key=api_key, api_base=mistral_base_url, temperature=temperature, max_tokens=max_tokens, is_chat_model=True, context_window=context_window, ), ) class _SharedKeyPoolLlamaIndexLLM: """Dispatch LlamaIndex calls through a shared API key pool.""" def __init__( self, *, key_pool: APIKeyPool, models_by_key: dict[str, LLM], ) -> None: self._proxy = SharedKeyPoolProxy( key_pool=key_pool, clients_by_key=models_by_key, ) def __getattr__(self, name: str) -> Any: return getattr(self._proxy, name) def chat(self, *args: Any, **kwargs: Any) -> Any: return self._proxy.chat(*args, **kwargs) def complete(self, *args: Any, **kwargs: Any) -> Any: return self._proxy.complete(*args, **kwargs) def stream_chat(self, *args: Any, **kwargs: Any) -> Any: return self._proxy.stream_chat(*args, **kwargs) def stream_complete(self, *args: Any, **kwargs: Any) -> Any: return self._proxy.stream_complete(*args, **kwargs)
[docs] def build_rag_llm_clients( *, provider: str, model_name: str, temperature: float, max_tokens: int, timeout_seconds: float, base_url: Optional[str], mistral_base_url: str, context_window: int, api_key: Optional[str], mistral_api_key: Optional[str], key_pool: Optional[APIKeyPool] = None, ) -> RAGLLMClients: """ Build provider-specific clients for RAG. When *key_pool* is provided and contains >1 key, multiple underlying clients are created and dispatched through the shared pool. Supported providers: mistral, openai_compatible. """ resolved_provider = normalize_provider(provider) resolved_model = model_name.strip() if not resolved_model: raise ValueError("LLM model is required.") effective_api_key = resolve_api_key( provider=resolved_provider, api_key=api_key, mistral_api_key=mistral_api_key, ) if resolved_provider == "mistral": if key_pool is not None and len(key_pool) > 1: logger.info( "RAG LLM factory: building pooled clients with %d keys", len(key_pool), ) chat_models = build_clients_by_key( key_pool=key_pool, factory=lambda key: build_chat_model( provider=resolved_provider, model=resolved_model, api_key=key, temperature=temperature, max_tokens=max_tokens, timeout_seconds=timeout_seconds, ), ) llama_models = build_clients_by_key( key_pool=key_pool, factory=lambda key: _build_llamaindex_mistral_llm( model_name=resolved_model, api_key=key, temperature=temperature, max_tokens=max_tokens, mistral_base_url=mistral_base_url, context_window=context_window, ), ) chat_llm = SharedKeyPoolChatModel( key_pool=key_pool, models_by_key=chat_models, ) llama_llm = cast( LLM, _SharedKeyPoolLlamaIndexLLM( key_pool=key_pool, models_by_key=llama_models, ), ) else: chat_llm = build_chat_model( provider=resolved_provider, model=resolved_model, api_key=effective_api_key, temperature=temperature, max_tokens=max_tokens, timeout_seconds=timeout_seconds, ) llama_llm = _build_llamaindex_mistral_llm( model_name=resolved_model, api_key=effective_api_key, temperature=temperature, max_tokens=max_tokens, mistral_base_url=mistral_base_url, context_window=context_window, ) return RAGLLMClients( provider=resolved_provider, model_name=resolved_model, chat_llm=chat_llm, llamaindex_llm=llama_llm, ) if resolved_provider == "openai_compatible": resolved_base_url = normalize_base_url( provider=resolved_provider, base_url=base_url or "", ) if not resolved_base_url: raise ValueError("LLM base URL is required for openai_compatible provider. Set generation.base_url.") chat_llm = build_chat_model( provider=resolved_provider, model=resolved_model, api_key=effective_api_key, base_url=resolved_base_url, temperature=temperature, max_tokens=max_tokens, timeout_seconds=timeout_seconds, ) return RAGLLMClients( provider=resolved_provider, model_name=resolved_model, chat_llm=chat_llm, llamaindex_llm=None, ) raise ValueError(f"Unsupported LLM provider {provider!r}. Use one of: mistral, openai_compatible.")