Skip to content

Commit

Permalink
Add synchronous execution option to workflow provisioning
Browse files Browse the repository at this point in the history
Signed-off-by: Junwei Dai <[email protected]>
  • Loading branch information
Junwei Dai committed Jan 7, 2025
1 parent 0284554 commit 95085ff
Show file tree
Hide file tree
Showing 11 changed files with 476 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
*/
package org.opensearch.flowframework.common;

import org.opensearch.common.unit.TimeValue;

/**
* Representation of common values that are used across project
*/
Expand Down Expand Up @@ -55,6 +57,8 @@ private CommonValue() {}
/** The last provisioned time field */
public static final String LAST_PROVISIONED_TIME_FIELD = "last_provisioned_time";

public static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = TimeValue.timeValueSeconds(1);

/*
* Constants associated with Rest or Transport actions
*/
Expand All @@ -74,6 +78,8 @@ private CommonValue() {}
public static final String PROVISION_WORKFLOW = "provision";
/** The param name for update workflow field in create API */
public static final String UPDATE_WORKFLOW_FIELDS = "update_fields";
/** The param name for specifying the timeout duration in seconds to wait for workflow completion */
public static final String WAIT_FOR_COMPLETION_TIMEOUT = "wait_for_completion_timeout";
/** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */
public static final String WORKFLOW_STEP = "workflow_step";
/** The param name for default use case, used by the create workflow API */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
Expand Down Expand Up @@ -43,6 +44,7 @@
import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS;
import static org.opensearch.flowframework.common.CommonValue.USE_CASE;
import static org.opensearch.flowframework.common.CommonValue.VALIDATION;
import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
Expand Down Expand Up @@ -87,6 +89,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false);
boolean reprovision = request.paramAsBoolean(REPROVISION_WORKFLOW, false);
boolean updateFields = request.paramAsBoolean(UPDATE_WORKFLOW_FIELDS, false);
TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, null);
String useCase = request.param(USE_CASE);

// If provisioning, consume all other params and pass to provision transport action
Expand Down Expand Up @@ -226,7 +229,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
validation,
provision || updateFields,
params,
reprovision
reprovision,
waitForCompletionTimeout
);

return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
Expand All @@ -33,6 +34,7 @@
import java.util.stream.Collectors;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
Expand Down Expand Up @@ -73,6 +75,7 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
String workflowId = request.param(WORKFLOW_ID);
TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, null);
try {
Map<String, String> params = parseParamsAndContent(request);
if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) {
Expand All @@ -86,7 +89,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST);
}
// Create request and provision
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params, waitForCompletionTimeout);
return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ private void createExecute(WorkflowRequest request, User user, ActionListener<Wo
WorkflowRequest workflowRequest = new WorkflowRequest(
globalContextResponse.getId(),
null,
request.getParams()
request.getParams(),
request.getWaitForCompletionTimeout()
);
logger.info(
"Provisioning parameter is set, continuing to provision workflow {}",
Expand All @@ -261,7 +262,18 @@ private void createExecute(WorkflowRequest request, User user, ActionListener<Wo
ProvisionWorkflowAction.INSTANCE,
workflowRequest,
ActionListener.wrap(provisionResponse -> {
listener.onResponse(new WorkflowResponse(provisionResponse.getWorkflowId()));
if (request.getWaitForCompletionTimeout() != null) {
listener.onResponse(
new WorkflowResponse(
provisionResponse.getWorkflowId(),
provisionResponse.getWorkflowState()
)
);
} else {
listener.onResponse(
new WorkflowResponse(provisionResponse.getWorkflowId())
);
}
}, exception -> {
String errorMessage = "Provisioning failed.";
logger.error(errorMessage, exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD;
Expand Down Expand Up @@ -210,14 +212,27 @@ private void executeProvisionRequest(
),
ActionListener.wrap(updateResponse -> {
logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING);
executeWorkflowAsync(workflowId, provisionProcessSequence, listener);
if (request.getWaitForCompletionTimeout() != null) {
executeWorkflowSync(
workflowId,
provisionProcessSequence,
listener,
request.getWaitForCompletionTimeout().getMillis()
);
} else {
executeWorkflowAsync(workflowId, provisionProcessSequence, listener);
}
// update last provisioned field in template
Template newTemplate = Template.builder(template).lastProvisionedTime(Instant.now()).build();
flowFrameworkIndicesHandler.updateTemplateInGlobalContext(
request.getWorkflowId(),
newTemplate,
ActionListener.wrap(templateResponse -> {
listener.onResponse(new WorkflowResponse(request.getWorkflowId()));
if (request.getWaitForCompletionTimeout() != null) {
logger.info("Waiting for workflow completion");
} else {
listener.onResponse(new WorkflowResponse(request.getWorkflowId()));
}
}, exception -> {
String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage(
"Failed to update use case template {}",
Expand Down Expand Up @@ -275,18 +290,105 @@ private void executeProvisionRequest(
*/
private void executeWorkflowAsync(String workflowId, List<ProcessNode> workflowSequence, ActionListener<WorkflowResponse> listener) {
try {
threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, workflowId); });
threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL)
.execute(() -> { executeWorkflow(workflowSequence, workflowId, listener, false); });
} catch (Exception exception) {
listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(exception)));
}
}

/**
* Retrieves a thread from the provision thread pool to execute a workflow with a timeout mechanism.
* If the execution exceeds the specified timeout, it will return the current status of the workflow.
*
* @param workflowId The id of the workflow
* @param workflowSequence The sorted workflow to execute
* @param listener ActionListener for any failures or responses
* @param timeout The timeout duration in milliseconds
*/
private void executeWorkflowSync(
String workflowId,
List<ProcessNode> workflowSequence,
ActionListener<WorkflowResponse> listener,
long timeout
) {
PlainActionFuture<WorkflowResponse> workflowFuture = new PlainActionFuture<>();
AtomicBoolean isResponseSent = new AtomicBoolean(false);
CompletableFuture.runAsync(() -> {
try {
executeWorkflow(workflowSequence, workflowId, new ActionListener<>() {
@Override
public void onResponse(WorkflowResponse workflowResponse) {
if (isResponseSent.get()) {
logger.info("Ignoring onResponse for workflowId: {} as timeout already occurred", workflowId);
return;
}
isResponseSent.set(true);
workflowFuture.onResponse(null);
listener.onResponse(new WorkflowResponse(workflowResponse.getWorkflowId(), workflowResponse.getWorkflowState()));
}

@Override
public void onFailure(Exception e) {
if (isResponseSent.get()) {
logger.info("Ignoring onFailure for workflowId: {} as timeout already occurred", workflowId);
return;
}
isResponseSent.set(true);
workflowFuture.onFailure(
new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(e))
);
listener.onFailure(
new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(e))
);
}
}, true);
} catch (Exception ex) {
if (!isResponseSent.get()) {
isResponseSent.set(true);
workflowFuture.onFailure(
new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(ex))
);
listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(ex)));
}
}
}, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL));

threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> {
try {
Thread.sleep(timeout);
if (isResponseSent.compareAndSet(false, true)) {
logger.warn("Workflow execution timed out for workflowId: {}", workflowId);
client.execute(
GetWorkflowStateAction.INSTANCE,
new GetWorkflowStateRequest(workflowId, false),
ActionListener.wrap(
response -> listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())),
exception -> listener.onFailure(
new FlowFrameworkException("Failed to get workflow state after timeout", ExceptionsHelper.status(exception))
)
)
);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});
}

/**
* Executes the given workflow sequence
* @param workflowSequence The topologically sorted workflow to execute
* @param workflowId The workflowId associated with the workflow that is executing
* @param listener The ActionListener to handle the workflow response or failure
* @param isSyncExecution Flag indicating whether the workflow should be executed synchronously (true) or asynchronously (false)
*/
private void executeWorkflow(List<ProcessNode> workflowSequence, String workflowId) {
private void executeWorkflow(
List<ProcessNode> workflowSequence,
String workflowId,
ActionListener<WorkflowResponse> listener,
boolean isSyncExecution
) {
String currentStepId = "";
try {
Map<String, PlainActionFuture<?>> workflowFutureMap = new LinkedHashMap<>();
Expand Down Expand Up @@ -324,6 +426,23 @@ private void executeWorkflow(List<ProcessNode> workflowSequence, String workflow
),
ActionListener.wrap(updateResponse -> {
logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED);
if (isSyncExecution) {
client.execute(
GetWorkflowStateAction.INSTANCE,
new GetWorkflowStateRequest(workflowId, false),
ActionListener.wrap(response -> {
listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState()));
}, exception -> {
String errorMessage = "Failed to get workflow state.";
logger.error(errorMessage, exception);
if (exception instanceof FlowFrameworkException) {
listener.onFailure(exception);
} else {
listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)));
}
})
);
}
}, exception -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exception); })
);
} catch (Exception ex) {
Expand Down
Loading

0 comments on commit 95085ff

Please sign in to comment.