Skip to content

Commit

Permalink
Add warnings for deprecated azure oai config changes (#4317)
Browse files Browse the repository at this point in the history
* Add warnings for deprecated azure oai config changes

* Update docs and usages, simplify capabilities
  • Loading branch information
jackgerrits authored Nov 25, 2024
1 parent 341417e commit b2ae4d1
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,12 @@
"token_provider = get_bearer_token_provider(DefaultAzureCredential(), \"https://cognitiveservices.azure.com/.default\")\n",
"\n",
"az_model_client = AzureOpenAIChatCompletionClient(\n",
" model=\"{your-azure-deployment}\",\n",
" azure_deployment=\"{your-azure-deployment}\",\n",
" model=\"{model-name, such as gpt-4o}\",\n",
" api_version=\"2024-06-01\",\n",
" azure_endpoint=\"https://{your-custom-endpoint}.openai.azure.com/\",\n",
" azure_ad_token_provider=token_provider, # Optional if you choose key-based authentication.\n",
" # api_key=\"sk-...\", # For key-based authentication.\n",
" model_capabilities={\n",
" \"vision\": True,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" },\n",
")"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,11 @@ token_provider = get_bearer_token_provider(
)

client = AzureOpenAIChatCompletionClient(
model="{your-azure-deployment}",
azure_deployment="{your-azure-deployment}",
model="{model-name, such as gpt-4o}",
api_version="2024-02-01",
azure_endpoint="https://{your-custom-endpoint}.openai.azure.com/",
azure_ad_token_provider=token_provider,
model_capabilities={
"vision":True,
"function_calling":True,
"json_output":True,
}
)
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -79,15 +79,11 @@
"\n",
"# Create the client with type-checked environment variables\n",
"client = AzureOpenAIChatCompletionClient(\n",
" model=get_env_variable(\"AZURE_OPENAI_DEPLOYMENT_NAME\"),\n",
" azure_deployment=get_env_variable(\"AZURE_OPENAI_DEPLOYMENT_NAME\"),\n",
" model=get_env_variable(\"AZURE_OPENAI_MODEL\"),\n",
" api_version=get_env_variable(\"AZURE_OPENAI_API_VERSION\"),\n",
" azure_endpoint=get_env_variable(\"AZURE_OPENAI_ENDPOINT\"),\n",
" api_key=get_env_variable(\"AZURE_OPENAI_API_KEY\"),\n",
" model_capabilities={\n",
" \"vision\": False,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" },\n",
")"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -294,16 +294,12 @@
"token_provider = get_bearer_token_provider(DefaultAzureCredential(), \"https://cognitiveservices.azure.com/.default\")\n",
"\n",
"az_model_client = AzureOpenAIChatCompletionClient(\n",
" model=\"{your-azure-deployment}\",\n",
" azure_deployment=\"{your-azure-deployment}\",\n",
" model=\"{model-name, such as gpt-4o}\",\n",
" api_version=\"2024-06-01\",\n",
" azure_endpoint=\"https://{your-custom-endpoint}.openai.azure.com/\",\n",
" azure_ad_token_provider=token_provider, # Optional if you choose key-based authentication.\n",
" # api_key=\"sk-...\", # For key-based authentication.\n",
" model_capabilities={\n",
" \"vision\": True,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" },\n",
")"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,29 @@ def _azure_openai_client_from_config(config: Mapping[str, Any]) -> AsyncAzureOpe
# Take a copy
copied_config = dict(config).copy()

import warnings

if "azure_deployment" not in copied_config and "model" in copied_config:
warnings.warn(
"Previous behavior of using the model name as the deployment name is deprecated and will be removed in 0.4. Please specify azure_deployment",
stacklevel=2,
)

if "azure_endpoint" not in copied_config and "base_url" in copied_config:
warnings.warn(
"Previous behavior of using the base_url as the endpoint is deprecated and will be removed in 0.4. Please specify azure_endpoint",
stacklevel=2,
)

# Do some fixups
copied_config["azure_deployment"] = copied_config.get("azure_deployment", config.get("model"))
if copied_config["azure_deployment"] is not None:
copied_config["azure_deployment"] = copied_config["azure_deployment"].replace(".", "")
if "." in copied_config["azure_deployment"]:
warnings.warn(
"Previous behavior stripping '.' from the deployment name is deprecated and will be removed in 0.4",
stacklevel=2,
)
copied_config["azure_deployment"] = copied_config["azure_deployment"].replace(".", "")
copied_config["azure_endpoint"] = copied_config.get("azure_endpoint", copied_config.pop("base_url", None))

# Shave down the config to just the AzureOpenAIChatCompletionClient kwargs
Expand Down Expand Up @@ -331,9 +350,7 @@ def __init__(
model_capabilities: Optional[ModelCapabilities] = None,
):
self._client = client
if model_capabilities is None and isinstance(client, AsyncAzureOpenAI):
raise ValueError("AzureOpenAIChatCompletionClient requires explicit model capabilities")
elif model_capabilities is None:
if model_capabilities is None:
self._model_capabilities = _model_info.get_capabilities(create_args["model"])
else:
self._model_capabilities = model_capabilities
Expand Down Expand Up @@ -963,7 +980,7 @@ class AzureOpenAIChatCompletionClient(BaseOpenAIChatCompletionClient):
api_version (str): The API version to use. **Required for Azure models.**
azure_ad_token (str): The Azure AD token to use. Provide this or `azure_ad_token_provider` for token-based authentication.
azure_ad_token_provider (Callable[[], Awaitable[str]]): The Azure AD token provider to use. Provide this or `azure_ad_token` for token-based authentication.
model_capabilities (ModelCapabilities): The capabilities of the model. **Required for Azure models.**
model_capabilities (ModelCapabilities): The capabilities of the model if default resolved values are not correct.
api_key (optional, str): The API key to use, use this if you are using key based authentication. It is optional if you are using Azure AD token based authentication or `AZURE_OPENAI_API_KEY` environment variable.
timeout (optional, int): The timeout for the request in seconds.
max_retries (optional, int): The maximum number of retries to attempt.
Expand All @@ -990,26 +1007,19 @@ class AzureOpenAIChatCompletionClient(BaseOpenAIChatCompletionClient):
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
az_model_client = AzureOpenAIChatCompletionClient(
model="{your-azure-deployment}",
azure_deployment="{your-azure-deployment}",
model="{deployed-model, such as 'gpt-4o'}",
api_version="2024-06-01",
azure_endpoint="https://{your-custom-endpoint}.openai.azure.com/",
azure_ad_token_provider=token_provider, # Optional if you choose key-based authentication.
# api_key="sk-...", # For key-based authentication. `AZURE_OPENAI_API_KEY` environment variable can also be used instead.
model_capabilities={
"vision": True,
"function_calling": True,
"json_output": True,
},
)
See `here <https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/managed-identity#chat-completions>`_ for how to use the Azure client directly or for more info.
"""

def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]):
if "model" not in kwargs:
raise ValueError("model is required for OpenAIChatCompletionClient")

model_capabilities: Optional[ModelCapabilities] = None
copied_args = dict(kwargs).copy()
if "model_capabilities" in kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ class BaseOpenAIClientConfiguration(CreateArguments, total=False):
api_key: str
timeout: Union[float, None]
max_retries: int
model_capabilities: ModelCapabilities
"""What functionality the model supports, determined by default from model name but is overriden if value passed."""


# See OpenAI docs for explanation of these parameters
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
organization: str
base_url: str
# Not required
model_capabilities: ModelCapabilities


class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
Expand All @@ -47,8 +47,6 @@ class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False)
api_version: Required[str]
azure_ad_token: str
azure_ad_token_provider: AsyncAzureADTokenProvider
# Must be provided
model_capabilities: Required[ModelCapabilities]


__all__ = ["AzureOpenAIClientConfiguration", "OpenAIClientConfiguration"]
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def test_openai_chat_completion_client() -> None:
@pytest.mark.asyncio
async def test_azure_openai_chat_completion_client() -> None:
client = AzureOpenAIChatCompletionClient(
azure_deployment="gpt-4o-1",
model="gpt-4o",
api_key="api_key",
api_version="2020-08-04",
Expand Down

0 comments on commit b2ae4d1

Please sign in to comment.