Skip to content

Commit

Permalink
Create network policy for aoss source. (#3613)
Browse files Browse the repository at this point in the history
Signed-off-by: Adi Suresh <[email protected]>
  • Loading branch information
asuresh8 authored Nov 10, 2023
1 parent 2ea6edf commit c3c35da
Show file tree
Hide file tree
Showing 13 changed files with 466 additions and 52 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.opensearch.dataprepper.plugins.sink.opensearch;
package org.opensearch.dataprepper.plugins.common.opensearch;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package org.opensearch.dataprepper.plugins.common.opensearch;

import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.sink.opensearch.ConnectionConfiguration;
import org.opensearch.dataprepper.plugins.source.opensearch.configuration.AwsAuthenticationConfiguration;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.core.retry.backoff.FullJitterBackoffStrategy;
import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient;

import java.time.Duration;

public class ServerlessNetworkPolicyUpdaterFactory {
public static ServerlessNetworkPolicyUpdater create(
final AwsCredentialsSupplier awsCredentialsSupplier,
final ConnectionConfiguration connectionConfiguration
) {
return new ServerlessNetworkPolicyUpdater(getOpenSearchServerlessClient(
awsCredentialsSupplier, connectionConfiguration.createAwsCredentialsOptions()
));
}

public static ServerlessNetworkPolicyUpdater create(
final AwsCredentialsSupplier awsCredentialsSupplier,
final AwsAuthenticationConfiguration awsConfig
) {
final AwsCredentialsOptions awsCredentialsOptions = AwsCredentialsOptions.builder()
.withRegion(awsConfig.getAwsRegion())
.withStsRoleArn(awsConfig.getAwsStsRoleArn())
.withStsExternalId(awsConfig.getAwsStsExternalId())
.withStsHeaderOverrides(awsConfig.getAwsStsHeaderOverrides())
.build();

return new ServerlessNetworkPolicyUpdater(getOpenSearchServerlessClient(
awsCredentialsSupplier, awsCredentialsOptions
));
}

private static OpenSearchServerlessClient getOpenSearchServerlessClient(
final AwsCredentialsSupplier awsCredentialsSupplier,
final AwsCredentialsOptions awsCredentialsOptions
) {
return OpenSearchServerlessClient.builder()
.credentialsProvider(awsCredentialsSupplier.getProvider(awsCredentialsOptions))
.region(awsCredentialsOptions.getRegion())
.overrideConfiguration(ClientOverrideConfiguration.builder()
.retryPolicy(RetryPolicy.builder()
.backoffStrategy(FullJitterBackoffStrategy.builder()
.baseDelay(Duration.ofSeconds(10))
.maxBackoffTime(Duration.ofSeconds(60))
.build())
.numRetries(10)
.build())
.build())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package org.opensearch.dataprepper.plugins.common.opensearch;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.dataprepper.plugins.source.opensearch.configuration.AwsAuthenticationConfiguration;
import org.opensearch.dataprepper.plugins.sink.opensearch.ConnectionConfiguration;
import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ServerlessOptions;

import java.util.Optional;

public class ServerlessOptionsFactory {

public static Optional<ServerlessOptions> create(final ConnectionConfiguration connectionConfiguration) {
if (!connectionConfiguration.isServerless() ||
StringUtils.isBlank(connectionConfiguration.getServerlessNetworkPolicyName()) ||
StringUtils.isBlank(connectionConfiguration.getServerlessCollectionName()) ||
StringUtils.isBlank(connectionConfiguration.getServerlessVpceId())
) {
return Optional.empty();
}

return Optional.of(new ServerlessOptions(
connectionConfiguration.getServerlessNetworkPolicyName(),
connectionConfiguration.getServerlessCollectionName(),
connectionConfiguration.getServerlessVpceId()));
}

public static Optional<ServerlessOptions> create(final AwsAuthenticationConfiguration awsConfig) {
if (awsConfig == null || !awsConfig.isServerlessCollection()) {
return Optional.empty();
}

final ServerlessOptions serverlessOptions = awsConfig.getServerlessOptions();
if (serverlessOptions == null ||
StringUtils.isBlank(serverlessOptions.getNetworkPolicyName()) ||
StringUtils.isBlank(serverlessOptions.getCollectionName()) ||
StringUtils.isBlank(serverlessOptions.getVpceId())
) {
return Optional.empty();
}

return Optional.of(serverlessOptions);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,10 @@
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.core.retry.backoff.FullJitterBackoffStrategy;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.http.apache.ProxyConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient;

import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
Expand Down Expand Up @@ -162,19 +158,19 @@ Integer getConnectTimeout() {
return connectTimeout;
}

boolean isServerless() {
public boolean isServerless() {
return serverless;
}

String getServerlessNetworkPolicyName() {
public String getServerlessNetworkPolicyName() {
return serverlessNetworkPolicyName;
}

String getServerlessCollectionName() {
public String getServerlessCollectionName() {
return serverlessCollectionName;
}

String getServerlessVpceId() {
public String getServerlessVpceId() {
return serverlessVpceId;
}

Expand Down Expand Up @@ -333,7 +329,7 @@ private void attachSigV4(final RestClientBuilder restClientBuilder, AwsCredentia
});
}

private AwsCredentialsOptions createAwsCredentialsOptions() {
public AwsCredentialsOptions createAwsCredentialsOptions() {
final AwsCredentialsOptions awsCredentialsOptions = AwsCredentialsOptions.builder()
.withStsRoleArn(awsStsRoleArn)
.withStsExternalId(awsStsExternalId)
Expand Down Expand Up @@ -440,24 +436,6 @@ private OpenSearchTransport createOpenSearchTransport(final RestHighLevelClient
}
}

public OpenSearchServerlessClient createOpenSearchServerlessClient(final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsOptions awsCredentialsOptions = createAwsCredentialsOptions();

return OpenSearchServerlessClient.builder()
.credentialsProvider(awsCredentialsSupplier.getProvider(awsCredentialsOptions))
.region(Region.of(awsRegion))
.overrideConfiguration(ClientOverrideConfiguration.builder()
.retryPolicy(RetryPolicy.builder()
.backoffStrategy(FullJitterBackoffStrategy.builder()
.baseDelay(Duration.ofSeconds(10))
.maxBackoffTime(Duration.ofSeconds(60))
.build())
.numRetries(10)
.build())
.build())
.build();
}

private SdkHttpClient createSdkHttpClient() {
ApacheHttpClient.Builder apacheHttpClientBuilder = ApacheHttpClient.builder();
if (connectTimeout != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
import org.opensearch.dataprepper.model.sink.AbstractSink;
import org.opensearch.dataprepper.model.sink.Sink;
import org.opensearch.dataprepper.model.sink.SinkContext;
import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessOptionsFactory;
import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater;
import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdaterFactory;
import org.opensearch.dataprepper.plugins.dlq.DlqProvider;
import org.opensearch.dataprepper.plugins.dlq.DlqWriter;
import org.opensearch.dataprepper.plugins.sink.opensearch.bulk.AccumulatingBulkRequest;
Expand All @@ -60,9 +63,9 @@
import org.opensearch.dataprepper.plugins.sink.opensearch.index.IndexTemplateAPIWrapperFactory;
import org.opensearch.dataprepper.plugins.sink.opensearch.index.IndexType;
import org.opensearch.dataprepper.plugins.sink.opensearch.index.TemplateStrategy;
import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ServerlessOptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient;

import java.io.BufferedWriter;
import java.io.IOException;
Expand Down Expand Up @@ -549,18 +552,18 @@ public void shutdown() {
}

private void maybeUpdateServerlessNetworkPolicy() {
final ConnectionConfiguration connectionConfiguration = openSearchSinkConfig.getConnectionConfiguration();
if (connectionConfiguration.isServerless() &&
!StringUtils.isBlank(connectionConfiguration.getServerlessNetworkPolicyName()) &&
!StringUtils.isBlank(connectionConfiguration.getServerlessCollectionName()) &&
!StringUtils.isBlank(connectionConfiguration.getServerlessVpceId())
) {
final OpenSearchServerlessClient openSearchServerlessClient = connectionConfiguration.createOpenSearchServerlessClient(awsCredentialsSupplier);
final ServerlessNetworkPolicyUpdater networkPolicyUpdater = new ServerlessNetworkPolicyUpdater(openSearchServerlessClient);
final Optional<ServerlessOptions> maybeServerlessOptions = ServerlessOptionsFactory.create(
openSearchSinkConfig.getConnectionConfiguration());

if (maybeServerlessOptions.isPresent()) {
final ServerlessNetworkPolicyUpdater networkPolicyUpdater = ServerlessNetworkPolicyUpdaterFactory.create(
awsCredentialsSupplier, openSearchSinkConfig.getConnectionConfiguration()
);
networkPolicyUpdater.updateNetworkPolicy(
connectionConfiguration.getServerlessNetworkPolicyName(),
connectionConfiguration.getServerlessCollectionName(),
connectionConfiguration.getServerlessVpceId());
maybeServerlessOptions.get().getNetworkPolicyName(),
maybeServerlessOptions.get().getCollectionName(),
maybeServerlessOptions.get().getVpceId()
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,5 @@ public void stop() {
LOG.error("Interrupted while waiting for the search worker to terminate", e);
scheduledExecutorService.shutdownNow();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
import org.opensearch.dataprepper.model.source.coordinator.SourceCoordinator;
import org.opensearch.dataprepper.model.source.coordinator.UsesSourceCoordination;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessOptionsFactory;
import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdater;
import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ServerlessOptions;
import org.opensearch.dataprepper.plugins.common.opensearch.ServerlessNetworkPolicyUpdaterFactory;
import org.opensearch.dataprepper.plugins.source.opensearch.metrics.OpenSearchSourcePluginMetrics;
import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.OpenSearchClientFactory;
import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessor;
import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessorStrategy;

import java.util.Objects;
import java.util.Optional;

@DataPrepperPlugin(name="opensearch", pluginType = Source.class, pluginConfigurationType = OpenSearchSourceConfiguration.class)
public class OpenSearchSource implements Source<Record<Event>>, UsesSourceCoordination {
Expand Down Expand Up @@ -55,12 +60,12 @@ public void start(final Buffer<Record<Event>> buffer) {
if (buffer == null) {
throw new IllegalStateException("Buffer provided is null");
}
maybeUpdateNetworkPolicy();
startProcess(openSearchSourceConfiguration, buffer);
}

private void startProcess(final OpenSearchSourceConfiguration openSearchSourceConfiguration,
final Buffer<Record<Event>> buffer) {

final OpenSearchClientFactory openSearchClientFactory = OpenSearchClientFactory.create(awsCredentialsSupplier);
final OpenSearchSourcePluginMetrics openSearchSourcePluginMetrics = OpenSearchSourcePluginMetrics.create(pluginMetrics);
final SearchAccessorStrategy searchAccessorStrategy = SearchAccessorStrategy.create(
Expand Down Expand Up @@ -93,4 +98,20 @@ public <T> void setSourceCoordinator(final SourceCoordinator<T> sourceCoordinato
public Class<?> getPartitionProgressStateClass() {
return OpenSearchIndexProgressState.class;
}

// VisibleForTesting
void maybeUpdateNetworkPolicy() {
final Optional<ServerlessOptions> maybeServerlessOptions = ServerlessOptionsFactory.create(
openSearchSourceConfiguration.getAwsAuthenticationOptions());
if (maybeServerlessOptions.isPresent()) {
final ServerlessNetworkPolicyUpdater networkPolicyUpdater = ServerlessNetworkPolicyUpdaterFactory.create(
awsCredentialsSupplier, openSearchSourceConfiguration.getAwsAuthenticationOptions()
);
networkPolicyUpdater.updateNetworkPolicy(
maybeServerlessOptions.get().getNetworkPolicyName(),
maybeServerlessOptions.get().getCollectionName(),
maybeServerlessOptions.get().getVpceId()
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ public class AwsAuthenticationConfiguration {
@JsonProperty("serverless")
private Boolean serverless = false;

@JsonProperty("serverless_options")
private ServerlessOptions serverlessOptions;

public String getAwsStsRoleArn() {
return awsStsRoleArn;
}
Expand All @@ -51,5 +54,9 @@ public Map<String, String> getAwsStsHeaderOverrides() {
public Boolean isServerlessCollection() {
return serverless;
}

public ServerlessOptions getServerlessOptions() {
return serverlessOptions;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.opensearch.dataprepper.plugins.source.opensearch.configuration;

import com.fasterxml.jackson.annotation.JsonProperty;

public class ServerlessOptions {

@JsonProperty("network_policy_name")
private String networkPolicyName;

@JsonProperty("collection_name")
private String collectionName;

@JsonProperty("vpce_id")
private String vpceId;

public ServerlessOptions() {

}

public ServerlessOptions(String networkPolicyName, String collectionName, String vpceId) {
this.networkPolicyName = networkPolicyName;
this.collectionName = collectionName;
this.vpceId = vpceId;
}

public String getNetworkPolicyName() {
return networkPolicyName;
}

public String getCollectionName() {
return collectionName;
}

public String getVpceId() {
return vpceId;
}

}

Loading

0 comments on commit c3c35da

Please sign in to comment.