Skip to content

Commit

Permalink
Refactor flyteadmin to pass proto structs as pointers (#5717)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sovietaced authored Sep 5, 2024
1 parent 5b6bd52 commit 5f69589
Show file tree
Hide file tree
Showing 198 changed files with 1,472 additions and 1,733 deletions.
5 changes: 3 additions & 2 deletions flyteadmin/.golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ linters-settings:
- prefix(github.com/flyteorg)
skip-generated: true
issues:
exclude:
- copylocks
exclude-rules:
- path: pkg/workflowengine/impl/prepare_execution.go
text: "copies lock"
32 changes: 16 additions & 16 deletions flyteadmin/dataproxy/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (s Service) CreateDownloadLink(ctx context.Context, req *service.CreateDown
// Lookup task, node, workflow execution
var nativeURL string
if nodeExecutionIDEnvelope, casted := req.GetSource().(*service.CreateDownloadLinkRequest_NodeExecutionId); casted {
node, err := s.nodeExecutionManager.GetNodeExecution(ctx, admin.NodeExecutionGetRequest{
node, err := s.nodeExecutionManager.GetNodeExecution(ctx, &admin.NodeExecutionGetRequest{
Id: nodeExecutionIDEnvelope.NodeExecutionId,
})

Expand Down Expand Up @@ -309,9 +309,9 @@ func (s Service) validateResolveArtifactRequest(req *service.GetDataRequest) err

// GetCompleteTaskExecutionID returns the task execution identifier for the task execution with the Task ID filled in.
// The one coming from the node execution doesn't have this as this is not data encapsulated in the flyte url.
func (s Service) GetCompleteTaskExecutionID(ctx context.Context, taskExecID core.TaskExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {
func (s Service) GetCompleteTaskExecutionID(ctx context.Context, taskExecID *core.TaskExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {

taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, admin.TaskExecutionListRequest{
taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, &admin.TaskExecutionListRequest{
NodeExecutionId: taskExecID.GetNodeExecutionId(),
Limit: 1,
Filters: fmt.Sprintf("eq(retry_attempt,%s)", strconv.Itoa(int(taskExecID.RetryAttempt))),
Expand All @@ -326,9 +326,9 @@ func (s Service) GetCompleteTaskExecutionID(ctx context.Context, taskExecID core
return taskExec.Id, nil
}

func (s Service) GetTaskExecutionID(ctx context.Context, attempt int, nodeExecID core.NodeExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {
taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, admin.TaskExecutionListRequest{
NodeExecutionId: &nodeExecID,
func (s Service) GetTaskExecutionID(ctx context.Context, attempt int, nodeExecID *core.NodeExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {
taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, &admin.TaskExecutionListRequest{
NodeExecutionId: nodeExecID,
Limit: 1,
Filters: fmt.Sprintf("eq(retry_attempt,%s)", strconv.Itoa(attempt)),
})
Expand All @@ -342,11 +342,11 @@ func (s Service) GetTaskExecutionID(ctx context.Context, attempt int, nodeExecID
return taskExec.Id, nil
}

func (s Service) GetDataFromNodeExecution(ctx context.Context, nodeExecID core.NodeExecutionIdentifier, ioType common.ArtifactType, name string) (
func (s Service) GetDataFromNodeExecution(ctx context.Context, nodeExecID *core.NodeExecutionIdentifier, ioType common.ArtifactType, name string) (
*service.GetDataResponse, error) {

resp, err := s.nodeExecutionManager.GetNodeExecutionData(ctx, admin.NodeExecutionGetDataRequest{
Id: &nodeExecID,
resp, err := s.nodeExecutionManager.GetNodeExecutionData(ctx, &admin.NodeExecutionGetDataRequest{
Id: nodeExecID,
})
if err != nil {
return nil, err
Expand All @@ -361,7 +361,7 @@ func (s Service) GetDataFromNodeExecution(ctx context.Context, nodeExecID core.N
// Assume deck, and create a download link request
dlRequest := service.CreateDownloadLinkRequest{
ArtifactType: service.ArtifactType_ARTIFACT_TYPE_DECK,
Source: &service.CreateDownloadLinkRequest_NodeExecutionId{NodeExecutionId: &nodeExecID},
Source: &service.CreateDownloadLinkRequest_NodeExecutionId{NodeExecutionId: nodeExecID},
}
resp, err := s.CreateDownloadLink(ctx, &dlRequest)
if err != nil {
Expand Down Expand Up @@ -391,12 +391,12 @@ func (s Service) GetDataFromNodeExecution(ctx context.Context, nodeExecID core.N
}, nil
}

func (s Service) GetDataFromTaskExecution(ctx context.Context, taskExecID core.TaskExecutionIdentifier, ioType common.ArtifactType, name string) (
func (s Service) GetDataFromTaskExecution(ctx context.Context, taskExecID *core.TaskExecutionIdentifier, ioType common.ArtifactType, name string) (
*service.GetDataResponse, error) {

var lm *core.LiteralMap
reqT := admin.TaskExecutionGetDataRequest{
Id: &taskExecID,
reqT := &admin.TaskExecutionGetDataRequest{
Id: taskExecID,
}
resp, err := s.taskExecutionManager.GetTaskExecutionData(ctx, reqT)
if err != nil {
Expand Down Expand Up @@ -445,13 +445,13 @@ func (s Service) GetData(ctx context.Context, req *service.GetDataRequest) (
}

if execution.NodeExecID != nil {
return s.GetDataFromNodeExecution(ctx, *execution.NodeExecID, execution.IOType, execution.LiteralName)
return s.GetDataFromNodeExecution(ctx, execution.NodeExecID, execution.IOType, execution.LiteralName)
} else if execution.PartialTaskExecID != nil {
taskExecID, err := s.GetCompleteTaskExecutionID(ctx, *execution.PartialTaskExecID)
taskExecID, err := s.GetCompleteTaskExecutionID(ctx, execution.PartialTaskExecID)
if err != nil {
return nil, err
}
return s.GetDataFromTaskExecution(ctx, *taskExecID, execution.IOType, execution.LiteralName)
return s.GetDataFromTaskExecution(ctx, taskExecID, execution.IOType, execution.LiteralName)
}

return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to parse get data request %v", req)
Expand Down
16 changes: 8 additions & 8 deletions flyteadmin/dataproxy/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func TestCreateUploadLocationMore(t *testing.T) {
func TestCreateDownloadLink(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
nodeExecutionManager.SetGetNodeExecutionFunc(func(ctx context.Context, request admin.NodeExecutionGetRequest) (*admin.NodeExecution, error) {
nodeExecutionManager.SetGetNodeExecutionFunc(func(ctx context.Context, request *admin.NodeExecutionGetRequest) (*admin.NodeExecution, error) {
return &admin.NodeExecution{
Closure: &admin.NodeExecutionClosure{
DeckUri: "s3://something/something",
Expand Down Expand Up @@ -282,14 +282,14 @@ func TestService_GetData(t *testing.T) {
}

nodeExecutionManager.SetGetNodeExecutionDataFunc(
func(ctx context.Context, request admin.NodeExecutionGetDataRequest) (*admin.NodeExecutionGetDataResponse, error) {
func(ctx context.Context, request *admin.NodeExecutionGetDataRequest) (*admin.NodeExecutionGetDataResponse, error) {
return &admin.NodeExecutionGetDataResponse{
FullInputs: inputsLM,
FullOutputs: outputsLM,
}, nil
},
)
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return &admin.TaskExecutionList{
TaskExecutions: []*admin.TaskExecution{
{
Expand All @@ -315,7 +315,7 @@ func TestService_GetData(t *testing.T) {
},
}, nil
})
taskExecutionManager.SetGetTaskExecutionDataCallback(func(ctx context.Context, request admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) {
taskExecutionManager.SetGetTaskExecutionDataCallback(func(ctx context.Context, request *admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) {
return &admin.TaskExecutionGetDataResponse{
FullInputs: inputsLM,
FullOutputs: outputsLM,
Expand Down Expand Up @@ -388,10 +388,10 @@ func TestService_Error(t *testing.T) {
assert.NoError(t, err)

t.Run("get a working set of urls without retry attempt", func(t *testing.T) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return nil, errors.NewFlyteAdminErrorf(1, "not found")
})
nodeExecID := core.NodeExecutionIdentifier{
nodeExecID := &core.NodeExecutionIdentifier{
NodeId: "n0",
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "proj",
Expand All @@ -404,13 +404,13 @@ func TestService_Error(t *testing.T) {
})

t.Run("get a working set of urls without retry attempt", func(t *testing.T) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return &admin.TaskExecutionList{
TaskExecutions: nil,
Token: "",
}, nil
})
nodeExecID := core.NodeExecutionIdentifier{
nodeExecID := &core.NodeExecutionIdentifier{
NodeId: "n0",
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "proj",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecution

// This is a rough copy of the ListTaskExecutions function in TaskExecutionManager. It can be deprecated once we move the processing out of Admin itself.
// Just return the highest retry attempt.
func (c *CloudEventWrappedPublisher) getLatestTaskExecutions(ctx context.Context, nodeExecutionID core.NodeExecutionIdentifier) (*admin.TaskExecution, error) {
ctx = getNodeExecutionContext(ctx, &nodeExecutionID)
func (c *CloudEventWrappedPublisher) getLatestTaskExecutions(ctx context.Context, nodeExecutionID *core.NodeExecutionIdentifier) (*admin.TaskExecution, error) {
ctx = getNodeExecutionContext(ctx, nodeExecutionID)

identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, nodeExecutionID)
if err != nil {
Expand Down Expand Up @@ -283,7 +283,7 @@ func (c *CloudEventWrappedPublisher) TransformNodeExecutionEvent(ctx context.Con
var taskExecID *core.TaskExecutionIdentifier
var typedInterface *core.TypedInterface

lte, err := c.getLatestTaskExecutions(ctx, *rawEvent.Id)
lte, err := c.getLatestTaskExecutions(ctx, rawEvent.Id)
if err != nil {
logger.Errorf(ctx, "failed to get latest task execution for node exec id [%+v] with err: %v", rawEvent.Id, err)
return nil, err
Expand Down Expand Up @@ -353,7 +353,7 @@ func (c *CloudEventWrappedPublisher) Publish(ctx context.Context, notificationTy
phase = e.Phase.String()
eventTime = e.OccurredAt.AsTime()

dummyNodeExecutionID := core.NodeExecutionIdentifier{
dummyNodeExecutionID := &core.NodeExecutionIdentifier{
NodeId: "end-node",
ExecutionId: e.ExecutionId,
}
Expand All @@ -378,7 +378,7 @@ func (c *CloudEventWrappedPublisher) Publish(ctx context.Context, notificationTy
if e.ParentNodeExecutionId == nil {
return fmt.Errorf("parent node execution id is nil for task execution [%+v]", e)
}
eventSource = common.FlyteURLKeyFromNodeExecutionIDRetry(*e.ParentNodeExecutionId,
eventSource = common.FlyteURLKeyFromNodeExecutionIDRetry(e.ParentNodeExecutionId,
int(e.RetryAttempt))
finalMsg, err = c.TransformTaskExecutionEvent(ctx, e)
if err != nil {
Expand All @@ -392,7 +392,7 @@ func (c *CloudEventWrappedPublisher) Publish(ctx context.Context, notificationTy
phase = e.Phase.String()
eventTime = e.OccurredAt.AsTime()
eventID = fmt.Sprintf("%v.%v", executionID, phase)
eventSource = common.FlyteURLKeyFromNodeExecutionID(*msgType.Event.Id)
eventSource = common.FlyteURLKeyFromNodeExecutionID(msgType.Event.Id)
finalMsg, err = c.TransformNodeExecutionEvent(ctx, e)
if err != nil {
logger.Errorf(ctx, "Failed to transform node execution event with error: %v", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (
// events, node execution processing doesn't have to wait on these to be committed.
type nodeExecutionEventWriter struct {
db repositoryInterfaces.Repository
events chan admin.NodeExecutionEventRequest
events chan *admin.NodeExecutionEventRequest
}

func (w *nodeExecutionEventWriter) Write(event admin.NodeExecutionEventRequest) {
func (w *nodeExecutionEventWriter) Write(event *admin.NodeExecutionEventRequest) {
w.events <- event
}

Expand All @@ -40,6 +40,6 @@ func (w *nodeExecutionEventWriter) Run() {
func NewNodeExecutionEventWriter(db repositoryInterfaces.Repository, bufferSize int) interfaces.NodeExecutionEventWriter {
return &nodeExecutionEventWriter{
db: db,
events: make(chan admin.NodeExecutionEventRequest, bufferSize),
events: make(chan *admin.NodeExecutionEventRequest, bufferSize),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func TestNodeExecutionEventWriter(t *testing.T) {
db := mocks.NewMockRepository()

event := admin.NodeExecutionEventRequest{
event := &admin.NodeExecutionEventRequest{
RequestId: "request_id",
Event: &event2.NodeExecutionEvent{
Id: &core.NodeExecutionIdentifier{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (
// events, workflow execution processing doesn't have to wait on these to be committed.
type workflowExecutionEventWriter struct {
db repositoryInterfaces.Repository
events chan admin.WorkflowExecutionEventRequest
events chan *admin.WorkflowExecutionEventRequest
}

func (w *workflowExecutionEventWriter) Write(event admin.WorkflowExecutionEventRequest) {
func (w *workflowExecutionEventWriter) Write(event *admin.WorkflowExecutionEventRequest) {
w.events <- event
}

Expand All @@ -40,6 +40,6 @@ func (w *workflowExecutionEventWriter) Run() {
func NewWorkflowExecutionEventWriter(db repositoryInterfaces.Repository, bufferSize int) interfaces.WorkflowExecutionEventWriter {
return &workflowExecutionEventWriter{
db: db,
events: make(chan admin.WorkflowExecutionEventRequest, bufferSize),
events: make(chan *admin.WorkflowExecutionEventRequest, bufferSize),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func TestWorkflowExecutionEventWriter(t *testing.T) {
db := mocks.NewMockRepository()

event := admin.WorkflowExecutionEventRequest{
event := &admin.WorkflowExecutionEventRequest{
RequestId: "request_id",
Event: &event2.WorkflowExecutionEvent{
ExecutionId: &core.WorkflowExecutionIdentifier{
Expand Down
2 changes: 1 addition & 1 deletion flyteadmin/pkg/async/events/interfaces/node_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ import (

type NodeExecutionEventWriter interface {
Run()
Write(nodeExecutionEvent admin.NodeExecutionEventRequest)
Write(nodeExecutionEvent *admin.NodeExecutionEventRequest)
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ import (

type WorkflowExecutionEventWriter interface {
Run()
Write(workflowExecutionEvent admin.WorkflowExecutionEventRequest)
Write(workflowExecutionEvent *admin.WorkflowExecutionEventRequest)
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 5f69589

Please sign in to comment.