Skip to content

Commit

Permalink
Fix Azure endpoint properties and baseUrl computation
Browse files Browse the repository at this point in the history
  • Loading branch information
edeandrea committed Jan 26, 2024
1 parent 8f886a3 commit 76ff91b
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 57 deletions.
16 changes: 16 additions & 0 deletions openai/azure-openai/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@
<artifactId>quarkus-langchain4j-openai-common</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
/**
* Represents an OpenAI language model, hosted on Azure, that has a chat completion interface, such as gpt-3.5-turbo.
* <p>
* Mandatory parameters for initialization are: baseUrl, apiVersion and apiKey.
* Mandatory parameters for initialization are: {@code apiVersion}, {@code apiKey}, and either {@code endpoint} OR
* {@code resourceName} and {@code deploymentId}.
* <p>
* There are two primary authentication methods to access Azure OpenAI:
* <p>
Expand All @@ -51,7 +52,7 @@ public class AzureOpenAiChatModel implements ChatLanguageModel, TokenCountEstima
private final Integer maxRetries;
private final Tokenizer tokenizer;

public AzureOpenAiChatModel(String baseUrl,
public AzureOpenAiChatModel(String endpoint,
String apiVersion,
String apiKey,
Tokenizer tokenizer,
Expand All @@ -69,7 +70,7 @@ public AzureOpenAiChatModel(String baseUrl,
timeout = getOrDefault(timeout, ofSeconds(60));

this.client = OpenAiClient.builder()
.baseUrl(ensureNotBlank(baseUrl, "baseUrl"))
.baseUrl(ensureNotBlank(endpoint, "endpoint"))
.azureApiKey(apiKey)
.apiVersion(apiVersion)
.callTimeout(timeout)
Expand Down Expand Up @@ -143,7 +144,7 @@ public static Builder builder() {

public static class Builder {

private String baseUrl;
private String endpoint;
private String apiVersion;
private String apiKey;
private Tokenizer tokenizer;
Expand All @@ -159,14 +160,14 @@ public static class Builder {
private Boolean logResponses;

/**
* Sets the Azure OpenAI base URL. This is a mandatory parameter.
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
*
* @param baseUrl The Azure OpenAI base URL in the format:
* https://{resource}.openai.azure.com/openai/deployments/{deployment}
* @param endpoint The Azure OpenAI endpoint in the format:
* https://{resource-name}.openai.azure.com/openai/deployments/{deployment-name}
* @return builder
*/
public Builder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
return this;
}

Expand Down Expand Up @@ -248,8 +249,7 @@ public Builder logResponses(Boolean logResponses) {
}

public AzureOpenAiChatModel build() {
return new AzureOpenAiChatModel(
baseUrl,
return new AzureOpenAiChatModel(endpoint,
apiVersion,
apiKey,
tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
/**
* Represents an OpenAI embedding model, hosted on Azure, such as text-embedding-ada-002.
* <p>
* Mandatory parameters for initialization are: baseUrl, apiVersion and apiKey.
* Mandatory parameters for initialization are: {@code apiVersion}, {@code apiKey}, and either {@code endpoint} OR
* {@code resourceName} and {@code deploymentId}.
* <p>
* There are two primary authentication methods to access Azure OpenAI:
* <p>
Expand All @@ -46,7 +47,7 @@ public class AzureOpenAiEmbeddingModel implements EmbeddingModel, TokenCountEsti
private final Integer maxRetries;
private final Tokenizer tokenizer;

public AzureOpenAiEmbeddingModel(String baseUrl,
public AzureOpenAiEmbeddingModel(String endpoint,
String apiVersion,
String apiKey,
Tokenizer tokenizer,
Expand All @@ -59,7 +60,7 @@ public AzureOpenAiEmbeddingModel(String baseUrl,
timeout = getOrDefault(timeout, ofSeconds(60));

this.client = OpenAiClient.builder()
.baseUrl(ensureNotBlank(baseUrl, "baseUrl"))
.baseUrl(ensureNotBlank(endpoint, "endpoint"))
.azureApiKey(apiKey)
.apiVersion(apiVersion)
.callTimeout(timeout)
Expand Down Expand Up @@ -130,7 +131,7 @@ public static Builder builder() {

public static class Builder {

private String baseUrl;
private String endpoint;
private String apiVersion;
private String apiKey;
private Tokenizer tokenizer;
Expand All @@ -141,14 +142,14 @@ public static class Builder {
private Boolean logResponses;

/**
* Sets the Azure OpenAI base URL. This is a mandatory parameter.
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
*
* @param baseUrl The Azure OpenAI base URL in the format:
* https://{resource}.openai.azure.com/openai/deployments/{deployment}
* @param endpoint The Azure OpenAI endpoint in the format:
* https://{resource-name}.openai.azure.com/openai/deployments/{deployment-id}
* @return builder
*/
public Builder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
return this;
}

Expand Down Expand Up @@ -205,8 +206,7 @@ public Builder logResponses(Boolean logResponses) {
}

public AzureOpenAiEmbeddingModel build() {
return new AzureOpenAiEmbeddingModel(
baseUrl,
return new AzureOpenAiEmbeddingModel(endpoint,
apiVersion,
apiKey,
tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
* Represents an OpenAI language model, hosted on Azure, that has a chat completion interface, such as gpt-3.5-turbo.
* The model's response is streamed token by token and should be handled with {@link StreamingResponseHandler}.
* <p>
* Mandatory parameters for initialization are: baseUrl, apiVersion and apiKey.
* Mandatory parameters for initialization are: {@code apiVersion}, {@code apiKey}, and either {@code endpoint} OR
* {@code resourceName} and {@code deploymentId}.
* <p>
* There are two primary authentication methods to access Azure OpenAI:
* <p>
Expand All @@ -56,7 +57,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatLanguageModel
private final Double frequencyPenalty;
private final Tokenizer tokenizer;

public AzureOpenAiStreamingChatModel(String baseUrl,
public AzureOpenAiStreamingChatModel(String endpoint,
String apiVersion,
String apiKey,
Tokenizer tokenizer,
Expand All @@ -73,7 +74,7 @@ public AzureOpenAiStreamingChatModel(String baseUrl,
timeout = getOrDefault(timeout, ofSeconds(60));

this.client = OpenAiClient.builder()
.baseUrl(ensureNotBlank(baseUrl, "baseUrl"))
.baseUrl(ensureNotBlank(endpoint, "endpoint"))
.azureApiKey(apiKey)
.apiVersion(apiVersion)
.callTimeout(timeout)
Expand Down Expand Up @@ -178,7 +179,7 @@ public static Builder builder() {

public static class Builder {

private String baseUrl;
private String endpoint;
private String apiVersion;
private String apiKey;
private Tokenizer tokenizer;
Expand All @@ -193,14 +194,14 @@ public static class Builder {
private Boolean logResponses;

/**
* Sets the Azure OpenAI base URL. This is a mandatory parameter.
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
*
* @param baseUrl The Azure OpenAI base URL in the format:
* https://{resource}.openai.azure.com/openai/deployments/{deployment}
* @param endpoint The Azure OpenAI endpoint in the format:
* https://{resource-name}.openai.azure.com/openai/deployments/{deployment-id}
* @return builder
*/
public Builder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
return this;
}

Expand Down Expand Up @@ -277,8 +278,7 @@ public Builder logResponses(Boolean logResponses) {
}

public AzureOpenAiStreamingChatModel build() {
return new AzureOpenAiStreamingChatModel(
baseUrl,
return new AzureOpenAiStreamingChatModel(endpoint,
apiVersion,
apiKey,
tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,33 @@

import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault;

import java.util.ArrayList;
import java.util.function.Supplier;

import io.quarkus.runtime.ShutdownContext;
import io.quarkus.runtime.annotations.Recorder;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiChatModel;
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiEmbeddingModel;
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiStreamingChatModel;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.ChatModelConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.EmbeddingModelConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.Langchain4jAzureOpenAiConfig;
import io.quarkiverse.langchain4j.openai.QuarkusOpenAiClient;
import io.quarkus.runtime.ShutdownContext;
import io.quarkus.runtime.annotations.Recorder;
import io.smallrye.config.ConfigValidationException;
import io.smallrye.config.ConfigValidationException.Problem;

@Recorder
public class AzureOpenAiRecorder {
static final String AZURE_ENDPOINT_URL_PATTERN = "https://%s.openai.azure.com/openai/deployments/%s";

public Supplier<?> chatModel(Langchain4jAzureOpenAiConfig runtimeConfig) {
public Supplier<ChatLanguageModel> chatModel(Langchain4jAzureOpenAiConfig runtimeConfig) {
ChatModelConfig chatModelConfig = runtimeConfig.chatModel();
var builder = AzureOpenAiChatModel.builder()
.baseUrl(getBaseUrl(runtimeConfig))
.endpoint(getEndpoint(runtimeConfig))
.apiKey(runtimeConfig.apiKey())
.apiVersion(runtimeConfig.apiVersion())
.timeout(runtimeConfig.timeout())
Expand All @@ -39,16 +47,16 @@ public Supplier<?> chatModel(Langchain4jAzureOpenAiConfig runtimeConfig) {

return new Supplier<>() {
@Override
public Object get() {
public ChatLanguageModel get() {
return builder.build();
}
};
}

public Supplier<?> streamingChatModel(Langchain4jAzureOpenAiConfig runtimeConfig) {
public Supplier<StreamingChatLanguageModel> streamingChatModel(Langchain4jAzureOpenAiConfig runtimeConfig) {
ChatModelConfig chatModelConfig = runtimeConfig.chatModel();
var builder = AzureOpenAiStreamingChatModel.builder()
.baseUrl(getBaseUrl(runtimeConfig))
.endpoint(getEndpoint(runtimeConfig))
.apiKey(runtimeConfig.apiKey())
.apiVersion(runtimeConfig.apiVersion())
.timeout(runtimeConfig.timeout())
Expand All @@ -66,16 +74,16 @@ public Supplier<?> streamingChatModel(Langchain4jAzureOpenAiConfig runtimeConfig

return new Supplier<>() {
@Override
public Object get() {
public StreamingChatLanguageModel get() {
return builder.build();
}
};
}

public Supplier<?> embeddingModel(Langchain4jAzureOpenAiConfig runtimeConfig) {
public Supplier<EmbeddingModel> embeddingModel(Langchain4jAzureOpenAiConfig runtimeConfig) {
EmbeddingModelConfig embeddingModelConfig = runtimeConfig.embeddingModel();
var builder = AzureOpenAiEmbeddingModel.builder()
.baseUrl(getBaseUrl(runtimeConfig))
.endpoint(getEndpoint(runtimeConfig))
.apiKey(runtimeConfig.apiKey())
.apiVersion(runtimeConfig.apiVersion())
.timeout(runtimeConfig.timeout())
Expand All @@ -85,18 +93,44 @@ public Supplier<?> embeddingModel(Langchain4jAzureOpenAiConfig runtimeConfig) {

return new Supplier<>() {
@Override
public Object get() {
public EmbeddingModel get() {
return builder.build();
}
};
}

private String getBaseUrl(Langchain4jAzureOpenAiConfig runtimeConfig) {
var baseUrl = runtimeConfig.baseUrl();
static String getEndpoint(Langchain4jAzureOpenAiConfig runtimeConfig) {
return runtimeConfig.endpoint()
.map(String::trim)
.filter(endpoint -> !endpoint.isBlank())
.orElseGet(() -> constructEndpointFromConfig(runtimeConfig));
}

private static String constructEndpointFromConfig(Langchain4jAzureOpenAiConfig runtimeConfig) {
var resourceName = runtimeConfig.resourceName();
var deploymentId = runtimeConfig.deploymentId();

if (resourceName.isEmpty() || deploymentId.isEmpty()) {
var configProblems = new ArrayList<>();

if (resourceName.isEmpty()) {
configProblems.add(createConfigProblem("resource-name"));
}

if (deploymentId.isEmpty()) {
configProblems.add(createConfigProblem("deployment-id"));
}

throw new ConfigValidationException(configProblems.toArray(new Problem[configProblems.size()]));
}

return String.format(AZURE_ENDPOINT_URL_PATTERN, resourceName.get(), deploymentId.get());
}

return !baseUrl.trim().isEmpty() ? baseUrl
: String.format("https://%s.openai.azure.com/openai/deployments/%s", runtimeConfig.resourceName(),
runtimeConfig.deploymentId());
private static ConfigValidationException.Problem createConfigProblem(String key) {
return new ConfigValidationException.Problem(String.format(
"SRCFG00014: The config property quarkus.langchain4j.azure-openai.%s is required but it could not be found in any config source",
key));
}

public void cleanUp(ShutdownContext shutdown) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,35 @@
public interface Langchain4jAzureOpenAiConfig {

/**
* The name of your Azure OpenAI Resource
* The name of your Azure OpenAI Resource. You're required to first deploy a model before you can make calls.
* <p>
* This and {@code quarkus.langchain4j.azure-openai.deployment-id} are required if
* {@code quarkus.langchain4j.azure-openai.endpoint} is not set.
* If {@code quarkus.langchain4j.azure-openai.endpoint} is not set then this is never read.
* </p>
*/
String resourceName();
Optional<String> resourceName();

/**
* The name of your model deployment. You're required to first deploy a model before you can make calls.
* The id of your model deployment. You're required to first deploy a model before you can make calls.
* <p>
* This and {@code quarkus.langchain4j.azure-openai.resource-name} are required if
* {@code quarkus.langchain4j.azure-openai.endpoint} is not set.
* If {@code quarkus.langchain4j.azure-openai.endpoint} is not set then this is never read.
* </p>
*/
String deploymentId();
Optional<String> deploymentId();

/**
* The base url for the Azure OpenAI resource. Defaults to
* {@code https://${quarkus.langchain4j.azure-openai.resource-name}.openai.azure.com/openai/deployments/${quarkus.langchain4j.azure-openai.deployment-id}}.
* The endpoint for the Azure OpenAI resource.
* <p>
* If not specified, then {@code quarkus.langchain4j.azure-openai.resource-name} and
* {@code quarkus.langchain4j.azure-openai.deployment-id} are required.
* In this case the endpoint will be set to
* {@code https://${quarkus.langchain4j.azure-openai.resource-name}.openai.azure.com/openai/deployments/${quarkus.langchain4j.azure-openai.deployment-id}}
* </p>
*/
@WithDefault("https://${quarkus.langchain4j.azure-openai.resource-name}.openai.azure.com/openai/deployments/${quarkus.langchain4j.azure-openai.deployment-id}")
String baseUrl();
Optional<String> endpoint();

/**
* The API version to use for this operation. This follows the YYYY-MM-DD format
Expand Down
Loading

0 comments on commit 76ff91b

Please sign in to comment.