Skip to content

Commit

Permalink
[Dataflow Streaming] Support to receive multiple work items in a sing…
Browse files Browse the repository at this point in the history
…le StreamingGetWorkResponseChunk
  • Loading branch information
arunpandianp committed Jan 6, 2025
1 parent 5eed396 commit 4379de8
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ public interface DataflowStreamingPipelineOptions extends PipelineOptions {

void setUseSeparateWindmillHeartbeatStreams(Boolean value);

@Description("If true, GetWorkStreams will request multiple work items in a response chunk.")
boolean getWindmillMultipleItemsInGetWorkResponse();

void setWindmillMultipleItemsInGetWorkResponse(boolean value);

@Description("The number of streams to use for GetData requests.")
@Default.Integer(1)
int getWindmillGetDataStreamCount();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,8 @@ private static GrpcWindmillStreamFactory.Builder createGrpcwindmillStreamFactory
.setSendKeyedGetDataRequests(
!options.isEnableStreamingEngine()
|| DataflowRunner.hasExperiment(
options, "streaming_engine_disable_new_heartbeat_requests"));
options, "streaming_engine_disable_new_heartbeat_requests"))
.setMultipleItemsInGetWorkResponse(options.getWindmillMultipleItemsInGetWorkResponse());
}

private static JobHeader createJobHeader(DataflowWorkerHarnessOptions options, long clientId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import com.google.auto.value.AutoValue;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import javax.annotation.Nullable;
Expand All @@ -43,35 +44,46 @@
*/
@NotThreadSafe
final class GetWorkResponseChunkAssembler {

private static final Logger LOG = LoggerFactory.getLogger(GetWorkResponseChunkAssembler.class);

private final GetWorkTimingInfosTracker workTimingInfosTracker;
private @Nullable ComputationMetadata metadata;
private ByteString data;
private long bufferedSize;
private long remainingSize;

GetWorkResponseChunkAssembler() {
workTimingInfosTracker = new GetWorkTimingInfosTracker(System::currentTimeMillis);
data = ByteString.EMPTY;
bufferedSize = 0;
metadata = null;
}

/**
* Appends the response chunk bytes to the {@link #data }byte buffer. Return the assembled
* WorkItem if all response chunks for a WorkItem have been received.
*/
Optional<AssembledWorkItem> append(Windmill.StreamingGetWorkResponseChunk chunk) {
List<AssembledWorkItem> append(Windmill.StreamingGetWorkResponseChunk chunk) {
if (chunk.hasComputationMetadata()) {
metadata = ComputationMetadata.fromProto(chunk.getComputationMetadata());
}

data = data.concat(chunk.getSerializedWorkItem());
bufferedSize += chunk.getSerializedWorkItem().size();
workTimingInfosTracker.addTimingInfo(chunk.getPerWorkItemTimingInfosList());

// If the entire WorkItem has been received, assemble the WorkItem.
return chunk.getRemainingBytesForWorkItem() == 0 ? flushToWorkItem() : Optional.empty();
List<AssembledWorkItem> response = new ArrayList<>();
for (int i = 0; i < chunk.getSerializedWorkItemList().size(); i++) {
data = data.concat(chunk.getSerializedWorkItemList().get(i));
if (i == chunk.getSerializedWorkItemList().size() - 1) {
remainingSize = chunk.getRemainingBytesForWorkItem();
} else {
remainingSize = 0;
}
if (remainingSize == 0) {
flushToWorkItem().ifPresent(response::add);
}
}
if (remainingSize == 0) {
workTimingInfosTracker.reset();
}
return response;
}

/**
Expand All @@ -81,25 +93,26 @@ Optional<AssembledWorkItem> append(Windmill.StreamingGetWorkResponseChunk chunk)
*/
private Optional<AssembledWorkItem> flushToWorkItem() {
try {
long size = data.size();
return Optional.of(
AssembledWorkItem.create(
WorkItem.parseFrom(data.newInput()),
Preconditions.checkNotNull(metadata),
workTimingInfosTracker.getLatencyAttributions(),
bufferedSize));
size));
} catch (IOException e) {
LOG.error("Failed to parse work item from stream: ", e);
} finally {
workTimingInfosTracker.reset();
data = ByteString.EMPTY;
bufferedSize = 0;
}

return Optional.empty();
}

@AutoValue
abstract static class ComputationMetadata {

private static ComputationMetadata fromProto(
Windmill.ComputationWorkItemMetadata metadataProto) {
return new AutoValue_GetWorkResponseChunkAssembler_ComputationMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,6 @@ final class GrpcDirectGetWorkStream
extends AbstractWindmillStream<StreamingGetWorkRequest, StreamingGetWorkResponseChunk>
implements GetWorkStream {
private static final Logger LOG = LoggerFactory.getLogger(GrpcDirectGetWorkStream.class);
private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST =
StreamingGetWorkRequest.newBuilder()
.setRequestExtension(
Windmill.StreamingGetWorkRequestExtension.newBuilder()
.setMaxItems(0)
.setMaxBytes(0)
.build())
.build();

private final GetWorkBudgetTracker budgetTracker;
private final GetWorkRequest requestHeader;
Expand All @@ -88,6 +80,8 @@ final class GrpcDirectGetWorkStream
*/
private final ConcurrentMap<Long, GetWorkResponseChunkAssembler> workItemAssemblers;

private final boolean multipleItemsInGetWorkResponse;

private GrpcDirectGetWorkStream(
String backendWorkerToken,
Function<
Expand All @@ -99,6 +93,7 @@ private GrpcDirectGetWorkStream(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean multipleItemsInGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
HeartbeatSender heartbeatSender,
GetDataClient getDataClient,
Expand Down Expand Up @@ -127,6 +122,7 @@ private GrpcDirectGetWorkStream(
.setItems(requestHeader.getMaxItems())
.setBytes(requestHeader.getMaxBytes())
.build());
this.multipleItemsInGetWorkResponse = multipleItemsInGetWorkResponse;
}

static GrpcDirectGetWorkStream create(
Expand All @@ -140,6 +136,7 @@ static GrpcDirectGetWorkStream create(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean multipleItemsInGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
HeartbeatSender heartbeatSender,
GetDataClient getDataClient,
Expand All @@ -153,6 +150,7 @@ static GrpcDirectGetWorkStream create(
streamObserverFactory,
streamRegistry,
logEveryNStreamFailures,
multipleItemsInGetWorkResponse,
getWorkThrottleTimer,
heartbeatSender,
getDataClient,
Expand Down Expand Up @@ -184,6 +182,7 @@ private void maybeSendRequestExtension(GetWorkBudget extension) {
Windmill.StreamingGetWorkRequestExtension.newBuilder()
.setMaxItems(extension.items())
.setMaxBytes(extension.bytes()))
.setSupportsMultipleWorkItemsInChunk(multipleItemsInGetWorkResponse)
.build();
lastRequest.set(request);
budgetTracker.recordBudgetRequested(extension);
Expand All @@ -209,6 +208,7 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException
.setMaxItems(initialGetWorkBudget.items())
.setMaxBytes(initialGetWorkBudget.bytes())
.build())
.setSupportsMultipleWorkItemsInChunk(multipleItemsInGetWorkResponse)
.build();
lastRequest.set(request);
budgetTracker.recordBudgetRequested(initialGetWorkBudget);
Expand All @@ -231,7 +231,15 @@ public void appendSpecificHtml(PrintWriter writer) {

@Override
public void sendHealthCheck() throws WindmillStreamShutdownException {
trySend(HEALTH_CHECK_REQUEST);
trySend(
StreamingGetWorkRequest.newBuilder()
.setRequestExtension(
Windmill.StreamingGetWorkRequestExtension.newBuilder()
.setMaxItems(0)
.setMaxBytes(0)
.build())
.setSupportsMultipleWorkItemsInChunk(multipleItemsInGetWorkResponse)
.build());
}

@Override
Expand All @@ -243,7 +251,7 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) {
workItemAssemblers
.computeIfAbsent(chunk.getStreamId(), unused -> new GetWorkResponseChunkAssembler())
.append(chunk)
.ifPresent(this::consumeAssembledWorkItem);
.forEach(this::consumeAssembledWorkItem);
}

private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ final class GrpcGetWorkStream
private final Map<Long, GetWorkResponseChunkAssembler> workItemAssemblers;
private final AtomicLong inflightMessages;
private final AtomicLong inflightBytes;
private final boolean multipleItemsInGetWorkResponse;

private GrpcGetWorkStream(
String backendWorkerToken,
Expand All @@ -64,6 +65,7 @@ private GrpcGetWorkStream(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean multipleItemsInGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
WorkItemReceiver receiver) {
super(
Expand All @@ -81,6 +83,7 @@ private GrpcGetWorkStream(
this.workItemAssemblers = new ConcurrentHashMap<>();
this.inflightMessages = new AtomicLong();
this.inflightBytes = new AtomicLong();
this.multipleItemsInGetWorkResponse = multipleItemsInGetWorkResponse;
}

public static GrpcGetWorkStream create(
Expand All @@ -94,6 +97,7 @@ public static GrpcGetWorkStream create(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean multipleItemsInGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
WorkItemReceiver receiver) {
return new GrpcGetWorkStream(
Expand All @@ -104,6 +108,7 @@ public static GrpcGetWorkStream create(
streamObserverFactory,
streamRegistry,
logEveryNStreamFailures,
multipleItemsInGetWorkResponse,
getWorkThrottleTimer,
receiver);
}
Expand All @@ -115,6 +120,7 @@ private void sendRequestExtension(long moreItems, long moreBytes) {
StreamingGetWorkRequestExtension.newBuilder()
.setMaxItems(moreItems)
.setMaxBytes(moreBytes))
.setSupportsMultipleWorkItemsInChunk(multipleItemsInGetWorkResponse)
.build();

executeSafely(
Expand All @@ -132,7 +138,11 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException
workItemAssemblers.clear();
inflightMessages.set(request.getMaxItems());
inflightBytes.set(request.getMaxBytes());
trySend(StreamingGetWorkRequest.newBuilder().setRequest(request).build());
trySend(
StreamingGetWorkRequest.newBuilder()
.setSupportsMultipleWorkItemsInChunk(multipleItemsInGetWorkResponse)
.setRequest(request)
.build());
}

@Override
Expand All @@ -157,6 +167,7 @@ public void sendHealthCheck() throws WindmillStreamShutdownException {
StreamingGetWorkRequest.newBuilder()
.setRequestExtension(
StreamingGetWorkRequestExtension.newBuilder().setMaxItems(0).setMaxBytes(0).build())
.setSupportsMultipleWorkItemsInChunk(multipleItemsInGetWorkResponse)
.build());
}

Expand All @@ -166,7 +177,7 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) {
workItemAssemblers
.computeIfAbsent(chunk.getStreamId(), unused -> new GetWorkResponseChunkAssembler())
.append(chunk)
.ifPresent(this::consumeAssembledWorkItem);
.forEach(this::consumeAssembledWorkItem);
}

private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ static GrpcWindmillServer newTestInstance(
.setSendKeyedGetDataRequests(sendKeyedGetDataRequests)
.setHealthCheckIntervalMillis(
testOptions.getWindmillServiceStreamingRpcHealthCheckPeriodMs())
.setMultipleItemsInGetWorkResponse(
testOptions.getWindmillMultipleItemsInGetWorkResponse())
.build();

return new GrpcWindmillServer(testOptions, windmillStreamFactory, dispatcherClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public class GrpcWindmillStreamFactory implements StatusDataProvider {
// If true, then active work refreshes will be sent as KeyedGetDataRequests. Otherwise, use the
// newer ComputationHeartbeatRequests.
private final boolean sendKeyedGetDataRequests;
private final boolean multipleItemsInGetWorkResponse;
private final Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses;

private GrpcWindmillStreamFactory(
Expand All @@ -99,6 +100,7 @@ private GrpcWindmillStreamFactory(
int streamingRpcBatchLimit,
int windmillMessagesBetweenIsReadyChecks,
boolean sendKeyedGetDataRequests,
boolean multipleItemsInGetWorkResponse,
Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses,
Supplier<Duration> maxBackOffSupplier) {
this.jobHeader = jobHeader;
Expand All @@ -115,6 +117,7 @@ private GrpcWindmillStreamFactory(
.backoff());
this.streamRegistry = ConcurrentHashMap.newKeySet();
this.sendKeyedGetDataRequests = sendKeyedGetDataRequests;
this.multipleItemsInGetWorkResponse = multipleItemsInGetWorkResponse;
this.processHeartbeatResponses = processHeartbeatResponses;
this.streamIdGenerator = new AtomicLong();
}
Expand All @@ -126,6 +129,7 @@ static GrpcWindmillStreamFactory create(
int streamingRpcBatchLimit,
int windmillMessagesBetweenIsReadyChecks,
boolean sendKeyedGetDataRequests,
boolean multipleItemsInGetWorkResponse,
Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses,
Supplier<Duration> maxBackOffSupplier,
int healthCheckIntervalMillis) {
Expand All @@ -136,6 +140,7 @@ static GrpcWindmillStreamFactory create(
streamingRpcBatchLimit,
windmillMessagesBetweenIsReadyChecks,
sendKeyedGetDataRequests,
multipleItemsInGetWorkResponse,
processHeartbeatResponses,
maxBackOffSupplier);

Expand Down Expand Up @@ -209,6 +214,7 @@ public GetWorkStream createGetWorkStream(
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
multipleItemsInGetWorkResponse,
getWorkThrottleTimer,
processWorkItem);
}
Expand All @@ -229,6 +235,7 @@ public GetWorkStream createDirectGetWorkStream(
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
multipleItemsInGetWorkResponse,
getWorkThrottleTimer,
heartbeatSender,
getDataClient,
Expand Down Expand Up @@ -356,6 +363,8 @@ Builder setProcessHeartbeatResponses(

Builder setHealthCheckIntervalMillis(int healthCheckIntervalMillis);

Builder setMultipleItemsInGetWorkResponse(boolean enabled);

GrpcWindmillStreamFactory build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem
.setInputDataWatermark(1L)
.setDependentRealtimeInputWatermark(1L)
.build())
.setSerializedWorkItem(workItem.toByteString())
.addSerializedWorkItem(workItem.toByteString())
.setRemainingBytesForWorkItem(0)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ public void onNext(StreamingGetWorkRequest request) {
StreamingGetWorkResponseChunk.Builder builder =
StreamingGetWorkResponseChunk.newBuilder()
.setStreamId(id)
.setSerializedWorkItem(serializedResponse.substring(i, end))
.addSerializedWorkItem(serializedResponse.substring(i, end))
.setRemainingBytesForWorkItem(serializedResponse.size() - end);

if (i == 0) {
Expand Down Expand Up @@ -1166,7 +1166,7 @@ public void onNext(StreamingGetWorkRequest request) {
StreamingGetWorkResponseChunk.Builder builder =
StreamingGetWorkResponseChunk.newBuilder()
.setStreamId(id)
.setSerializedWorkItem(serializedResponse)
.addSerializedWorkItem(serializedResponse)
.setRemainingBytesForWorkItem(0)
.setComputationMetadata(
ComputationWorkItemMetadata.newBuilder()
Expand Down
Loading

0 comments on commit 4379de8

Please sign in to comment.