Skip to content

Commit

Permalink
Pass kwargs along to client constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert committed Oct 10, 2024
1 parent 53f7460 commit 8a03d10
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 79 deletions.
44 changes: 21 additions & 23 deletions chatlas/_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import json
from typing import TYPE_CHECKING, AsyncGenerator, Iterable, Optional, cast
from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, Optional, cast

from . import _utils
from ._abc import BaseChatWithTools
from ._anthropic_types import CreateCompletion
from ._utils import ToolFunction

if TYPE_CHECKING:
from anthropic import AsyncAnthropic
from anthropic.types import (
ContentBlock,
MessageParam,
Expand All @@ -17,6 +16,7 @@
ToolResultBlockParam,
ToolUseBlock,
)
from httpx import URL


class AnthropicChat(BaseChatWithTools["CreateCompletion"]):
Expand All @@ -29,10 +29,11 @@ def __init__(
*,
api_key: Optional[str] = None,
model: "Model" = "claude-3-5-sonnet-20240620",
system_prompt: Optional[str] = None,
max_tokens: int = 1024,
system_prompt: Optional[str] = None,
tools: Iterable[ToolFunction] = (),
client: "AsyncAnthropic | None" = None,
base_url: Optional[str | URL] = None,
**kwargs: Any,
):
"""
Start a chat powered by Anthropic
Expand All @@ -43,42 +44,39 @@ def __init__(
Your Anthropic API key.
model
The model to use for the chat.
system_prompt
A system prompt to use for the chat.
max_tokens
The maximum number of tokens to generate for each response.
system_prompt
A system prompt to use for the chat.
tools
A list of tools (i.e., function calls) to use for the chat.
client
An `anthropic.AsyncAnthropic` client instance to use for the chat.
Use this to customize stuff like `base_url`, `timeout`, etc.
base_url
The base URL to use for requests.
kwargs
Additional keyword arguments to pass to the `anthropic.AsyncAnthropic` constructor.
Raises
------
ImportError
If the `anthropic` package is not installed.
"""
self._model = model
self._system_prompt = system_prompt
self._max_tokens = max_tokens
for tool in tools:
self.register_tool(tool)
if client is None:
client = self._get_client()
if api_key is not None:
client.api_key = api_key
self.client = client

def _get_client(self) -> "AsyncAnthropic":
try:
from anthropic import AsyncAnthropic

return AsyncAnthropic()
except ImportError:
raise ImportError(
f"The {self.__class__.__name__} class requires the `anthropic` package. "
"Please install it with `pip install anthropic`."
)
self._model = model
self._system_prompt = system_prompt
self._max_tokens = max_tokens
for tool in tools:
self.register_tool(tool)
self.client = AsyncAnthropic(
api_key=api_key,
base_url=base_url,
**kwargs,
)

async def response_generator(
self,
Expand Down
38 changes: 14 additions & 24 deletions chatlas/_google.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from typing import TYPE_CHECKING, AsyncGenerator, Optional
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional

from ._abc import BaseChat
from ._utils import ToolFunction

if TYPE_CHECKING:
from google.generativeai.types import (
ContentType,
FunctionLibraryType,
GenerationConfigType,
RequestOptionsType,
)
from google.generativeai.types.content_types import ToolConfigType, protos
from google.generativeai.types.safety_types import SafetySettingOptions
from google.generativeai.types import RequestOptionsType
from google.generativeai.types.content_types import protos

Content = protos.Content

Expand All @@ -23,11 +18,10 @@ def __init__(
*,
api_key: Optional[str] = None,
model: str = "gemini-1.5-flash",
system_prompt: Optional["ContentType"] = None,
tools: Optional["FunctionLibraryType"] = None,
tool_config: Optional["ToolConfigType"] = None,
safety_settings: Optional["SafetySettingOptions"] = None,
generation_config: Optional["GenerationConfigType"] = None,
system_prompt: Optional[str] = None,
tools: Optional[ToolFunction] = None,
api_endpoint: Optional[str] = None,
**kwargs: Any,
) -> None:
"""
Start a chat powered by Google Generative AI.
Expand All @@ -42,12 +36,10 @@ def __init__(
A system prompt to use for the chat.
tools
A list of tools (i.e., function calls) to use for the chat.
tool_config
Configuration for the tools (see `google.generativeai.GenerativeModel` for details).
safety_settings
Safety settings for the chat (see `google.generativeai.GenerativeModel` for details).
generation_config
Configuration for the generation process (see `google.generativeai.GenerativeModel` for details).
api_endpoint
The API endpoint to use.
kwargs
Additional keyword arguments to pass to the `google.generativeai.GenerativeModel` constructor.
"""
try:
from google.generativeai import GenerativeModel
Expand All @@ -60,15 +52,13 @@ def __init__(
if api_key is not None:
import google.generativeai as genai

genai.configure(api_key=api_key)
genai.configure(api_key=api_key, client_options={"api_endpoint": api_endpoint})

self.client = GenerativeModel(
model_name=model,
system_instruction=system_prompt,
tools=tools,
tool_config=tool_config,
safety_settings=safety_settings,
generation_config=generation_config,
**kwargs,
)

# https://github.com/google-gemini/cookbook/blob/main/quickstarts/Function_calling.ipynb
Expand Down
27 changes: 15 additions & 12 deletions chatlas/_ollama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import json
from typing import TYPE_CHECKING, AsyncGenerator, Iterable, Optional, Sequence, cast
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Iterable,
Optional,
Sequence,
cast,
)

from . import _utils
from ._abc import BaseChatWithTools
Expand All @@ -8,7 +16,7 @@
from ._utils import ToolFunction

if TYPE_CHECKING:
from ollama import AsyncClient, Message
from ollama import Message
from ollama._types import ChatResponse, Tool, ToolCall


Expand All @@ -23,7 +31,8 @@ def __init__(
model: Optional[str] = None,
system_prompt: Optional[str] = None,
tools: Iterable[ToolFunction] = (),
client: "AsyncClient | None" = None,
host: Optional[str] = None,
**kwargs: Any,
) -> None:
"""
Start a chat powered by Ollama.
Expand All @@ -37,23 +46,17 @@ def __init__(
tools
A list of tools (i.e., function calls) to make available in the
chat.
client
An `ollama.AsyncClient` instance to use for the chat. Use this to
customize stuff like `host`, `timeout`, etc.
kwargs
Additional keyword arguments to pass to the `ollama.AsyncClient` constructor.
"""
self._model = model
self._system_prompt = system_prompt
for tool in tools:
self.register_tool(tool)
if client is None:
client = self._get_client()
self.client = client

def _get_client(self) -> "AsyncClient":
try:
from ollama import AsyncClient

return AsyncClient()
self.client = AsyncClient(host=host, **kwargs)
except ImportError:
raise ImportError(
f"The {self.__class__.__name__} class requires the `ollama` package. "
Expand Down
38 changes: 19 additions & 19 deletions chatlas/_openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import TYPE_CHECKING, AsyncGenerator, Iterable, Optional, Sequence
from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, Optional, Sequence

from . import _utils
from ._abc import BaseChatWithTools
Expand All @@ -8,7 +8,7 @@
from ._utils import ToolFunction

if TYPE_CHECKING:
from openai import AsyncOpenAI
from httpx import URL
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
Expand All @@ -31,7 +31,8 @@ def __init__(
model: "ChatModel" = "gpt-4o",
system_prompt: Optional[str] = None,
tools: Iterable[ToolFunction] = (),
client: "AsyncOpenAI | None" = None,
base_url: Optional[str | URL] = None,
**kwargs: Any,
):
"""
Start a chat powered by OpenAI
Expand All @@ -46,36 +47,35 @@ def __init__(
A system prompt to use for the chat.
tools
A list of tools (i.e., function calls) to use for the chat.
client
An `openai.AsyncOpenAI` client instance to use for the chat. Use
this to customize stuff like `base_url`, `timeout`, etc.
base_url
The base URL to use for requests.
kwargs
Additional keyword arguments to pass to the `openai.AsyncOpenAI` constructor.
Raises
------
ImportError
If the `openai` package is not installed.
"""
self._model = model
self._system_prompt = system_prompt
for tool in tools:
self.register_tool(tool)
if client is None:
client = self._get_client()
if api_key is not None:
client.api_key = api_key
self.client = client

def _get_client(self) -> "AsyncOpenAI":
try:
from openai import AsyncOpenAI

return AsyncOpenAI()
except ImportError:
raise ImportError(
f"The {self.__class__.__name__} class requires the `openai` package. "
"Install it with `pip install openai`."
)

self._model = model
self._system_prompt = system_prompt
for tool in tools:
self.register_tool(tool)

self.client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
**kwargs,
)

async def response_generator(
self,
user_input: str,
Expand Down
4 changes: 3 additions & 1 deletion chatlas/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@

ToolFunctionSync = Callable[..., Any]
ToolFunctionAsync = Callable[..., Awaitable[Any]]
ToolFunction = Union[ToolFunctionSync, ToolFunctionAsync]
ToolFunction = Union[
ToolFunctionSync, ToolFunctionAsync
] # TODO: support pydantic types?


class ToolSchemaProperty(TypedDict, total=False):
Expand Down

0 comments on commit 8a03d10

Please sign in to comment.