diff --git a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/ServerlessNetworkPolicyUpdater.java b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdater.java similarity index 99% rename from data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/ServerlessNetworkPolicyUpdater.java rename to data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdater.java index ef427b0026..183ffd3419 100644 --- a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/ServerlessNetworkPolicyUpdater.java +++ b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdater.java @@ -1,4 +1,4 @@ -package org.opensearch.dataprepper.plugins.sink.opensearch; +package org.opensearch.dataprepper.plugins.common.opensearch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdaterFactory.java b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdaterFactory.java new file mode 100644 index 0000000000..a4a227cf98 --- /dev/null +++ b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdaterFactory.java @@ -0,0 +1,58 @@ +package org.opensearch.dataprepper.plugins.common.opensearch; + +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.sink.opensearch.ConnectionConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.AwsAuthenticationConfiguration; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.core.retry.backoff.FullJitterBackoffStrategy; +import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient; + +import java.time.Duration; + +public class ServerlessNetworkPolicyUpdaterFactory { + public static ServerlessNetworkPolicyUpdater create( + final AwsCredentialsSupplier awsCredentialsSupplier, + final ConnectionConfiguration connectionConfiguration + ) { + return new ServerlessNetworkPolicyUpdater(getOpenSearchServerlessClient( + awsCredentialsSupplier, connectionConfiguration.createAwsCredentialsOptions() + )); + } + + public static ServerlessNetworkPolicyUpdater create( + final AwsCredentialsSupplier awsCredentialsSupplier, + final AwsAuthenticationConfiguration awsConfig + ) { + final AwsCredentialsOptions awsCredentialsOptions = AwsCredentialsOptions.builder() + .withRegion(awsConfig.getAwsRegion()) + .withStsRoleArn(awsConfig.getAwsStsRoleArn()) + .withStsExternalId(awsConfig.getAwsStsExternalId()) + .withStsHeaderOverrides(awsConfig.getAwsStsHeaderOverrides()) + .build(); + + return new ServerlessNetworkPolicyUpdater(getOpenSearchServerlessClient( + awsCredentialsSupplier, awsCredentialsOptions + )); + } + + private static OpenSearchServerlessClient getOpenSearchServerlessClient( + final AwsCredentialsSupplier awsCredentialsSupplier, + final AwsCredentialsOptions awsCredentialsOptions + ) { + return OpenSearchServerlessClient.builder() + .credentialsProvider(awsCredentialsSupplier.getProvider(awsCredentialsOptions)) + .region(awsCredentialsOptions.getRegion()) + .overrideConfiguration(ClientOverrideConfiguration.builder() + .retryPolicy(RetryPolicy.builder() + .backoffStrategy(FullJitterBackoffStrategy.builder() + .baseDelay(Duration.ofSeconds(10)) + .maxBackoffTime(Duration.ofSeconds(60)) + .build()) + .numRetries(10) + .build()) + .build()) + .build(); + } +} diff --git a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessOptionsFactory.java b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessOptionsFactory.java new file mode 100644 index 0000000000..5204ca6575 --- /dev/null +++ b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessOptionsFactory.java @@ -0,0 +1,44 @@ +package org.opensearch.dataprepper.plugins.common.opensearch; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.AwsAuthenticationConfiguration; +import org.opensearch.dataprepper.plugins.sink.opensearch.ConnectionConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ServerlessOptions; + +import java.util.Optional; + +public class ServerlessOptionsFactory { + + public static Optional create(final ConnectionConfiguration connectionConfiguration) { + if (!connectionConfiguration.isServerless() || + StringUtils.isBlank(connectionConfiguration.getServerlessNetworkPolicyName()) || + StringUtils.isBlank(connectionConfiguration.getServerlessCollectionName()) || + StringUtils.isBlank(connectionConfiguration.getServerlessVpceId()) + ) { + return Optional.empty(); + } + + return Optional.of(new ServerlessOptions( + connectionConfiguration.getServerlessNetworkPolicyName(), + connectionConfiguration.getServerlessCollectionName(), + connectionConfiguration.getServerlessVpceId())); + } + + public static Optional create(final AwsAuthenticationConfiguration awsConfig) { + if (awsConfig == null || !awsConfig.isServerlessCollection()) { + return Optional.empty(); + } + + final ServerlessOptions serverlessOptions = awsConfig.getServerlessOptions(); + if (serverlessOptions == null || + StringUtils.isBlank(serverlessOptions.getNetworkPolicyName()) || + StringUtils.isBlank(serverlessOptions.getCollectionName()) || + StringUtils.isBlank(serverlessOptions.getVpceId()) + ) { + return Optional.empty(); + } + + return Optional.of(serverlessOptions); + } + +} diff --git a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/ConnectionConfiguration.java b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/ConnectionConfiguration.java index 5600b699f6..4fb6b60bcb 100644 --- a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/ConnectionConfiguration.java +++ b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/ConnectionConfiguration.java @@ -36,14 +36,10 @@ import software.amazon.awssdk.arns.Arn; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.signer.Aws4Signer; -import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; -import software.amazon.awssdk.core.retry.RetryPolicy; -import software.amazon.awssdk.core.retry.backoff.FullJitterBackoffStrategy; import software.amazon.awssdk.http.SdkHttpClient; import software.amazon.awssdk.http.apache.ApacheHttpClient; import software.amazon.awssdk.http.apache.ProxyConfiguration; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; @@ -162,19 +158,19 @@ Integer getConnectTimeout() { return connectTimeout; } - boolean isServerless() { + public boolean isServerless() { return serverless; } - String getServerlessNetworkPolicyName() { + public String getServerlessNetworkPolicyName() { return serverlessNetworkPolicyName; } - String getServerlessCollectionName() { + public String getServerlessCollectionName() { return serverlessCollectionName; } - String getServerlessVpceId() { + public String getServerlessVpceId() { return serverlessVpceId; } @@ -333,7 +329,7 @@ private void attachSigV4(final RestClientBuilder restClientBuilder, AwsCredentia }); } - private AwsCredentialsOptions createAwsCredentialsOptions() { + public AwsCredentialsOptions createAwsCredentialsOptions() { final AwsCredentialsOptions awsCredentialsOptions = AwsCredentialsOptions.builder() .withStsRoleArn(awsStsRoleArn) .withStsExternalId(awsStsExternalId) @@ -440,24 +436,6 @@ private OpenSearchTransport createOpenSearchTransport(final RestHighLevelClient } } - public OpenSearchServerlessClient createOpenSearchServerlessClient(final AwsCredentialsSupplier awsCredentialsSupplier) { - final AwsCredentialsOptions awsCredentialsOptions = createAwsCredentialsOptions(); - - return OpenSearchServerlessClient.builder() - .credentialsProvider(awsCredentialsSupplier.getProvider(awsCredentialsOptions)) - .region(Region.of(awsRegion)) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .retryPolicy(RetryPolicy.builder() - .backoffStrategy(FullJitterBackoffStrategy.builder() - .baseDelay(Duration.ofSeconds(10)) - .maxBackoffTime(Duration.ofSeconds(60)) - .build()) - .numRetries(10) - .build()) - .build()) - .build(); - } - private SdkHttpClient createSdkHttpClient() { ApacheHttpClient.Builder apacheHttpClientBuilder = ApacheHttpClient.builder(); if (connectTimeout != null) { diff --git a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/OpenSearchSink.java b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/OpenSearchSink.java index 2f3621496d..f7de43a527 100644 --- a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/OpenSearchSink.java +++ b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/sink/opensearch/OpenSearchSink.java @@ -40,6 +40,9 @@ import org.opensearch.dataprepper.model.sink.AbstractSink; import org.opensearch.dataprepper.model.sink.Sink; import org.opensearch.dataprepper.model.sink.SinkContext; +import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessOptionsFactory; +import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater; +import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdaterFactory; import org.opensearch.dataprepper.plugins.dlq.DlqProvider; import org.opensearch.dataprepper.plugins.dlq.DlqWriter; import org.opensearch.dataprepper.plugins.sink.opensearch.bulk.AccumulatingBulkRequest; @@ -60,9 +63,9 @@ import org.opensearch.dataprepper.plugins.sink.opensearch.index.IndexTemplateAPIWrapperFactory; import org.opensearch.dataprepper.plugins.sink.opensearch.index.IndexType; import org.opensearch.dataprepper.plugins.sink.opensearch.index.TemplateStrategy; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ServerlessOptions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient; import java.io.BufferedWriter; import java.io.IOException; @@ -549,18 +552,18 @@ public void shutdown() { } private void maybeUpdateServerlessNetworkPolicy() { - final ConnectionConfiguration connectionConfiguration = openSearchSinkConfig.getConnectionConfiguration(); - if (connectionConfiguration.isServerless() && - !StringUtils.isBlank(connectionConfiguration.getServerlessNetworkPolicyName()) && - !StringUtils.isBlank(connectionConfiguration.getServerlessCollectionName()) && - !StringUtils.isBlank(connectionConfiguration.getServerlessVpceId()) - ) { - final OpenSearchServerlessClient openSearchServerlessClient = connectionConfiguration.createOpenSearchServerlessClient(awsCredentialsSupplier); - final ServerlessNetworkPolicyUpdater networkPolicyUpdater = new ServerlessNetworkPolicyUpdater(openSearchServerlessClient); + final Optional maybeServerlessOptions = ServerlessOptionsFactory.create( + openSearchSinkConfig.getConnectionConfiguration()); + + if (maybeServerlessOptions.isPresent()) { + final ServerlessNetworkPolicyUpdater networkPolicyUpdater = ServerlessNetworkPolicyUpdaterFactory.create( + awsCredentialsSupplier, openSearchSinkConfig.getConnectionConfiguration() + ); networkPolicyUpdater.updateNetworkPolicy( - connectionConfiguration.getServerlessNetworkPolicyName(), - connectionConfiguration.getServerlessCollectionName(), - connectionConfiguration.getServerlessVpceId()); + maybeServerlessOptions.get().getNetworkPolicyName(), + maybeServerlessOptions.get().getCollectionName(), + maybeServerlessOptions.get().getVpceId() + ); } } diff --git a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchService.java b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchService.java index 4c124afd92..6b512324f1 100644 --- a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchService.java +++ b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchService.java @@ -120,6 +120,5 @@ public void stop() { LOG.error("Interrupted while waiting for the search worker to terminate", e); scheduledExecutorService.shutdownNow(); } - } } diff --git a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java index eaf3804c95..d4c91b57ef 100644 --- a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java +++ b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java @@ -16,12 +16,17 @@ import org.opensearch.dataprepper.model.source.coordinator.SourceCoordinator; import org.opensearch.dataprepper.model.source.coordinator.UsesSourceCoordination; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessOptionsFactory; +import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ServerlessOptions; +import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdaterFactory; import org.opensearch.dataprepper.plugins.source.opensearch.metrics.OpenSearchSourcePluginMetrics; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.OpenSearchClientFactory; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessor; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessorStrategy; import java.util.Objects; +import java.util.Optional; @DataPrepperPlugin(name="opensearch", pluginType = Source.class, pluginConfigurationType = OpenSearchSourceConfiguration.class) public class OpenSearchSource implements Source>, UsesSourceCoordination { @@ -55,12 +60,12 @@ public void start(final Buffer> buffer) { if (buffer == null) { throw new IllegalStateException("Buffer provided is null"); } + maybeUpdateNetworkPolicy(); startProcess(openSearchSourceConfiguration, buffer); } private void startProcess(final OpenSearchSourceConfiguration openSearchSourceConfiguration, final Buffer> buffer) { - final OpenSearchClientFactory openSearchClientFactory = OpenSearchClientFactory.create(awsCredentialsSupplier); final OpenSearchSourcePluginMetrics openSearchSourcePluginMetrics = OpenSearchSourcePluginMetrics.create(pluginMetrics); final SearchAccessorStrategy searchAccessorStrategy = SearchAccessorStrategy.create( @@ -93,4 +98,20 @@ public void setSourceCoordinator(final SourceCoordinator sourceCoordinato public Class getPartitionProgressStateClass() { return OpenSearchIndexProgressState.class; } + + // VisibleForTesting + void maybeUpdateNetworkPolicy() { + final Optional maybeServerlessOptions = ServerlessOptionsFactory.create( + openSearchSourceConfiguration.getAwsAuthenticationOptions()); + if (maybeServerlessOptions.isPresent()) { + final ServerlessNetworkPolicyUpdater networkPolicyUpdater = ServerlessNetworkPolicyUpdaterFactory.create( + awsCredentialsSupplier, openSearchSourceConfiguration.getAwsAuthenticationOptions() + ); + networkPolicyUpdater.updateNetworkPolicy( + maybeServerlessOptions.get().getNetworkPolicyName(), + maybeServerlessOptions.get().getCollectionName(), + maybeServerlessOptions.get().getVpceId() + ); + } + } } diff --git a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfiguration.java b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfiguration.java index 5aa4b4cdd2..c03a5c8c73 100644 --- a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfiguration.java +++ b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/AwsAuthenticationConfiguration.java @@ -32,6 +32,9 @@ public class AwsAuthenticationConfiguration { @JsonProperty("serverless") private Boolean serverless = false; + @JsonProperty("serverless_options") + private ServerlessOptions serverlessOptions; + public String getAwsStsRoleArn() { return awsStsRoleArn; } @@ -51,5 +54,9 @@ public Map getAwsStsHeaderOverrides() { public Boolean isServerlessCollection() { return serverless; } + + public ServerlessOptions getServerlessOptions() { + return serverlessOptions; + } } diff --git a/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/ServerlessOptions.java b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/ServerlessOptions.java new file mode 100644 index 0000000000..e23c8c2967 --- /dev/null +++ b/data-prepper-plugins/opensearch/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/configuration/ServerlessOptions.java @@ -0,0 +1,39 @@ +package org.opensearch.dataprepper.plugins.source.opensearch.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class ServerlessOptions { + + @JsonProperty("network_policy_name") + private String networkPolicyName; + + @JsonProperty("collection_name") + private String collectionName; + + @JsonProperty("vpce_id") + private String vpceId; + + public ServerlessOptions() { + + } + + public ServerlessOptions(String networkPolicyName, String collectionName, String vpceId) { + this.networkPolicyName = networkPolicyName; + this.collectionName = collectionName; + this.vpceId = vpceId; + } + + public String getNetworkPolicyName() { + return networkPolicyName; + } + + public String getCollectionName() { + return collectionName; + } + + public String getVpceId() { + return vpceId; + } + +} + diff --git a/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdaterFactoryTest.java b/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdaterFactoryTest.java new file mode 100644 index 0000000000..de092d1d66 --- /dev/null +++ b/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdaterFactoryTest.java @@ -0,0 +1,96 @@ +package org.opensearch.dataprepper.plugins.common.opensearch; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.sink.opensearch.ConnectionConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.AwsAuthenticationConfiguration; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient; +import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClientBuilder; + +import java.util.Collections; +import java.util.UUID; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +class ServerlessNetworkPolicyUpdaterFactoryTest { + + private AwsCredentialsSupplier mockAwsCredentialsSupplier; + private AwsCredentialsProvider mockAwsCredentialsProvider; + private ConnectionConfiguration mockConnectionConfiguration; + private AwsAuthenticationConfiguration mockAwsAuthenticationConfiguration; + + @BeforeEach + void setUp() { + // Mock dependencies + mockAwsCredentialsSupplier = mock(AwsCredentialsSupplier.class); + mockAwsCredentialsProvider = mock(AwsCredentialsProvider.class); + mockConnectionConfiguration = mock(ConnectionConfiguration.class); + mockAwsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class); + } + + @Test + void testCreateWithConnectionConfiguration() { + try (MockedStatic mockedClient = Mockito.mockStatic(OpenSearchServerlessClient.class)) { + // Mock the OpenSearchServerlessClient builder and its methods + OpenSearchServerlessClientBuilder builderMock = mock(OpenSearchServerlessClientBuilder.class); + OpenSearchServerlessClient clientMock = mock(OpenSearchServerlessClient.class); + + mockedClient.when(OpenSearchServerlessClient::builder).thenReturn(builderMock); + when(builderMock.credentialsProvider(any(AwsCredentialsProvider.class))).thenReturn(builderMock); + when(builderMock.region(any(Region.class))).thenReturn(builderMock); + when(builderMock.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(builderMock); + when(builderMock.build()).thenReturn(clientMock); + + when(mockAwsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class))).thenReturn(mockAwsCredentialsProvider); + when(mockConnectionConfiguration.createAwsCredentialsOptions()).thenReturn(AwsCredentialsOptions.builder() + .withRegion(Region.AP_EAST_1) + .withStsRoleArn(UUID.randomUUID().toString()) + .withStsExternalId(UUID.randomUUID().toString()) + .withStsHeaderOverrides(Collections.emptyMap()) + .build()); + + // Call the method under test + ServerlessNetworkPolicyUpdater updater = ServerlessNetworkPolicyUpdaterFactory.create( + mockAwsCredentialsSupplier, + mockConnectionConfiguration + ); + } + } + + @Test + void testCreateWithAwsAuthenticationConfiguration() { + try (MockedStatic mockedClient = Mockito.mockStatic(OpenSearchServerlessClient.class)) { + // Mock the OpenSearchServerlessClient builder and its methods + OpenSearchServerlessClientBuilder builderMock = mock(OpenSearchServerlessClientBuilder.class); + OpenSearchServerlessClient clientMock = mock(OpenSearchServerlessClient.class); + + mockedClient.when(OpenSearchServerlessClient::builder).thenReturn(builderMock); + when(builderMock.credentialsProvider(any(AwsCredentialsProvider.class))).thenReturn(builderMock); + when(builderMock.region(any(Region.class))).thenReturn(builderMock); + when(builderMock.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(builderMock); + when(builderMock.build()).thenReturn(clientMock); + + when(mockAwsCredentialsSupplier.getProvider(any(AwsCredentialsOptions.class))).thenReturn(mockAwsCredentialsProvider); + when(mockAwsAuthenticationConfiguration.getAwsRegion()).thenReturn(Region.AF_SOUTH_1); + when(mockAwsAuthenticationConfiguration.getAwsStsRoleArn()).thenReturn(UUID.randomUUID().toString()); + when(mockAwsAuthenticationConfiguration.getAwsStsExternalId()).thenReturn(UUID.randomUUID().toString()); + when(mockAwsAuthenticationConfiguration.getAwsStsHeaderOverrides()).thenReturn(Collections.emptyMap()); + + // Call the method under test + ServerlessNetworkPolicyUpdater updater = ServerlessNetworkPolicyUpdaterFactory.create( + mockAwsCredentialsSupplier, + mockAwsAuthenticationConfiguration + ); + + } + } +} \ No newline at end of file diff --git a/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/sink/opensearch/ServerlessNetworkPolicyUpdaterTest.java b/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdaterTest.java similarity index 95% rename from data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/sink/opensearch/ServerlessNetworkPolicyUpdaterTest.java rename to data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdaterTest.java index 7fed6b523a..21b5f21916 100644 --- a/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/sink/opensearch/ServerlessNetworkPolicyUpdaterTest.java +++ b/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessNetworkPolicyUpdaterTest.java @@ -1,4 +1,4 @@ -package org.opensearch.dataprepper.plugins.sink.opensearch; +package org.opensearch.dataprepper.plugins.common.opensearch; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -25,13 +25,13 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.dataprepper.plugins.sink.opensearch.ServerlessNetworkPolicyUpdater.COLLECTION; -import static org.opensearch.dataprepper.plugins.sink.opensearch.ServerlessNetworkPolicyUpdater.CREATED_BY_DATA_PREPPER; -import static org.opensearch.dataprepper.plugins.sink.opensearch.ServerlessNetworkPolicyUpdater.DESCRIPTION; -import static org.opensearch.dataprepper.plugins.sink.opensearch.ServerlessNetworkPolicyUpdater.RESOURCE; -import static org.opensearch.dataprepper.plugins.sink.opensearch.ServerlessNetworkPolicyUpdater.RESOURCE_TYPE; -import static org.opensearch.dataprepper.plugins.sink.opensearch.ServerlessNetworkPolicyUpdater.RULES; -import static org.opensearch.dataprepper.plugins.sink.opensearch.ServerlessNetworkPolicyUpdater.SOURCE_VPCES; +import static org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater.COLLECTION; +import static org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater.CREATED_BY_DATA_PREPPER; +import static org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater.DESCRIPTION; +import static org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater.RESOURCE; +import static org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater.RESOURCE_TYPE; +import static org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater.RULES; +import static org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater.SOURCE_VPCES; public class ServerlessNetworkPolicyUpdaterTest { diff --git a/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessOptionsFactoryTest.java b/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessOptionsFactoryTest.java new file mode 100644 index 0000000000..6e55355f42 --- /dev/null +++ b/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/common/opensearch/ServerlessOptionsFactoryTest.java @@ -0,0 +1,117 @@ +package org.opensearch.dataprepper.plugins.common.opensearch; + +import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.AwsAuthenticationConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ServerlessOptions; +import org.opensearch.dataprepper.plugins.sink.opensearch.ConnectionConfiguration; + +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +class ServerlessOptionsFactoryTest { + + @Test + void getOptionsShouldReturnEmptyForNonServerlessConnection() { + ConnectionConfiguration connectionConfiguration = mock(ConnectionConfiguration.class); + when(connectionConfiguration.isServerless()).thenReturn(false); + + Optional result = ServerlessOptionsFactory.create(connectionConfiguration); + + assertFalse(result.isPresent()); + } + + @Test + void getOptionsShouldReturnEmptyForBlankValuesInConnection() { + ConnectionConfiguration connectionConfiguration = mock(ConnectionConfiguration.class); + when(connectionConfiguration.isServerless()).thenReturn(true); + when(connectionConfiguration.getServerlessNetworkPolicyName()).thenReturn(" "); + when(connectionConfiguration.getServerlessCollectionName()).thenReturn(" "); + when(connectionConfiguration.getServerlessVpceId()).thenReturn(" "); + + Optional result = ServerlessOptionsFactory.create(connectionConfiguration); + + assertFalse(result.isPresent()); + } + + @Test + void getOptionsShouldReturnNonEmptyForValidConnectionConfiguration() { + ConnectionConfiguration connectionConfiguration = mock(ConnectionConfiguration.class); + when(connectionConfiguration.isServerless()).thenReturn(true); + when(connectionConfiguration.getServerlessNetworkPolicyName()).thenReturn("policyName"); + when(connectionConfiguration.getServerlessCollectionName()).thenReturn("collectionName"); + when(connectionConfiguration.getServerlessVpceId()).thenReturn("vpceId"); + + Optional result = ServerlessOptionsFactory.create(connectionConfiguration); + + assertTrue(result.isPresent()); + result.ifPresent(options -> { + assertEquals("policyName", options.getNetworkPolicyName()); + assertEquals("collectionName", options.getCollectionName()); + assertEquals("vpceId", options.getVpceId()); + }); + } + + @Test + void getOptionsShouldReturnEmptyForNullAwsConfig() { + Optional result = ServerlessOptionsFactory.create((AwsAuthenticationConfiguration) null); + + assertFalse(result.isPresent()); + } + + @Test + void getOptionsShouldReturnEmptyForNonServerlessAwsConfig() { + AwsAuthenticationConfiguration awsConfig = mock(AwsAuthenticationConfiguration.class); + when(awsConfig.isServerlessCollection()).thenReturn(false); + + Optional result = ServerlessOptionsFactory.create(awsConfig); + + assertFalse(result.isPresent()); + } + + @Test + void getOptionsShouldReturnEmptyForNullServerlessOptionsInAwsConfig() { + AwsAuthenticationConfiguration awsConfig = mock(AwsAuthenticationConfiguration.class); + when(awsConfig.isServerlessCollection()).thenReturn(true); + when(awsConfig.getServerlessOptions()).thenReturn(null); + + Optional result = ServerlessOptionsFactory.create(awsConfig); + + assertFalse(result.isPresent()); + } + + @Test + void getOptionsShouldReturnEmptyForBlankValuesInServerlessOptions() { + AwsAuthenticationConfiguration awsConfig = mock(AwsAuthenticationConfiguration.class); + ServerlessOptions serverlessOptions = new ServerlessOptions(" ", " ", " "); + + when(awsConfig.isServerlessCollection()).thenReturn(true); + when(awsConfig.getServerlessOptions()).thenReturn(serverlessOptions); + + Optional result = ServerlessOptionsFactory.create(awsConfig); + + assertFalse(result.isPresent()); + } + + @Test + void getOptionsShouldReturnNonEmptyForValidAwsConfig() { + AwsAuthenticationConfiguration awsConfig = mock(AwsAuthenticationConfiguration.class); + ServerlessOptions serverlessOptions = new ServerlessOptions("policyName", "collectionName", "vpceId"); + + when(awsConfig.isServerlessCollection()).thenReturn(true); + when(awsConfig.getServerlessOptions()).thenReturn(serverlessOptions); + + Optional result = ServerlessOptionsFactory.create(awsConfig); + + assertTrue(result.isPresent()); + result.ifPresent(options -> { + assertEquals("policyName", options.getNetworkPolicyName()); + assertEquals("collectionName", options.getCollectionName()); + assertEquals("vpceId", options.getVpceId()); + }); + } +} diff --git a/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceTest.java b/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceTest.java index 4895ceb7cf..78e67cd98a 100644 --- a/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceTest.java +++ b/data-prepper-plugins/opensearch/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceTest.java @@ -18,14 +18,24 @@ import org.opensearch.dataprepper.model.plugin.PluginConfigObservable; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.source.coordinator.SourceCoordinator; +import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater; +import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdaterFactory; +import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessOptionsFactory; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.AwsAuthenticationConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ServerlessOptions; import org.opensearch.dataprepper.plugins.source.opensearch.metrics.OpenSearchSourcePluginMetrics; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.OpenSearchClientFactory; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessor; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessorStrategy; +import java.util.Optional; +import java.util.UUID; + import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -37,6 +47,9 @@ public class OpenSearchSourceTest { @Mock private OpenSearchSourceConfiguration openSearchSourceConfiguration; + @Mock + private AwsAuthenticationConfiguration awsAuthenticationConfiguration; + @Mock private OpenSearchService openSearchService; @@ -67,6 +80,9 @@ public class OpenSearchSourceTest { @Mock private PluginConfigObservable pluginConfigObservable; + @Mock + private ServerlessNetworkPolicyUpdater serverlessNetworkPolicyUpdater; + private OpenSearchSource createObjectUnderTest() { return new OpenSearchSource( openSearchSourceConfiguration, awsCredentialsSupplier, acknowledgementSetManager, @@ -80,7 +96,6 @@ void start_with_null_buffer_throws_IllegalStateException() { @Test void start_with_non_null_buffer_does_not_throw() { - when(searchAccessorStrategy.getSearchAccessor()).thenReturn(searchAccessor); doNothing().when(openSearchService).start(); @@ -90,7 +105,8 @@ void start_with_non_null_buffer_does_not_throw() { try (final MockedStatic searchAccessorStrategyMockedStatic = mockStatic(SearchAccessorStrategy.class); final MockedStatic openSearchClientFactoryMockedStatic = mockStatic(OpenSearchClientFactory.class); final MockedStatic openSearchSourcePluginMetricsMockedStatic = mockStatic(OpenSearchSourcePluginMetrics.class); - final MockedStatic openSearchServiceMockedStatic = mockStatic(OpenSearchService.class)) { + final MockedStatic openSearchServiceMockedStatic = mockStatic(OpenSearchService.class); + final MockedStatic serverlessOptionsFactoryMockedStatic = mockStatic(ServerlessOptionsFactory.class)) { openSearchClientFactoryMockedStatic.when(() -> OpenSearchClientFactory.create(awsCredentialsSupplier)).thenReturn(openSearchClientFactory); searchAccessorStrategyMockedStatic.when(() -> SearchAccessorStrategy.create( openSearchSourceConfiguration, openSearchClientFactory, pluginConfigObservable)).thenReturn(searchAccessorStrategy); @@ -99,7 +115,43 @@ void start_with_non_null_buffer_does_not_throw() { openSearchServiceMockedStatic.when(() -> OpenSearchService.createOpenSearchService(searchAccessor, sourceCoordinator, openSearchSourceConfiguration, buffer, acknowledgementSetManager, openSearchSourcePluginMetrics)) .thenReturn(openSearchService); + serverlessOptionsFactoryMockedStatic.when(() -> ServerlessOptionsFactory.create(openSearchSourceConfiguration.getAwsAuthenticationOptions())).thenReturn(Optional.empty()); + + objectUnderTest.start(buffer); + } + } + + @Test + void start_with_non_null_buffer_serverless_options_does_not_throw() { + when(searchAccessorStrategy.getSearchAccessor()).thenReturn(searchAccessor); + doNothing().when(openSearchService).start(); + + final OpenSearchSource objectUnderTest = createObjectUnderTest(); + objectUnderTest.setSourceCoordinator(sourceCoordinator); + + try (final MockedStatic searchAccessorStrategyMockedStatic = mockStatic(SearchAccessorStrategy.class); + final MockedStatic openSearchClientFactoryMockedStatic = mockStatic(OpenSearchClientFactory.class); + final MockedStatic openSearchSourcePluginMetricsMockedStatic = mockStatic(OpenSearchSourcePluginMetrics.class); + final MockedStatic openSearchServiceMockedStatic = mockStatic(OpenSearchService.class); + final MockedStatic serverlessOptionsFactoryMockedStatic = mockStatic(ServerlessOptionsFactory.class); + final MockedStatic serverlessNetworkPolicyUpdaterFactoryMockedStatic = mockStatic(ServerlessNetworkPolicyUpdaterFactory.class)) { + openSearchClientFactoryMockedStatic.when(() -> OpenSearchClientFactory.create(awsCredentialsSupplier)).thenReturn(openSearchClientFactory); + searchAccessorStrategyMockedStatic.when(() -> SearchAccessorStrategy.create( + openSearchSourceConfiguration, openSearchClientFactory, pluginConfigObservable)).thenReturn(searchAccessorStrategy); + openSearchSourcePluginMetricsMockedStatic.when(() -> OpenSearchSourcePluginMetrics.create(pluginMetrics)).thenReturn(openSearchSourcePluginMetrics); + + openSearchServiceMockedStatic.when(() -> OpenSearchService.createOpenSearchService(searchAccessor, sourceCoordinator, openSearchSourceConfiguration, buffer, acknowledgementSetManager, openSearchSourcePluginMetrics)) + .thenReturn(openSearchService); + + when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration); + serverlessOptionsFactoryMockedStatic.when(() -> ServerlessOptionsFactory.create(awsAuthenticationConfiguration)).thenReturn(Optional.of(new ServerlessOptions(UUID.randomUUID().toString(), UUID.randomUUID().toString(), UUID.randomUUID().toString()))); + serverlessNetworkPolicyUpdaterFactoryMockedStatic.when(() -> ServerlessNetworkPolicyUpdaterFactory.create(any(AwsCredentialsSupplier.class), any(AwsAuthenticationConfiguration.class))).thenReturn(serverlessNetworkPolicyUpdater); + doNothing().when(serverlessNetworkPolicyUpdater).updateNetworkPolicy(any(), any(), any()); + objectUnderTest.start(buffer); + + verify(serverlessNetworkPolicyUpdater).updateNetworkPolicy(any(), any(), any()); } } + }