"""
Factory utilities for RAG LLM clients.
"""
import logging
from dataclasses import dataclass
from typing import Any, Optional, cast
from lalandre_core.llm import (
SharedKeyPoolChatModel,
build_chat_model,
normalize_base_url,
normalize_provider,
resolve_api_key,
)
from lalandre_core.utils.api_key_pool import APIKeyPool
from lalandre_core.utils.shared_key_pool import SharedKeyPoolProxy, build_clients_by_key
from llama_index.core.llms import LLM
from llama_index.llms.openai_like import OpenAILike
logger = logging.getLogger(__name__)
[docs]
@dataclass(frozen=True)
class RAGLLMClients:
"""Bundled LLM clients used by RAG modes."""
provider: str
model_name: str
chat_llm: Any
llamaindex_llm: Optional[LLM]
def _build_llamaindex_mistral_llm(
*,
model_name: str,
api_key: str,
temperature: float,
max_tokens: int,
mistral_base_url: str,
context_window: int,
) -> LLM:
return cast(
LLM,
OpenAILike(
model=model_name,
api_key=api_key,
api_base=mistral_base_url,
temperature=temperature,
max_tokens=max_tokens,
is_chat_model=True,
context_window=context_window,
),
)
class _SharedKeyPoolLlamaIndexLLM:
"""Dispatch LlamaIndex calls through a shared API key pool."""
def __init__(
self,
*,
key_pool: APIKeyPool,
models_by_key: dict[str, LLM],
) -> None:
self._proxy = SharedKeyPoolProxy(
key_pool=key_pool,
clients_by_key=models_by_key,
)
def __getattr__(self, name: str) -> Any:
return getattr(self._proxy, name)
def chat(self, *args: Any, **kwargs: Any) -> Any:
return self._proxy.chat(*args, **kwargs)
def complete(self, *args: Any, **kwargs: Any) -> Any:
return self._proxy.complete(*args, **kwargs)
def stream_chat(self, *args: Any, **kwargs: Any) -> Any:
return self._proxy.stream_chat(*args, **kwargs)
def stream_complete(self, *args: Any, **kwargs: Any) -> Any:
return self._proxy.stream_complete(*args, **kwargs)
[docs]
def build_rag_llm_clients(
*,
provider: str,
model_name: str,
temperature: float,
max_tokens: int,
timeout_seconds: float,
base_url: Optional[str],
mistral_base_url: str,
context_window: int,
api_key: Optional[str],
mistral_api_key: Optional[str],
key_pool: Optional[APIKeyPool] = None,
) -> RAGLLMClients:
"""
Build provider-specific clients for RAG.
When *key_pool* is provided and contains >1 key, multiple underlying
clients are created and dispatched through the shared pool.
Supported providers: mistral, openai_compatible.
"""
resolved_provider = normalize_provider(provider)
resolved_model = model_name.strip()
if not resolved_model:
raise ValueError("LLM model is required.")
effective_api_key = resolve_api_key(
provider=resolved_provider,
api_key=api_key,
mistral_api_key=mistral_api_key,
)
if resolved_provider == "mistral":
if key_pool is not None and len(key_pool) > 1:
logger.info(
"RAG LLM factory: building pooled clients with %d keys",
len(key_pool),
)
chat_models = build_clients_by_key(
key_pool=key_pool,
factory=lambda key: build_chat_model(
provider=resolved_provider,
model=resolved_model,
api_key=key,
temperature=temperature,
max_tokens=max_tokens,
timeout_seconds=timeout_seconds,
),
)
llama_models = build_clients_by_key(
key_pool=key_pool,
factory=lambda key: _build_llamaindex_mistral_llm(
model_name=resolved_model,
api_key=key,
temperature=temperature,
max_tokens=max_tokens,
mistral_base_url=mistral_base_url,
context_window=context_window,
),
)
chat_llm = SharedKeyPoolChatModel(
key_pool=key_pool,
models_by_key=chat_models,
)
llama_llm = cast(
LLM,
_SharedKeyPoolLlamaIndexLLM(
key_pool=key_pool,
models_by_key=llama_models,
),
)
else:
chat_llm = build_chat_model(
provider=resolved_provider,
model=resolved_model,
api_key=effective_api_key,
temperature=temperature,
max_tokens=max_tokens,
timeout_seconds=timeout_seconds,
)
llama_llm = _build_llamaindex_mistral_llm(
model_name=resolved_model,
api_key=effective_api_key,
temperature=temperature,
max_tokens=max_tokens,
mistral_base_url=mistral_base_url,
context_window=context_window,
)
return RAGLLMClients(
provider=resolved_provider,
model_name=resolved_model,
chat_llm=chat_llm,
llamaindex_llm=llama_llm,
)
if resolved_provider == "openai_compatible":
resolved_base_url = normalize_base_url(
provider=resolved_provider,
base_url=base_url or "",
)
if not resolved_base_url:
raise ValueError("LLM base URL is required for openai_compatible provider. Set generation.base_url.")
chat_llm = build_chat_model(
provider=resolved_provider,
model=resolved_model,
api_key=effective_api_key,
base_url=resolved_base_url,
temperature=temperature,
max_tokens=max_tokens,
timeout_seconds=timeout_seconds,
)
return RAGLLMClients(
provider=resolved_provider,
model_name=resolved_model,
chat_llm=chat_llm,
llamaindex_llm=None,
)
raise ValueError(f"Unsupported LLM provider {provider!r}. Use one of: mistral, openai_compatible.")