Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sqs source: json codec support to split sqs message into multiple events #5330

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the longer copyright header.

https://github.com/opensearch-project/data-prepper/blob/90575b1de56f82f44d1af36f31ff4b077a627bd7/CONTRIBUTING.md#license-headers

/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 *
*/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added to every file

* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;

import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.codec.InputCodec;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventMetadata;
import org.opensearch.dataprepper.model.record.Record;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.sqs.model.Message;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.MessageSystemAttributeName;

import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.TimeoutException;

public class BulkSqsMessageHandler implements SqsMessageHandler {
private static final Logger LOG = LoggerFactory.getLogger(BulkSqsMessageHandler.class);
private final InputCodec codec;

public BulkSqsMessageHandler(final InputCodec codec) {
this.codec = codec;
}

@Override
public void handleMessage(final Message message,
final String url,
final Buffer<Record<Event>> buffer,
final int bufferTimeoutMillis,
final AcknowledgementSet acknowledgementSet) {
try {
final String sqsBody = message.body();
ByteArrayInputStream inputStream = new ByteArrayInputStream(sqsBody.getBytes(StandardCharsets.UTF_8));
codec.parse(inputStream, record -> {
final Event event = record.getData();
final EventMetadata eventMetadata = event.getMetadata();
eventMetadata.setAttribute("queueUrl", url);
for (Map.Entry<MessageSystemAttributeName, String> entry : message.attributes().entrySet()) {
final String originalKey = entry.getKey().toString();
final String lowerCamelCaseKey = originalKey.substring(0, 1).toLowerCase() + originalKey.substring(1);;
eventMetadata.setAttribute(lowerCamelCaseKey, entry.getValue());
}

for (Map.Entry<String, MessageAttributeValue> entry : message.messageAttributes().entrySet()) {
Copy link
Member

@dlvenable dlvenable Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this code is similar to the RawSqsMessageHandler. We should not duplicate this code.

I think the best option is to update the existing RawSqsMessageHandler to support an injectable message strategy.

It might look like:

interface MessageFieldStrategy {
  List<Event> parseEvents(String messageBody);
}

This has some advantages of being able to use buffer.writeAll for the whole batch and the attribute code is shared for all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the BulkMessageHandler and added strategies instead. buffer.writeAll is also used instead of buffer.write now

final String originalKey = entry.getKey();
final String lowerCamelCaseKey = originalKey.substring(0, 1).toLowerCase() + originalKey.substring(1);;
eventMetadata.setAttribute(lowerCamelCaseKey, entry.getValue().stringValue());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this processing can be shared, which has the added benefit of avoiding extra memory. You can do this in the refactoring that I suggest above.

That is, before the loop over the List<Event>, create a Map<String, String> for all the attributes. Then re-use those values. This should reduce both compute and memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea, the metadata/attributes of a bulk message would remain the same for every event in that message

}

if (acknowledgementSet != null) {
acknowledgementSet.add(event);
}

try {
buffer.write(record, bufferTimeoutMillis);
} catch (TimeoutException e) {
throw new RuntimeException(e);
}
});
} catch (final Exception e) {
LOG.error("Error processing SQS message: {}", e.getMessage(), e);
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import org.hibernate.validator.constraints.time.DurationMax;
import org.hibernate.validator.constraints.time.DurationMin;
import org.opensearch.dataprepper.model.configuration.PluginModel;

public class QueueConfig {

Expand Down Expand Up @@ -62,6 +63,9 @@ public class QueueConfig {
@DurationMax(seconds = 20)
private Duration waitTime = DEFAULT_WAIT_TIME_SECONDS;

@JsonProperty("codec")
private PluginModel codec = null;

public String getUrl() {
return url;
}
Expand Down Expand Up @@ -93,5 +97,10 @@ public Duration getWaitTime() {
public Duration getPollDelay() {
return pollDelay;
}

public PluginModel getCodec() {
return codec;
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import org.opensearch.dataprepper.common.concurrent.BackgroundThreadFactory;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.model.codec.InputCodec;
import org.opensearch.dataprepper.model.configuration.PluginModel;
import org.opensearch.dataprepper.model.configuration.PluginSetting;
import org.opensearch.dataprepper.model.plugin.PluginFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
Expand All @@ -35,9 +39,9 @@ public class SqsService {
static final double JITTER_RATE = 0.20;

private final SqsSourceConfig sqsSourceConfig;
private final SqsEventProcessor sqsEventProcessor;
private final SqsClient sqsClient;
private final PluginMetrics pluginMetrics;
private final PluginFactory pluginFactory;
private final AcknowledgementSetManager acknowledgementSetManager;
private final List<ExecutorService> allSqsUrlExecutorServices;
private final List<SqsWorker> sqsWorkers;
Expand All @@ -46,13 +50,13 @@ public class SqsService {
public SqsService(final Buffer<Record<Event>> buffer,
final AcknowledgementSetManager acknowledgementSetManager,
final SqsSourceConfig sqsSourceConfig,
final SqsEventProcessor sqsEventProcessor,
final PluginMetrics pluginMetrics,
final PluginFactory pluginFactory,
final AwsCredentialsProvider credentialsProvider) {

this.sqsSourceConfig = sqsSourceConfig;
this.sqsEventProcessor = sqsEventProcessor;
this.pluginMetrics = pluginMetrics;
this.pluginFactory = pluginFactory;
this.acknowledgementSetManager = acknowledgementSetManager;
this.allSqsUrlExecutorServices = new ArrayList<>();
this.sqsWorkers = new ArrayList<>();
Expand All @@ -70,8 +74,16 @@ public void start() {
sqsSourceConfig.getQueues().forEach(queueConfig -> {
String queueUrl = queueConfig.getUrl();
String queueName = queueUrl.substring(queueUrl.lastIndexOf('/') + 1);

int numWorkers = queueConfig.getNumWorkers();
SqsEventProcessor sqsEventProcessor;
if (queueConfig.getCodec() != null) {
final PluginModel codecConfiguration = queueConfig.getCodec();
final PluginSetting codecPluginSettings = new PluginSetting(codecConfiguration.getPluginName(), codecConfiguration.getPluginSettings());
final InputCodec codec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSettings);
sqsEventProcessor = new SqsEventProcessor(new BulkSqsMessageHandler(codec));
} else {
sqsEventProcessor = new SqsEventProcessor(new RawSqsMessageHandler());
}
ExecutorService executorService = Executors.newFixedThreadPool(
numWorkers, BackgroundThreadFactory.defaultExecutorThreadFactory("sqs-source" + queueName));
allSqsUrlExecutorServices.add(executorService);
Expand All @@ -80,10 +92,10 @@ public void start() {
buffer,
acknowledgementSetManager,
sqsClient,
sqsEventProcessor,
sqsSourceConfig,
queueConfig,
pluginMetrics,
sqsEventProcessor,
backoff))
.collect(Collectors.toList());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.plugin.PluginFactory;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.model.source.Source;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
Expand All @@ -21,6 +22,7 @@
public class SqsSource implements Source<Record<Event>> {

private final PluginMetrics pluginMetrics;
private final PluginFactory pluginFactory;
private final SqsSourceConfig sqsSourceConfig;
private SqsService sqsService;
private final AcknowledgementSetManager acknowledgementSetManager;
Expand All @@ -31,10 +33,12 @@ public class SqsSource implements Source<Record<Event>> {
@DataPrepperPluginConstructor
public SqsSource(final PluginMetrics pluginMetrics,
final SqsSourceConfig sqsSourceConfig,
final PluginFactory pluginFactory,
final AcknowledgementSetManager acknowledgementSetManager,
final AwsCredentialsSupplier awsCredentialsSupplier) {

this.pluginMetrics = pluginMetrics;
this.pluginFactory = pluginFactory;
this.sqsSourceConfig = sqsSourceConfig;
this.acknowledgementsEnabled = sqsSourceConfig.getAcknowledgements();
this.acknowledgementSetManager = acknowledgementSetManager;
Expand All @@ -49,9 +53,7 @@ public void start(Buffer<Record<Event>> buffer) {
}
final AwsAuthenticationAdapter awsAuthenticationAdapter = new AwsAuthenticationAdapter(awsCredentialsSupplier, sqsSourceConfig);
final AwsCredentialsProvider credentialsProvider = awsAuthenticationAdapter.getCredentialsProvider();
final SqsMessageHandler rawSqsMessageHandler = new RawSqsMessageHandler();
final SqsEventProcessor sqsEventProcessor = new SqsEventProcessor(rawSqsMessageHandler);
sqsService = new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, sqsEventProcessor, pluginMetrics, credentialsProvider);
sqsService = new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, pluginMetrics, pluginFactory, credentialsProvider);
sqsService.start();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,20 @@ public class SqsWorker implements Runnable {
public SqsWorker(final Buffer<Record<Event>> buffer,
final AcknowledgementSetManager acknowledgementSetManager,
final SqsClient sqsClient,
final SqsEventProcessor sqsEventProcessor,
final SqsSourceConfig sqsSourceConfig,
final QueueConfig queueConfig,
final PluginMetrics pluginMetrics,
final SqsEventProcessor sqsEventProcessor,
final Backoff backoff) {

this.sqsClient = sqsClient;
this.sqsEventProcessor = sqsEventProcessor;
this.queueConfig = queueConfig;
this.acknowledgementSetManager = acknowledgementSetManager;
this.standardBackoff = backoff;
this.endToEndAcknowledgementsEnabled = sqsSourceConfig.getAcknowledgements();
this.buffer = buffer;
this.bufferTimeoutMillis = (int) sqsSourceConfig.getBufferTimeout().toMillis();

this.sqsEventProcessor = sqsEventProcessor;
messageVisibilityTimesMap = new HashMap<>();
failedAttemptCount = 0;
sqsMessagesReceivedCounter = pluginMetrics.counter(SQS_MESSAGES_RECEIVED_METRIC_NAME);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.codec.InputCodec;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventMetadata;
import org.opensearch.dataprepper.model.record.Record;
import software.amazon.awssdk.services.sqs.model.Message;

import java.io.InputStream;
import java.util.function.Consumer;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class BulkSqsMessageHandlerTest {

private InputCodec mockCodec;
private Buffer<Record<Event>> mockBuffer;
private BulkSqsMessageHandler bulkSqsMessageHandler;
private int bufferTimeoutMillis;

@BeforeEach
void setUp() {
mockCodec = mock(InputCodec.class);
mockBuffer = mock(Buffer.class);
bulkSqsMessageHandler = new BulkSqsMessageHandler(mockCodec);
bufferTimeoutMillis = 10000;
}

@Test
void handleMessage_callsBufferWriteOnce() throws Exception {
final Message message = Message.builder()
.body("{\"someKey\":\"someValue\"}")
.build();
final String queueUrl = "https://sqs.us-east-1.amazonaws.com/123456789012/test-queue";

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
Consumer<Record<Event>> eventConsumer = invocation.getArgument(1);
final Event mockEvent = mock(Event.class);
final EventMetadata mockMetadata = mock(EventMetadata.class);
when(mockEvent.getMetadata()).thenReturn(mockMetadata);
when(mockMetadata.getEventType()).thenReturn("DOCUMENT");
eventConsumer.accept(new Record<>(mockEvent));
return null;
}).when(mockCodec).parse(any(InputStream.class), any());

bulkSqsMessageHandler.handleMessage(message, queueUrl, mockBuffer, bufferTimeoutMillis, null);
ArgumentCaptor<Record<Event>> argumentCaptor = ArgumentCaptor.forClass(Record.class);
verify(mockBuffer, times(1)).write(argumentCaptor.capture(), eq(bufferTimeoutMillis));
Record<Event> capturedRecord = argumentCaptor.getValue();
assertEquals(
"DOCUMENT",
capturedRecord.getData().getMetadata().getEventType(),
"Event type should be 'DOCUMENT'"
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ void testDefaultValues() {
assertEquals(1, queueConfig.getNumWorkers(), "Number of workers should default to 1");
assertNull(queueConfig.getMaximumMessages(), "Maximum messages should be null by default");
assertEquals(Duration.ofSeconds(0), queueConfig.getPollDelay(), "Poll delay should default to 0 seconds");
assertNull(queueConfig.getCodec(), "Codec should be null by default");
assertNull(queueConfig.getVisibilityTimeout(), "Visibility timeout should be null by default");
assertFalse(queueConfig.getVisibilityDuplicateProtection(), "Visibility duplicate protection should default to false");
assertEquals(Duration.ofHours(2), queueConfig.getVisibilityDuplicateProtectionTimeout(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.plugin.PluginFactory;
import org.opensearch.dataprepper.model.record.Record;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
Expand All @@ -27,19 +28,19 @@

class SqsServiceTest {
private SqsSourceConfig sqsSourceConfig;
private SqsEventProcessor sqsEventProcessor;
private SqsClient sqsClient;
private PluginMetrics pluginMetrics;
private PluginFactory pluginFactory;
private AcknowledgementSetManager acknowledgementSetManager;
private Buffer<Record<Event>> buffer;
private AwsCredentialsProvider credentialsProvider;

@BeforeEach
void setUp() {
sqsSourceConfig = mock(SqsSourceConfig.class);
sqsEventProcessor = mock(SqsEventProcessor.class);
sqsClient = mock(SqsClient.class, withSettings());
pluginMetrics = mock(PluginMetrics.class);
pluginFactory = mock(PluginFactory.class);
acknowledgementSetManager = mock(AcknowledgementSetManager.class);
buffer = mock(Buffer.class);
credentialsProvider = mock(AwsCredentialsProvider.class);
Expand All @@ -55,7 +56,7 @@ void start_with_single_queue_starts_workers() {
when(queueConfig.getUrl()).thenReturn("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue");
when(queueConfig.getNumWorkers()).thenReturn(2);
when(sqsSourceConfig.getQueues()).thenReturn(List.of(queueConfig));
SqsService sqsService = spy(new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, sqsEventProcessor, pluginMetrics, credentialsProvider));
SqsService sqsService = spy(new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, pluginMetrics, pluginFactory, credentialsProvider));
doReturn(sqsClient).when(sqsService).createSqsClient(credentialsProvider);
sqsService.start(); // if no exception is thrown here, then workers have been started
}
Expand All @@ -67,7 +68,7 @@ void stop_should_shutdown_executors_and_workers_and_close_client() throws Interr
when(queueConfig.getNumWorkers()).thenReturn(1);
when(sqsSourceConfig.getQueues()).thenReturn(List.of(queueConfig));
SqsClient sqsClient = mock(SqsClient.class);
SqsService sqsService = new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, sqsEventProcessor, pluginMetrics, credentialsProvider) {
SqsService sqsService = new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, pluginMetrics, pluginFactory, credentialsProvider) {
@Override
SqsClient createSqsClient(final AwsCredentialsProvider credentialsProvider) {
return sqsClient;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.plugin.PluginFactory;
import org.opensearch.dataprepper.model.record.Record;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
Expand All @@ -27,6 +28,7 @@ class SqsSourceTest {
private final String TEST_PIPELINE_NAME = "test_pipeline";
private SqsSource sqsSource;
private PluginMetrics pluginMetrics;
private PluginFactory pluginFactory;
private SqsSourceConfig sqsSourceConfig;
private AcknowledgementSetManager acknowledgementSetManager;
private AwsCredentialsSupplier awsCredentialsSupplier;
Expand All @@ -36,10 +38,11 @@ class SqsSourceTest {
@BeforeEach
void setUp() {
pluginMetrics = PluginMetrics.fromNames(PLUGIN_NAME, TEST_PIPELINE_NAME);
pluginFactory = mock(PluginFactory.class);
sqsSourceConfig = mock(SqsSourceConfig.class);
acknowledgementSetManager = mock(AcknowledgementSetManager.class);
awsCredentialsSupplier = mock(AwsCredentialsSupplier.class);
sqsSource = new SqsSource(pluginMetrics, sqsSourceConfig, acknowledgementSetManager, awsCredentialsSupplier);
sqsSource = new SqsSource(pluginMetrics, sqsSourceConfig, pluginFactory, acknowledgementSetManager, awsCredentialsSupplier);
buffer = mock(Buffer.class);
}

Expand Down
Loading
Loading