Skip to content

Commit

Permalink
Fix Azure endpoint properties and baseUrl computation
Browse files Browse the repository at this point in the history
  • Loading branch information
edeandrea committed Jan 29, 2024
1 parent 206a325 commit 8e9966f
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
* Represents an OpenAI language model, hosted on Azure, that has a chat completion interface, such as gpt-3.5-turbo.
* <p>
* Mandatory parameters for initialization are: {@code apiVersion}, {@code apiKey}, and either {@code endpoint} OR
* {@code resourceName} and {@code deploymentId}.
* {@code resourceName} and {@code deploymentName}.
* <p>
* There are two primary authentication methods to access Azure OpenAI:
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
* Represents an OpenAI embedding model, hosted on Azure, such as text-embedding-ada-002.
* <p>
* Mandatory parameters for initialization are: {@code apiVersion}, {@code apiKey}, and either {@code endpoint} OR
* {@code resourceName} and {@code deploymentId}.
* {@code resourceName} and {@code deploymentName}.
* <p>
* There are two primary authentication methods to access Azure OpenAI:
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
* The model's response is streamed token by token and should be handled with {@link StreamingResponseHandler}.
* <p>
* Mandatory parameters for initialization are: {@code apiVersion}, {@code apiKey}, and either {@code endpoint} OR
* {@code resourceName} and {@code deploymentId}.
* {@code resourceName} and {@code deploymentName}.
* <p>
* There are two primary authentication methods to access Azure OpenAI:
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import java.util.ArrayList;
import java.util.function.Supplier;

import io.quarkus.runtime.ShutdownContext;
import io.quarkus.runtime.annotations.Recorder;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
Expand All @@ -15,8 +18,6 @@
import io.quarkiverse.langchain4j.azure.openai.runtime.config.EmbeddingModelConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.Langchain4jAzureOpenAiConfig;
import io.quarkiverse.langchain4j.openai.QuarkusOpenAiClient;
import io.quarkus.runtime.ShutdownContext;
import io.quarkus.runtime.annotations.Recorder;
import io.smallrye.config.ConfigValidationException;
import io.smallrye.config.ConfigValidationException.Problem;

Expand Down Expand Up @@ -99,31 +100,32 @@ public EmbeddingModel get() {
}

static String getEndpoint(Langchain4jAzureOpenAiConfig runtimeConfig) {
return runtimeConfig.endpoint()
.map(String::trim)
.filter(endpoint -> !endpoint.isBlank())
.orElseGet(() -> constructEndpointFromConfig(runtimeConfig));
var endpoint = runtimeConfig.endpoint();

return (endpoint.isPresent() && !endpoint.get().trim().isBlank()) ?
endpoint.get() :
constructEndpointFromConfig(runtimeConfig);
}

private static String constructEndpointFromConfig(Langchain4jAzureOpenAiConfig runtimeConfig) {
var resourceName = runtimeConfig.resourceName();
var deploymentId = runtimeConfig.deploymentId();
var deploymentName = runtimeConfig.deploymentName();

if (resourceName.isEmpty() || deploymentId.isEmpty()) {
if (resourceName.isEmpty() || deploymentName.isEmpty()) {
var configProblems = new ArrayList<>();

if (resourceName.isEmpty()) {
configProblems.add(createConfigProblem("resource-name"));
}

if (deploymentId.isEmpty()) {
configProblems.add(createConfigProblem("deployment-id"));
if (deploymentName.isEmpty()) {
configProblems.add(createConfigProblem("deployment-name"));
}

throw new ConfigValidationException(configProblems.toArray(new Problem[configProblems.size()]));
}

return String.format(AZURE_ENDPOINT_URL_PATTERN, resourceName.get(), deploymentId.get());
return String.format(AZURE_ENDPOINT_URL_PATTERN, resourceName.get(), deploymentName.get());
}

private static ConfigValidationException.Problem createConfigProblem(String key) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,30 @@ public interface Langchain4jAzureOpenAiConfig {
/**
* The name of your Azure OpenAI Resource. You're required to first deploy a model before you can make calls.
* <p>
* This and {@code quarkus.langchain4j.azure-openai.deployment-id} are required if
* This and {@code quarkus.langchain4j.azure-openai.deployment-name} are required if
* {@code quarkus.langchain4j.azure-openai.endpoint} is not set.
* If {@code quarkus.langchain4j.azure-openai.endpoint} is not set then this is never read.
* </p>
*/
Optional<String> resourceName();

/**
* The id of your model deployment. You're required to first deploy a model before you can make calls.
* The name of your model deployment. You're required to first deploy a model before you can make calls.
* <p>
* This and {@code quarkus.langchain4j.azure-openai.resource-name} are required if
* {@code quarkus.langchain4j.azure-openai.endpoint} is not set.
* If {@code quarkus.langchain4j.azure-openai.endpoint} is not set then this is never read.
* </p>
*/
Optional<String> deploymentId();
Optional<String> deploymentName();

/**
* The endpoint for the Azure OpenAI resource.
* <p>
* If not specified, then {@code quarkus.langchain4j.azure-openai.resource-name} and
* {@code quarkus.langchain4j.azure-openai.deployment-id} are required.
* {@code quarkus.langchain4j.azure-openai.deployment-name} are required.
* In this case the endpoint will be set to
* {@code https://${quarkus.langchain4j.azure-openai.resource-name}.openai.azure.com/openai/deployments/${quarkus.langchain4j.azure-openai.deployment-id}}
* {@code https://${quarkus.langchain4j.azure-openai.resource-name}.openai.azure.com/openai/deployments/${quarkus.langchain4j.azure-openai.deployment-name}}
* </p>
*/
Optional<String> endpoint();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void noEndpointConfigSet() {
assertThat(configValidationException.getProblem(1))
.isNotNull()
.extracting(Problem::getMessage)
.isEqualTo(String.format(CONFIG_ERROR_MESSAGE_TEMPLATE, "deployment-id"));
.isEqualTo(String.format(CONFIG_ERROR_MESSAGE_TEMPLATE, "deployment-name"));
}

@Test
Expand All @@ -53,14 +53,14 @@ void onlyResourceNameSet() {
assertThat(configValidationException.getProblem(0))
.isNotNull()
.extracting(Problem::getMessage)
.isEqualTo(String.format(CONFIG_ERROR_MESSAGE_TEMPLATE, "deployment-id"));
.isEqualTo(String.format(CONFIG_ERROR_MESSAGE_TEMPLATE, "deployment-name"));
}

@Test
void onlyDeploymentIdSet() {
doReturn(Optional.of("deployment-id"))
void onlyDeploymentNameSet() {
doReturn(Optional.of("deployment-name"))
.when(this.config)
.deploymentId();
.deploymentName();

var configValidationException = catchThrowableOfType(() -> AzureOpenAiRecorder.getEndpoint(this.config),
ConfigValidationException.class);
Expand All @@ -86,18 +86,18 @@ void endpointSet() {
}

@Test
void resourceNameAndDeploymentIdSet() {
void resourceNameAndDeploymentNameSet() {
doReturn(Optional.of("resourceName"))
.when(this.config)
.resourceName();

doReturn(Optional.of("deploymentId"))
doReturn(Optional.of("deploymentName"))
.when(this.config)
.deploymentId();
.deploymentName();

assertThat(AzureOpenAiRecorder.getEndpoint(this.config))
.isNotNull()
.isEqualTo(String.format(AzureOpenAiRecorder.AZURE_ENDPOINT_URL_PATTERN, "resourceName", "deploymentId"));
.isEqualTo(String.format(AzureOpenAiRecorder.AZURE_ENDPOINT_URL_PATTERN, "resourceName", "deploymentName"));
}

static class Config implements Langchain4jAzureOpenAiConfig {
Expand All @@ -107,7 +107,7 @@ public Optional<String> resourceName() {
}

@Override
public Optional<String> deploymentId() {
public Optional<String> deploymentName() {
return Optional.empty();
}

Expand Down

0 comments on commit 8e9966f

Please sign in to comment.