Skip to content

Commit

Permalink
refactor code based on comment
Browse files Browse the repository at this point in the history
Signed-off-by: Junwei Dai <[email protected]>
  • Loading branch information
Junwei Dai committed Jan 13, 2025
1 parent 44a1959 commit c3b9a88
Show file tree
Hide file tree
Showing 10 changed files with 468 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
Expand Down Expand Up @@ -44,7 +43,6 @@
import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS;
import static org.opensearch.flowframework.common.CommonValue.USE_CASE;
import static org.opensearch.flowframework.common.CommonValue.VALIDATION;
import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
Expand Down Expand Up @@ -89,7 +87,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false);
boolean reprovision = request.paramAsBoolean(REPROVISION_WORKFLOW, false);
boolean updateFields = request.paramAsBoolean(UPDATE_WORKFLOW_FIELDS, false);
TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, null);
String useCase = request.param(USE_CASE);

// If provisioning, consume all other params and pass to provision transport action
Expand Down Expand Up @@ -229,8 +226,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
validation,
provision || updateFields,
params,
reprovision,
waitForCompletionTimeout
reprovision
);

return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
String workflowId = request.param(WORKFLOW_ID);
TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, null);
TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, TimeValue.MINUS_ONE);
try {
Map<String, String> params = parseParamsAndContent(request);
if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES;
import static org.opensearch.flowframework.util.ParseUtils.checkFilterByBackendRoles;
import static org.opensearch.flowframework.util.ParseUtils.getUserContext;
Expand Down Expand Up @@ -248,11 +249,14 @@ private void createExecute(WorkflowRequest request, User user, ActionListener<Wo
ActionListener.wrap(stateResponse -> {
logger.info("Creating state workflow doc: {}", globalContextResponse.getId());
if (request.isProvision()) {
String waitForTimeCompletion = request.getParams()
.getOrDefault(WAIT_FOR_COMPLETION_TIMEOUT, TimeValue.MINUS_ONE.toString());
WorkflowRequest workflowRequest = new WorkflowRequest(
globalContextResponse.getId(),
null,
request.getParams(),
request.getWaitForCompletionTimeout()
// todo : what is this setting name represent?
TimeValue.parseTimeValue(waitForTimeCompletion, "provision.timout")
);
logger.info(
"Provisioning parameter is set, continuing to provision workflow {}",
Expand All @@ -262,14 +266,18 @@ private void createExecute(WorkflowRequest request, User user, ActionListener<Wo
ProvisionWorkflowAction.INSTANCE,
workflowRequest,
ActionListener.wrap(provisionResponse -> {
listener.onResponse(
request.getWaitForCompletionTimeout() != null
? new WorkflowResponse(
if (workflowRequest.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) {
listener.onResponse(
new WorkflowResponse(provisionResponse.getWorkflowId())
);
} else {
listener.onResponse(
new WorkflowResponse(
provisionResponse.getWorkflowId(),
provisionResponse.getWorkflowState()
)
: new WorkflowResponse(provisionResponse.getWorkflowId())
);
);
}
}, exception -> {
String errorMessage = "Provisioning failed.";
logger.error(errorMessage, exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
Expand All @@ -32,6 +33,7 @@
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.util.WorkflowTimeoutUtility;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.plugins.PluginsService;
Expand Down Expand Up @@ -212,26 +214,26 @@ private void executeProvisionRequest(
),
ActionListener.wrap(updateResponse -> {
logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING);
if (request.getWaitForCompletionTimeout() != null) {
if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) {
executeWorkflowAsync(workflowId, provisionProcessSequence, listener);
} else {
executeWorkflowSync(
workflowId,
provisionProcessSequence,
listener,
request.getWaitForCompletionTimeout().getMillis()
);
} else {
executeWorkflowAsync(workflowId, provisionProcessSequence, listener);
}
// update last provisioned field in template
Template newTemplate = Template.builder(template).lastProvisionedTime(Instant.now()).build();
flowFrameworkIndicesHandler.updateTemplateInGlobalContext(
request.getWorkflowId(),
newTemplate,
ActionListener.wrap(templateResponse -> {
if (request.getWaitForCompletionTimeout() != null) {
logger.info("Waiting for workflow completion");
} else {
if (request.getWaitForCompletionTimeout() == TimeValue.MINUS_ONE) {
listener.onResponse(new WorkflowResponse(request.getWorkflowId()));
} else {
logger.info("Waiting for workflow completion");
}
}, exception -> {
String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage(
Expand Down Expand Up @@ -312,101 +314,27 @@ private void executeWorkflowSync(
ActionListener<WorkflowResponse> listener,
long timeout
) {
PlainActionFuture<WorkflowResponse> workflowFuture = new PlainActionFuture<>();
AtomicBoolean isResponseSent = new AtomicBoolean(false);

CompletableFuture.runAsync(() -> {
try {
executeWorkflow(workflowSequence, workflowId, new ActionListener<>() {
@Override
public void onResponse(WorkflowResponse workflowResponse) {
if (isResponseSent.get()) {
logger.info("Ignoring onResponse for workflowId: {} as timeout already occurred", workflowId);
return;
}
isResponseSent.set(true);
workflowFuture.onResponse(null);
listener.onResponse(new WorkflowResponse(workflowResponse.getWorkflowId(), workflowResponse.getWorkflowState()));
WorkflowTimeoutUtility.handleResponse(workflowId, workflowResponse, isResponseSent, listener);
}

@Override
public void onFailure(Exception e) {
if (isResponseSent.get()) {
logger.info("Ignoring onFailure for workflowId: {} as timeout already occurred", workflowId);
return;
}
isResponseSent.set(true);
workflowFuture.onFailure(
new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(e))
);
listener.onFailure(
new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(e))
);
WorkflowTimeoutUtility.handleFailure(workflowId, e, isResponseSent, listener);
}
}, true);
} catch (Exception ex) {
if (!isResponseSent.get()) {
isResponseSent.set(true);
workflowFuture.onFailure(
new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(ex))
);
listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(ex)));
}
WorkflowTimeoutUtility.handleFailure(workflowId, ex, isResponseSent, listener);
}
}, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL));

// Schedule timeout handler
scheduleTimeoutHandler(workflowId, listener, timeout, isResponseSent);
}

/**
* Schedules a timeout handler for workflow execution.
* This method starts a new task in the thread pool to wait for the specified timeout duration.
* If the workflow does not complete within the given timeout, it triggers a follow-up action
* to fetch the workflow's state and notify the listener.
*
* @param workflowId The unique identifier of the workflow being executed.
* @param listener The ActionListener to notify with the workflow's response or failure.
* @param timeout The maximum time (in milliseconds) to wait for the workflow to complete before timing out.
* @param isResponseSent An AtomicBoolean flag to ensure the response is sent only once.
*/
private void scheduleTimeoutHandler(
String workflowId,
ActionListener<WorkflowResponse> listener,
long timeout,
AtomicBoolean isResponseSent
) {
threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> {
try {
Thread.sleep(timeout);
if (isResponseSent.compareAndSet(false, true)) {
logger.warn("Workflow execution timed out for workflowId: {}", workflowId);
fetchWorkflowStateAfterTimeout(workflowId, listener);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});
}

/**
* Fetches the workflow state after a timeout has occurred.
* This method sends a request to retrieve the current state of the workflow
* and notifies the listener with the updated state or an error if the request fails.
*
* @param workflowId The unique identifier of the workflow whose state needs to be fetched.
* @param listener The ActionListener to notify with the workflow's updated state or failure.
*/
private void fetchWorkflowStateAfterTimeout(String workflowId, ActionListener<WorkflowResponse> listener) {
client.execute(
GetWorkflowStateAction.INSTANCE,
new GetWorkflowStateRequest(workflowId, false),
ActionListener.wrap(
response -> listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())),
exception -> listener.onFailure(
new FlowFrameworkException("Failed to get workflow state after timeout", ExceptionsHelper.status(exception))
)
)
);
WorkflowTimeoutUtility.scheduleTimeoutHandler(client, threadPool, workflowId, listener, timeout, isResponseSent);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public class WorkflowRequest extends ActionRequest {
* @param template the use case template which describes the workflow
*/
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) {
this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), false, null);
this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), false, TimeValue.MINUS_ONE);
}

/**
Expand All @@ -86,7 +86,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template)
* @param params The parameters from the REST path
*/
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map<String, String> params) {
this(workflowId, template, new String[] { "all" }, true, params, false, null);
this(workflowId, template, new String[] { "all" }, true, params, false, TimeValue.MINUS_ONE);
}

/**
Expand All @@ -109,6 +109,26 @@ public WorkflowRequest(
this(workflowId, template, new String[] { "all" }, true, params, false, waitForCompletionTimeout);
}

/**
* Instantiates a new WorkflowRequest
* @param workflowId the documentId of the workflow
* @param template the use case template which describes the workflow
* @param validation flag to indicate if validation is necessary
* @param provisionOrUpdate provision or updateFields flag. Only one may be true, the presence of update_fields key in map indicates if updating fields, otherwise true means it's provisioning.
* @param params map of REST path params. If provisionOrUpdate is false, must be an empty map. If update_fields key is present, must be only key.
* @param reprovision flag to indicate if request is to reprovision
*/
public WorkflowRequest(
@Nullable String workflowId,
@Nullable Template template,
String[] validation,
boolean provisionOrUpdate,
Map<String, String> params,
boolean reprovision
) {
this(workflowId, template, validation, provisionOrUpdate, params, reprovision, TimeValue.MINUS_ONE);
}

/**
* Instantiates a new WorkflowRequest
* @param workflowId the documentId of the workflow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.flowframework.transport;

import org.opensearch.Version;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -47,7 +48,10 @@ public WorkflowResponse(String workflowId) {
public WorkflowResponse(StreamInput in) throws IOException {
super(in);
this.workflowId = in.readString();
this.workflowState = in.readOptionalWriteable(WorkflowState::new);
// todo : change version to 2_19_0
if (in.getVersion().onOrAfter(Version.CURRENT)) {
this.workflowState = in.readOptionalWriteable(WorkflowState::new);
}

}

Expand Down Expand Up @@ -89,7 +93,10 @@ public WorkflowResponse(String workflowId, WorkflowState workflowState) {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(workflowId);
out.writeOptionalWriteable(workflowState);
// todo : change version to 2_19_0
if (out.getVersion().onOrAfter(Version.CURRENT)) {
out.writeOptionalWriteable(workflowState);
}
}

@Override
Expand Down
Loading

0 comments on commit c3b9a88

Please sign in to comment.