"""
Unified LangChain ChatModel factory.
"""
import itertools
import logging
from typing import Any, Iterator, List, Mapping, Optional
from langchain_core.runnables import Runnable
from langchain_mistralai import ChatMistralAI
from langchain_openai import ChatOpenAI
from lalandre_core.utils.api_key_pool import APIKeyPool
from lalandre_core.utils.shared_key_pool import SharedKeyPoolProxy
logger = logging.getLogger(__name__)
[docs]
class PooledChatModel(Runnable):
"""Round-robin wrapper over multiple ChatModel instances."""
def __init__(self, models: List[Any]) -> None:
if not models:
raise ValueError("PooledChatModel requires at least one model")
self._cycle = itertools.cycle(models)
logger.info("PooledChatModel: initialized with %d model(s)", len(models))
[docs]
def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any:
"""Invoke the next model in the pool with one request."""
return next(self._cycle).invoke(input, config, **kwargs)
[docs]
def stream(self, input: Any, config: Any = None, **kwargs: Any) -> Iterator[Any]:
"""Stream one response from the next model in the pool."""
return next(self._cycle).stream(input, config, **kwargs)
[docs]
def batch(self, inputs: List[Any], config: Any = None, **kwargs: Any) -> List[Any]:
"""Process one batch with the next model in the pool."""
return next(self._cycle).batch(inputs, config, **kwargs)
[docs]
class SharedKeyPoolChatModel(Runnable):
"""Dispatch each LangChain call through a shared API key pool."""
def __init__(
self,
*,
key_pool: APIKeyPool,
models_by_key: Mapping[str, Any],
) -> None:
self._proxy = SharedKeyPoolProxy(
key_pool=key_pool,
clients_by_key=models_by_key,
)
logger.info(
"SharedKeyPoolChatModel: initialized with %d model(s)",
len(models_by_key),
)
[docs]
def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any:
"""Invoke the model selected by the shared API key pool."""
return self._proxy.invoke(input, config, **kwargs)
[docs]
def stream(self, input: Any, config: Any = None, **kwargs: Any) -> Iterator[Any]:
"""Stream a response from the model selected by the shared API key pool."""
return self._proxy.stream(input, config, **kwargs)
[docs]
def batch(self, inputs: List[Any], config: Any = None, **kwargs: Any) -> List[Any]:
"""Execute a batch call through the model selected by the shared API key pool."""
return self._proxy.batch(inputs, config, **kwargs)
def __getattr__(self, name: str) -> Any:
return getattr(self._proxy, name)
[docs]
def build_chat_model(
*,
provider: str,
model: str,
api_key: str,
base_url: str = "",
temperature: float = 0.0,
max_tokens: Optional[int] = None,
timeout_seconds: Optional[float] = None,
) -> Any:
"""Build a LangChain ChatModel (ChatMistralAI or ChatOpenAI).
Returns the raw ChatModel instance — callers can wrap it
(e.g. LangchainLLMWrapper, StrOutputParser) as needed.
"""
if provider == "mistral":
kwargs: dict[str, Any] = {
"model": model,
"mistral_api_key": api_key,
"temperature": temperature,
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if timeout_seconds is not None:
kwargs["timeout"] = timeout_seconds
return ChatMistralAI(**kwargs)
if provider == "openai_compatible":
kwargs = {
"model": model,
"api_key": api_key,
"base_url": base_url,
"temperature": temperature,
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if timeout_seconds is not None:
kwargs["timeout"] = timeout_seconds
return ChatOpenAI(**kwargs)
raise ValueError(f"Unsupported LLM provider {provider!r}. Use one of: mistral, openai_compatible.")