Source code for lalandre_core.llm.langchain

"""
Unified LangChain ChatModel factory.
"""

import itertools
import logging
from typing import Any, Iterator, List, Mapping, Optional

from langchain_core.runnables import Runnable
from langchain_mistralai import ChatMistralAI
from langchain_openai import ChatOpenAI

from lalandre_core.utils.api_key_pool import APIKeyPool
from lalandre_core.utils.shared_key_pool import SharedKeyPoolProxy

logger = logging.getLogger(__name__)


[docs] class PooledChatModel(Runnable): """Round-robin wrapper over multiple ChatModel instances.""" def __init__(self, models: List[Any]) -> None: if not models: raise ValueError("PooledChatModel requires at least one model") self._cycle = itertools.cycle(models) logger.info("PooledChatModel: initialized with %d model(s)", len(models))
[docs] def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any: """Invoke the next model in the pool with one request.""" return next(self._cycle).invoke(input, config, **kwargs)
[docs] def stream(self, input: Any, config: Any = None, **kwargs: Any) -> Iterator[Any]: """Stream one response from the next model in the pool.""" return next(self._cycle).stream(input, config, **kwargs)
[docs] def batch(self, inputs: List[Any], config: Any = None, **kwargs: Any) -> List[Any]: """Process one batch with the next model in the pool.""" return next(self._cycle).batch(inputs, config, **kwargs)
[docs] class SharedKeyPoolChatModel(Runnable): """Dispatch each LangChain call through a shared API key pool.""" def __init__( self, *, key_pool: APIKeyPool, models_by_key: Mapping[str, Any], ) -> None: self._proxy = SharedKeyPoolProxy( key_pool=key_pool, clients_by_key=models_by_key, ) logger.info( "SharedKeyPoolChatModel: initialized with %d model(s)", len(models_by_key), )
[docs] def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any: """Invoke the model selected by the shared API key pool.""" return self._proxy.invoke(input, config, **kwargs)
[docs] def stream(self, input: Any, config: Any = None, **kwargs: Any) -> Iterator[Any]: """Stream a response from the model selected by the shared API key pool.""" return self._proxy.stream(input, config, **kwargs)
[docs] def batch(self, inputs: List[Any], config: Any = None, **kwargs: Any) -> List[Any]: """Execute a batch call through the model selected by the shared API key pool.""" return self._proxy.batch(inputs, config, **kwargs)
def __getattr__(self, name: str) -> Any: return getattr(self._proxy, name)
[docs] def build_chat_model( *, provider: str, model: str, api_key: str, base_url: str = "", temperature: float = 0.0, max_tokens: Optional[int] = None, timeout_seconds: Optional[float] = None, ) -> Any: """Build a LangChain ChatModel (ChatMistralAI or ChatOpenAI). Returns the raw ChatModel instance — callers can wrap it (e.g. LangchainLLMWrapper, StrOutputParser) as needed. """ if provider == "mistral": kwargs: dict[str, Any] = { "model": model, "mistral_api_key": api_key, "temperature": temperature, } if max_tokens is not None: kwargs["max_tokens"] = max_tokens if timeout_seconds is not None: kwargs["timeout"] = timeout_seconds return ChatMistralAI(**kwargs) if provider == "openai_compatible": kwargs = { "model": model, "api_key": api_key, "base_url": base_url, "temperature": temperature, } if max_tokens is not None: kwargs["max_tokens"] = max_tokens if timeout_seconds is not None: kwargs["timeout"] = timeout_seconds return ChatOpenAI(**kwargs) raise ValueError(f"Unsupported LLM provider {provider!r}. Use one of: mistral, openai_compatible.")