Skip to content

Commit

Permalink
make client pluggable by external code
Browse files Browse the repository at this point in the history
  • Loading branch information
olgavrou committed Nov 30, 2023
1 parent 6ce9a6e commit 473b280
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 64 deletions.
2 changes: 2 additions & 0 deletions autogen/oai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
config_list_from_json,
config_list_from_dotenv,
)
from autogen.oai.client import Client

__all__ = [
"OpenAIWrapper",
Expand All @@ -19,4 +20,5 @@
"config_list_from_models",
"config_list_from_json",
"config_list_from_dotenv",
"Client",
]
85 changes: 21 additions & 64 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import inspect
from flaml.automl.logger import logger_formatter
from types import SimpleNamespace
from abc import ABC, abstractmethod

from autogen.oai.openai_utils import get_key, oai_price1k
from autogen.token_count_utils import count_token
Expand All @@ -24,12 +24,6 @@
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
OpenAI = object

try:
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
ERROR = None
except ImportError:
ERROR = ImportError("Please install transformers and diskcache to use autogen.RLClientWrapper.")

logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
Expand All @@ -38,6 +32,14 @@
logger.addHandler(_ch)


def import_class_from_path(path, class_name):
import importlib.util
spec = importlib.util.spec_from_file_location("module.name", path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
cls = getattr(module, class_name)
return cls

Check warning on line 41 in autogen/oai/client.py

View check run for this annotation

Codecov / codecov/patch

autogen/oai/client.py#L36-L41

Added lines #L36 - L41 were not covered by tests

def template_formatter(
template: str | Callable | None,
context: Optional[Dict] = None,
Expand Down Expand Up @@ -125,66 +127,18 @@ def create(self, client, client_id, is_last, create_config: Dict, extra_kwargs:
return response
return None

Check warning on line 128 in autogen/oai/client.py

View check run for this annotation

Codecov / codecov/patch

autogen/oai/client.py#L124-L128

Added lines #L124 - L128 were not covered by tests

class Client(ABC):

class RLClient:
def __init__(self, config: Dict):
import torch

self.device = (
("cuda" if torch.cuda.is_available() else "cpu") if config.get("device", None) is None else config["device"]
)
self.tokenizer = AutoTokenizer.from_pretrained(config["local_model"], load_in_8bit=True, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(config["local_model"]).to(self.device)
# get max_length from config or set to 1000
self.max_length = config.get("max_length", 1000)
self.gen_config_params = config.get("params", {})
# correct max_length in self.params
self.gen_config_params["max_length"] = self.max_length
self.gen_config_params["eos_token_id"] = self.tokenizer.eos_token_id
self.gen_config_params["pad_token_id"] = self.tokenizer.eos_token_id
print(f"Loaded model {config['local_model']} to {self.device}")

@abstractmethod
def create(self, params):
if params.get("stream", False) and "messages" in params and "functions" not in params:
raise NotImplementedError("Local models do not support streaming or functions")
else:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0
pass

Check warning on line 134 in autogen/oai/client.py

View check run for this annotation

Codecov / codecov/patch

autogen/oai/client.py#L134

Added line #L134 was not covered by tests

response = SimpleNamespace()
inputs = self.tokenizer.apply_chat_template(
params["messages"], return_tensors="pt", add_generation_prompt=True
).to(self.device)

inputs_length = inputs.shape[-1]
# copy gen config params
gen_config_params = self.gen_config_params.copy()
# add inputs_length to max_length
gen_config_params["max_length"] += inputs_length
generation_config = GenerationConfig(**gen_config_params)


response.choices = []

for _ in range(len(response_contents)):
outputs = self.model.generate(inputs, generation_config=generation_config)
# Decode only the newly generated text, excluding the prompt
text = self.tokenizer.decode(outputs[0, inputs_length:], skip_special_tokens=True)
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = text
choice.message.function_call = None
response.choices.append(choice)

return response

def cost(self, response) -> float:
"""Calculate the cost of the response."""
return 0
@abstractmethod
def cost(self, response):
pass

Check warning on line 138 in autogen/oai/client.py

View check run for this annotation

Codecov / codecov/patch

autogen/oai/client.py#L138

Added line #L138 was not covered by tests


class OpenAIClient:
class OpenAIClient(Client):
def __init__(self, config: Dict):
self.client = OpenAI(**config)

Expand Down Expand Up @@ -372,8 +326,11 @@ def _client(self, config, openai_config):
"""
openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
self._process_for_azure(openai_config, config)
if "local_model" in config:
return RLClient(config)
if "custom_client" in config and "custom_client_code_path" in config:
custom_client = config["custom_client"]
custom_client_code_path = config["custom_client_code_path"]
CustomClient = import_class_from_path(custom_client_code_path, custom_client)
return CustomClient(config)

Check warning on line 333 in autogen/oai/client.py

View check run for this annotation

Codecov / codecov/patch

autogen/oai/client.py#L330-L333

Added lines #L330 - L333 were not covered by tests
else:
return OpenAIClient(openai_config)

Expand Down

0 comments on commit 473b280

Please sign in to comment.