Skip to content

Commit

Permalink
Add logic to create or update serverless network policy.
Browse files Browse the repository at this point in the history
Signed-off-by: Adi Suresh <[email protected]>
  • Loading branch information
asuresh8 committed Oct 17, 2023
1 parent ccbe50c commit d9923e0
Show file tree
Hide file tree
Showing 7 changed files with 651 additions and 1 deletion.
6 changes: 6 additions & 0 deletions data-prepper-plugins/opensearch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ if `exclude_keys` is set to ["message", "status"], the document written to OpenS
* `sts_role_arn` (Optional) : The STS role to assume for requests to AWS. Defaults to null, which will use the [standard SDK behavior for credentials](https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/credentials.html).
* `sts_header_overrides` (Optional): A map of header overrides to make when assuming the IAM role for the sink plugin.
* `serverless` (Optional): A boolean flag to indicate the OpenSearch backend is Amazon OpenSearch Serverless. Default to `false`. Notice that [ISM policies.](https://opensearch.org/docs/latest/im-plugin/ism/policies/) is not supported in Amazon OpenSearch Serverless and thus any ISM related configuration value has no effect, i.e. `ism_policy_file`.
* `serverless_options` (Optional): Additional options you can specify when using serverless.

#### <a name="serverless_configuration">Serverless Configuration</a>
* `network_policy_name` (Optional): The serverless network policy name being used. If both `collection_name` and `vpce_id` are specified, then this network policy will be attempted to be created or update. On the managed OpenSearch Ingestion Service, the `collection_name` and `vpce_id` fields are automatically set.
* `collection_name` (Optional): The serverless collection name.
* `vpce_id` (Optional): The VPCE ID connected to Amazon OpenSearch Serverless.

## Metrics
### Management Disabled Index Type
Expand Down
2 changes: 1 addition & 1 deletion data-prepper-plugins/opensearch/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies {
implementation 'software.amazon.awssdk:arns'
implementation 'io.micrometer:micrometer-core'
implementation 'software.amazon.awssdk:s3'
implementation 'software.amazon.awssdk:opensearchserverless'
implementation libs.commons.lang3
implementation 'software.amazon.awssdk:apache-client'
testImplementation testLibs.junit.vintage
Expand Down Expand Up @@ -72,7 +73,6 @@ task integrationTest(type: Test) {
}
}


jacocoTestReport {
dependsOn test
reports {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.core.retry.backoff.FullJitterBackoffStrategy;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.http.apache.ProxyConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient;

import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
Expand Down Expand Up @@ -86,6 +90,10 @@ public class ConnectionConfiguration {
public static final String AWS_STS_HEADER_OVERRIDES = "aws_sts_header_overrides";
public static final String PROXY = "proxy";
public static final String SERVERLESS = "serverless";
public static final String SERVERLESS_OPTIONS = "serverless_options";
public static final String COLLECTION_NAME = "collection_name";
public static final String NETWORK_POLICY_NAME = "network_policy_name";
public static final String VPCE_ID = "vpce_id";
public static final String REQUEST_COMPRESSION_ENABLED = "enable_request_compression";

/**
Expand All @@ -109,6 +117,9 @@ public class ConnectionConfiguration {
private final Optional<String> proxy;
private final String pipelineName;
private final boolean serverless;
private final String serverlessNetworkPolicyName;
private final String serverlessCollectionName;
private final String serverlessVpceId;
private final boolean requestCompressionEnabled;

List<String> getHosts() {
Expand Down Expand Up @@ -155,6 +166,18 @@ boolean isServerless() {
return serverless;
}

String getServerlessNetworkPolicyName() {
return serverlessNetworkPolicyName;
}

String getServerlessCollectionName() {
return serverlessCollectionName;
}

String getServerlessVpceId() {
return serverlessVpceId;
}

boolean isRequestCompressionEnabled() {
return requestCompressionEnabled;
}
Expand All @@ -174,6 +197,9 @@ private ConnectionConfiguration(final Builder builder) {
this.awsStsHeaderOverrides = builder.awsStsHeaderOverrides;
this.proxy = builder.proxy;
this.serverless = builder.serverless;
this.serverlessNetworkPolicyName = builder.serverlessNetworkPolicyName;
this.serverlessCollectionName = builder.serverlessCollectionName;
this.serverlessVpceId = builder.serverlessVpceId;
this.requestCompressionEnabled = builder.requestCompressionEnabled;
this.pipelineName = builder.pipelineName;
}
Expand Down Expand Up @@ -212,6 +238,13 @@ public static ConnectionConfiguration readConnectionConfiguration(final PluginSe
builder.withAwsStsHeaderOverrides((Map<String, String>)awsOption.get(AWS_STS_HEADER_OVERRIDES.substring(4)));
builder.withServerless(OBJECT_MAPPER.convertValue(
awsOption.getOrDefault(SERVERLESS, false), Boolean.class));

Map<String, String> serverlessOptions = (Map<String, String>) awsOption.get(SERVERLESS_OPTIONS);
if (serverlessOptions != null && !serverlessOptions.isEmpty()) {
builder.withServerlessNetworkPolicyName((String)(serverlessOptions.getOrDefault(NETWORK_POLICY_NAME, null)));
builder.withServerlessCollectionName((String)(serverlessOptions.getOrDefault(COLLECTION_NAME, null)));
builder.withServerlessVpceId((String)(serverlessOptions.getOrDefault(VPCE_ID, null)));
}
} else {
builder.withServerless(false);
}
Expand Down Expand Up @@ -407,6 +440,24 @@ private OpenSearchTransport createOpenSearchTransport(final RestHighLevelClient
}
}

public OpenSearchServerlessClient createOpenSearchServerlessClient(final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsOptions awsCredentialsOptions = createAwsCredentialsOptions();

return OpenSearchServerlessClient.builder()
.credentialsProvider(awsCredentialsSupplier.getProvider(awsCredentialsOptions))
.region(Region.of(awsRegion))
.overrideConfiguration(ClientOverrideConfiguration.builder()
.retryPolicy(RetryPolicy.builder()
.backoffStrategy(FullJitterBackoffStrategy.builder()
.baseDelay(Duration.ofSeconds(10))
.maxBackoffTime(Duration.ofSeconds(60))
.build())
.numRetries(10)
.build())
.build())
.build();
}

private SdkHttpClient createSdkHttpClient() {
ApacheHttpClient.Builder apacheHttpClientBuilder = ApacheHttpClient.builder();
if (connectTimeout != null) {
Expand Down Expand Up @@ -475,6 +526,9 @@ public static class Builder {
private Optional<String> proxy = Optional.empty();
private String pipelineName;
private boolean serverless;
private String serverlessNetworkPolicyName;
private String serverlessCollectionName;
private String serverlessVpceId;
private boolean requestCompressionEnabled;

private void validateStsRoleArn(final String awsStsRoleArn) {
Expand Down Expand Up @@ -585,6 +639,21 @@ public Builder withServerless(boolean serverless) {
return this;
}

public Builder withServerlessNetworkPolicyName(final String serverlessNetworkPolicyName) {
this.serverlessNetworkPolicyName = serverlessNetworkPolicyName;
return this;
}

public Builder withServerlessCollectionName(final String serverlessCollectionName) {
this.serverlessCollectionName = serverlessCollectionName;
return this;
}

public Builder withServerlessVpceId(final String serverlessVpceId) {
this.serverlessVpceId = serverlessVpceId;
return this;
}

public Builder withRequestCompressionEnabled(final boolean requestCompressionEnabled) {
this.requestCompressionEnabled = requestCompressionEnabled;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.opensearch.dataprepper.plugins.sink.opensearch.index.TemplateStrategy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient;

import java.io.BufferedWriter;
import java.io.IOException;
Expand Down Expand Up @@ -243,6 +244,9 @@ private void doInitializeInternal() throws IOException {
bulkRequestSupplier,
pluginSetting);

// Attempt to update the serverless network policy if required argument are given.
maybeUpdateServerlessNetworkPolicy();

objectMapper = new ObjectMapper();
this.initialized = true;
LOG.info("Initialized OpenSearch sink");
Expand Down Expand Up @@ -505,4 +509,21 @@ public void shutdown() {
super.shutdown();
closeFiles();
}

private void maybeUpdateServerlessNetworkPolicy() {
final ConnectionConfiguration connectionConfiguration = openSearchSinkConfig.getConnectionConfiguration();
LOG.info(connectionConfiguration.toString());
if (connectionConfiguration.isServerless() &&
!StringUtils.isBlank(connectionConfiguration.getServerlessNetworkPolicyName()) &&
!StringUtils.isBlank(connectionConfiguration.getServerlessCollectionName()) &&
!StringUtils.isBlank(connectionConfiguration.getServerlessVpceId())
) {
final OpenSearchServerlessClient openSearchServerlessClient = connectionConfiguration.createOpenSearchServerlessClient(awsCredentialsSupplier);
final ServerlessNetworkPolicyUpdater networkPolicyUpdater = new ServerlessNetworkPolicyUpdater(openSearchServerlessClient);
networkPolicyUpdater.updateNetworkPolicy(
connectionConfiguration.getServerlessNetworkPolicyName(),
connectionConfiguration.getServerlessCollectionName(),
connectionConfiguration.getServerlessVpceId());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package org.opensearch.dataprepper.plugins.sink.opensearch;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient;
import software.amazon.awssdk.services.opensearchserverless.model.CreateSecurityPolicyRequest;
import software.amazon.awssdk.services.opensearchserverless.model.GetSecurityPolicyRequest;
import software.amazon.awssdk.services.opensearchserverless.model.GetSecurityPolicyResponse;
import software.amazon.awssdk.services.opensearchserverless.model.ResourceNotFoundException;
import software.amazon.awssdk.services.opensearchserverless.model.SecurityPolicyDetail;
import software.amazon.awssdk.services.opensearchserverless.model.SecurityPolicyType;
import software.amazon.awssdk.services.opensearchserverless.model.UpdateSecurityPolicyRequest;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Pattern;

public class ServerlessNetworkPolicyUpdater {

static final String COLLECTION = "collection";
static final String DESCRIPTION = "Description";
static final String RESOURCE = "Resource";
static final String RESOURCE_TYPE = "ResourceType";
static final String RULES = "Rules";
static final String SOURCE_VPCES = "SourceVPCEs";

private static final Logger LOG = LoggerFactory.getLogger(ServerlessNetworkPolicyUpdater.class);

private final OpenSearchServerlessClient client;

public ServerlessNetworkPolicyUpdater(OpenSearchServerlessClient client) {
this.client = client;
}

public void updateNetworkPolicy(
final String networkPolicyName,
final String collectionName,
final String vpceId
) {
try {
final Document newStatement = createNetworkPolicyStatement(collectionName, vpceId);
final Optional<SecurityPolicyDetail> maybeNetworkPolicy = getNetworkPolicy(networkPolicyName);

if (maybeNetworkPolicy.isPresent()) {
final Document existingPolicy = maybeNetworkPolicy.get().policy();
final String policyVersion = maybeNetworkPolicy.get().policyVersion();
final List<Document> existingStatements = existingPolicy.asList();
if (hasAcceptablePolicy(existingStatements, collectionName, vpceId)) {
LOG.info("Policy statement already exists that matches collection and vpce id");
return;
}

final List<Document> statements = new ArrayList<>(existingStatements);
statements.add(newStatement);
final Document newPolicy = Document.fromList(statements);
updateNetworkPolicy(networkPolicyName, newPolicy, policyVersion);
} else {
final Document newPolicy = Document.fromList(List.of(newStatement));
createNetworkPolicy(networkPolicyName, newPolicy);
}
} catch (final Exception e) {
LOG.error("Failed to create or update network policy", e);
}
}

// VisibleForTesting
Optional<SecurityPolicyDetail> getNetworkPolicy(final String networkPolicyName) {
// Call the GetSecurityPolicy API
GetSecurityPolicyRequest getRequest = GetSecurityPolicyRequest.builder()
.name(networkPolicyName)
.type(SecurityPolicyType.NETWORK)
.build();

GetSecurityPolicyResponse response;
try {
response = client.getSecurityPolicy(getRequest);
} catch (final ResourceNotFoundException e) {
LOG.info("Could not find network policy {}", networkPolicyName);
return Optional.empty();
}

if (response.securityPolicyDetail() == null) {
LOG.info("Security policy exists but had no detail.");
return Optional.empty();
}

return Optional.of(response.securityPolicyDetail());
}

// VisibleForTesting
void createNetworkPolicy(final String networkPolicyName, final Document policy) {
final CreateSecurityPolicyRequest request = CreateSecurityPolicyRequest.builder()
.name(networkPolicyName)
.policy(policy.toString())
.type(SecurityPolicyType.NETWORK)
.build();

client.createSecurityPolicy(request);
}

// VisibleForTesting
void updateNetworkPolicy(final String networkPolicyName, final Document policy, final String policyVersion) {
final UpdateSecurityPolicyRequest request = UpdateSecurityPolicyRequest.builder()
.name(networkPolicyName)
.policy(policy.toString())
.type(SecurityPolicyType.NETWORK)
.policyVersion(policyVersion)
.build();

client.updateSecurityPolicy(request);
}

// VisibleForTesting
static Document createNetworkPolicyStatement(final String collectionName, final String vpceId) {
return Document.mapBuilder()
.putString(DESCRIPTION, "Created by Data Prepper")
.putList(RULES, List.of(Document.mapBuilder()
.putString(RESOURCE_TYPE, COLLECTION)
.putList(RESOURCE, List.of(Document.fromString(String.format("%s/%s", COLLECTION, collectionName))))
.build()))
.putList(SOURCE_VPCES, List.of(Document.fromString(vpceId)))
.build();
}

// VisibleForTesting
static boolean hasAcceptablePolicy(final List<Document> statements, final String collectionName, final String vpceId) {
for (final Document statement : statements) {
final Map<String, Document> statementFields = statement.asMap();
if (!statementFields.containsKey(SOURCE_VPCES) || !statementFields.containsKey(RULES)) {
continue;
}

// Check if the statement has the SourceVPCEs field that matches the given vpceId
boolean hasMatchingVpce = statementFields.get(SOURCE_VPCES).asList().stream()
.map(Document::asString)
.anyMatch(vpce -> vpce.equals(vpceId));

// Check if the statement has the Rules field with the ResourceType set to COLLECTION
// that matches (or covers) the given collectionName
boolean hasMatchingCollection = statementFields.get(RULES).asList().stream()
.filter(rule -> rule.asMap().get(RESOURCE_TYPE).asString().equals(COLLECTION))
.flatMap(rule -> rule.asMap().get(RESOURCE).asList().stream())
.map(Document::asString)
.anyMatch(collectionPattern -> matchesPattern(collectionPattern, String.format("%s/%s", COLLECTION, collectionName)));

// If both conditions are met, return true
if (hasMatchingVpce && hasMatchingCollection) {
return true;
}
}
return false;
}

// VisibleForTesting
static boolean matchesPattern(String pattern, String value) {
// Convert wildcard pattern to regex
String regex = "^" + Pattern.quote(pattern).replace("*", "\\E.*\\Q") + "$";
return value.matches(regex);
}

}
Loading

0 comments on commit d9923e0

Please sign in to comment.