diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index b867766..f54a7f6 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, AsyncGenerator, Iterable, Optional, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, Optional, cast from . import _utils from ._abc import BaseChatWithTools @@ -7,7 +7,6 @@ from ._utils import ToolFunction if TYPE_CHECKING: - from anthropic import AsyncAnthropic from anthropic.types import ( ContentBlock, MessageParam, @@ -17,6 +16,7 @@ ToolResultBlockParam, ToolUseBlock, ) + from httpx import URL class AnthropicChat(BaseChatWithTools["CreateCompletion"]): @@ -29,10 +29,11 @@ def __init__( *, api_key: Optional[str] = None, model: "Model" = "claude-3-5-sonnet-20240620", - system_prompt: Optional[str] = None, max_tokens: int = 1024, + system_prompt: Optional[str] = None, tools: Iterable[ToolFunction] = (), - client: "AsyncAnthropic | None" = None, + base_url: Optional[str | URL] = None, + **kwargs: Any, ): """ Start a chat powered by Anthropic @@ -43,42 +44,39 @@ def __init__( Your Anthropic API key. model The model to use for the chat. - system_prompt - A system prompt to use for the chat. max_tokens The maximum number of tokens to generate for each response. + system_prompt + A system prompt to use for the chat. tools A list of tools (i.e., function calls) to use for the chat. - client - An `anthropic.AsyncAnthropic` client instance to use for the chat. - Use this to customize stuff like `base_url`, `timeout`, etc. + base_url + The base URL to use for requests. + kwargs + Additional keyword arguments to pass to the `anthropic.AsyncAnthropic` constructor. Raises ------ ImportError If the `anthropic` package is not installed. """ - self._model = model - self._system_prompt = system_prompt - self._max_tokens = max_tokens - for tool in tools: - self.register_tool(tool) - if client is None: - client = self._get_client() - if api_key is not None: - client.api_key = api_key - self.client = client - - def _get_client(self) -> "AsyncAnthropic": try: from anthropic import AsyncAnthropic - - return AsyncAnthropic() except ImportError: raise ImportError( f"The {self.__class__.__name__} class requires the `anthropic` package. " "Please install it with `pip install anthropic`." ) + self._model = model + self._system_prompt = system_prompt + self._max_tokens = max_tokens + for tool in tools: + self.register_tool(tool) + self.client = AsyncAnthropic( + api_key=api_key, + base_url=base_url, + **kwargs, + ) async def response_generator( self, diff --git a/chatlas/_google.py b/chatlas/_google.py index 38c44cb..611e1db 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -1,16 +1,11 @@ -from typing import TYPE_CHECKING, AsyncGenerator, Optional +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional from ._abc import BaseChat +from ._utils import ToolFunction if TYPE_CHECKING: - from google.generativeai.types import ( - ContentType, - FunctionLibraryType, - GenerationConfigType, - RequestOptionsType, - ) - from google.generativeai.types.content_types import ToolConfigType, protos - from google.generativeai.types.safety_types import SafetySettingOptions + from google.generativeai.types import RequestOptionsType + from google.generativeai.types.content_types import protos Content = protos.Content @@ -23,11 +18,10 @@ def __init__( *, api_key: Optional[str] = None, model: str = "gemini-1.5-flash", - system_prompt: Optional["ContentType"] = None, - tools: Optional["FunctionLibraryType"] = None, - tool_config: Optional["ToolConfigType"] = None, - safety_settings: Optional["SafetySettingOptions"] = None, - generation_config: Optional["GenerationConfigType"] = None, + system_prompt: Optional[str] = None, + tools: Optional[ToolFunction] = None, + api_endpoint: Optional[str] = None, + **kwargs: Any, ) -> None: """ Start a chat powered by Google Generative AI. @@ -42,12 +36,10 @@ def __init__( A system prompt to use for the chat. tools A list of tools (i.e., function calls) to use for the chat. - tool_config - Configuration for the tools (see `google.generativeai.GenerativeModel` for details). - safety_settings - Safety settings for the chat (see `google.generativeai.GenerativeModel` for details). - generation_config - Configuration for the generation process (see `google.generativeai.GenerativeModel` for details). + api_endpoint + The API endpoint to use. + kwargs + Additional keyword arguments to pass to the `google.generativeai.GenerativeModel` constructor. """ try: from google.generativeai import GenerativeModel @@ -60,15 +52,13 @@ def __init__( if api_key is not None: import google.generativeai as genai - genai.configure(api_key=api_key) + genai.configure(api_key=api_key, client_options={"api_endpoint": api_endpoint}) self.client = GenerativeModel( model_name=model, system_instruction=system_prompt, tools=tools, - tool_config=tool_config, - safety_settings=safety_settings, - generation_config=generation_config, + **kwargs, ) # https://github.com/google-gemini/cookbook/blob/main/quickstarts/Function_calling.ipynb diff --git a/chatlas/_ollama.py b/chatlas/_ollama.py index c73f96f..d3e6f09 100644 --- a/chatlas/_ollama.py +++ b/chatlas/_ollama.py @@ -1,5 +1,13 @@ import json -from typing import TYPE_CHECKING, AsyncGenerator, Iterable, Optional, Sequence, cast +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Iterable, + Optional, + Sequence, + cast, +) from . import _utils from ._abc import BaseChatWithTools @@ -8,7 +16,7 @@ from ._utils import ToolFunction if TYPE_CHECKING: - from ollama import AsyncClient, Message + from ollama import Message from ollama._types import ChatResponse, Tool, ToolCall @@ -23,7 +31,8 @@ def __init__( model: Optional[str] = None, system_prompt: Optional[str] = None, tools: Iterable[ToolFunction] = (), - client: "AsyncClient | None" = None, + host: Optional[str] = None, + **kwargs: Any, ) -> None: """ Start a chat powered by Ollama. @@ -37,23 +46,17 @@ def __init__( tools A list of tools (i.e., function calls) to make available in the chat. - client - An `ollama.AsyncClient` instance to use for the chat. Use this to - customize stuff like `host`, `timeout`, etc. + kwargs + Additional keyword arguments to pass to the `ollama.AsyncClient` constructor. """ self._model = model self._system_prompt = system_prompt for tool in tools: self.register_tool(tool) - if client is None: - client = self._get_client() - self.client = client - - def _get_client(self) -> "AsyncClient": try: from ollama import AsyncClient - return AsyncClient() + self.client = AsyncClient(host=host, **kwargs) except ImportError: raise ImportError( f"The {self.__class__.__name__} class requires the `ollama` package. " diff --git a/chatlas/_openai.py b/chatlas/_openai.py index e3bb197..3a84539 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, AsyncGenerator, Iterable, Optional, Sequence +from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, Optional, Sequence from . import _utils from ._abc import BaseChatWithTools @@ -8,7 +8,7 @@ from ._utils import ToolFunction if TYPE_CHECKING: - from openai import AsyncOpenAI + from httpx import URL from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionMessageParam, @@ -31,7 +31,8 @@ def __init__( model: "ChatModel" = "gpt-4o", system_prompt: Optional[str] = None, tools: Iterable[ToolFunction] = (), - client: "AsyncOpenAI | None" = None, + base_url: Optional[str | URL] = None, + **kwargs: Any, ): """ Start a chat powered by OpenAI @@ -46,36 +47,35 @@ def __init__( A system prompt to use for the chat. tools A list of tools (i.e., function calls) to use for the chat. - client - An `openai.AsyncOpenAI` client instance to use for the chat. Use - this to customize stuff like `base_url`, `timeout`, etc. + base_url + The base URL to use for requests. + kwargs + Additional keyword arguments to pass to the `openai.AsyncOpenAI` constructor. Raises ------ ImportError If the `openai` package is not installed. """ - self._model = model - self._system_prompt = system_prompt - for tool in tools: - self.register_tool(tool) - if client is None: - client = self._get_client() - if api_key is not None: - client.api_key = api_key - self.client = client - - def _get_client(self) -> "AsyncOpenAI": try: from openai import AsyncOpenAI - - return AsyncOpenAI() except ImportError: raise ImportError( f"The {self.__class__.__name__} class requires the `openai` package. " "Install it with `pip install openai`." ) + self._model = model + self._system_prompt = system_prompt + for tool in tools: + self.register_tool(tool) + + self.client = AsyncOpenAI( + api_key=api_key, + base_url=base_url, + **kwargs, + ) + async def response_generator( self, user_input: str, diff --git a/chatlas/_utils.py b/chatlas/_utils.py index 56c2803..90c4ea2 100644 --- a/chatlas/_utils.py +++ b/chatlas/_utils.py @@ -30,7 +30,9 @@ ToolFunctionSync = Callable[..., Any] ToolFunctionAsync = Callable[..., Awaitable[Any]] -ToolFunction = Union[ToolFunctionSync, ToolFunctionAsync] +ToolFunction = Union[ + ToolFunctionSync, ToolFunctionAsync +] # TODO: support pydantic types? class ToolSchemaProperty(TypedDict, total=False):