Skip to content

Commit

Permalink
integrate transformation service into feature store
Browse files Browse the repository at this point in the history
  • Loading branch information
piket committed Oct 18, 2023
1 parent d28ac4a commit 2521da8
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 51 deletions.
2 changes: 1 addition & 1 deletion go/embedded/online_features.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (s *OnlineFeatureService) GetOnlineFeatures(

outputFields := make([]arrow.Field, 0)
outputColumns := make([]arrow.Array, 0)
pool := memory.NewCgoArrowAllocator()
pool := memory.NewGoAllocator()
for _, featureVector := range resp {
outputFields = append(outputFields,
arrow.Field{
Expand Down
10 changes: 8 additions & 2 deletions go/internal/feast/featurestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type FeatureStore struct {
registry *registry.Registry
onlineStore onlinestore.OnlineStore
transformationCallback transformation.TransformationCallback
transformationService *transformation.GrpcTransformationService
}

// A Features struct specifies a list of features to be retrieved from the online store. These features
Expand Down Expand Up @@ -54,12 +55,15 @@ func NewFeatureStore(config *registry.RepoConfig, callback transformation.Transf
if err != nil {
return nil, err
}
endpoint := "localhost:port" // TODO: replace with a config or real value
transformationService, _ := transformation.NewGrpcTransformationService(config, endpoint)

return &FeatureStore{
config: config,
registry: registry,
onlineStore: onlineStore,
transformationCallback: callback,
transformationService: transformationService,
}, nil
}

Expand Down Expand Up @@ -116,7 +120,7 @@ func (fs *FeatureStore) GetOnlineFeatures(
}

result := make([]*onlineserving.FeatureVector, 0)
arrowMemory := memory.NewCgoArrowAllocator()
arrowMemory := memory.NewGoAllocator()
featureViews := make([]*model.FeatureView, len(requestedFeatureViews))
index := 0
for _, featuresAndView := range requestedFeatureViews {
Expand Down Expand Up @@ -164,13 +168,15 @@ func (fs *FeatureStore) GetOnlineFeatures(
result = append(result, vectors...)
}

if fs.transformationCallback != nil {
if fs.transformationCallback != nil || fs.transformationService != nil {
onDemandFeatures, err := transformation.AugmentResponseWithOnDemandTransforms(
ctx,
requestedOnDemandFeatureViews,
requestData,
joinKeyToEntityValues,
result,
fs.transformationCallback,
fs.transformationService,
arrowMemory,
numRows,
fullFeatureNames,
Expand Down
41 changes: 30 additions & 11 deletions go/internal/feast/transformation/transformation.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package transformation

import (
"context"
"errors"
"fmt"
"runtime"
Expand Down Expand Up @@ -32,11 +33,13 @@ Python function is expected to return number of rows added to the output record
type TransformationCallback func(ODFVName string, inputArrPtr, inputSchemaPtr, outArrPtr, outSchemaPtr uintptr, fullFeatureNames bool) int

func AugmentResponseWithOnDemandTransforms(
ctx context.Context,
onDemandFeatureViews []*model.OnDemandFeatureView,
requestData map[string]*prototypes.RepeatedValue,
entityRows map[string]*prototypes.RepeatedValue,
features []*onlineserving.FeatureVector,
transformationCallback TransformationCallback,
transformationService *GrpcTransformationService,
arrowMemory memory.Allocator,
numRows int,
fullFeatureNames bool,
Expand Down Expand Up @@ -68,17 +71,33 @@ func AugmentResponseWithOnDemandTransforms(
retrievedFeatures[vector.Name] = vector.Values
}

onDemandFeatures, err := CallTransformations(
odfv,
retrievedFeatures,
requestContextArrow,
transformationCallback,
numRows,
fullFeatureNames,
)
if err != nil {
ReleaseArrowContext(requestContextArrow)
return nil, err
var onDemandFeatures []*onlineserving.FeatureVector
if transformationService != nil {
onDemandFeatures, err = transformationService.GetTransformation(
ctx,
odfv,
retrievedFeatures,
requestContextArrow,
numRows,
fullFeatureNames,
)
if err != nil {
ReleaseArrowContext(requestContextArrow)
return nil, err
}
} else {
onDemandFeatures, err = CallTransformations(
odfv,
retrievedFeatures,
requestContextArrow,
transformationCallback,
numRows,
fullFeatureNames,
)
if err != nil {
ReleaseArrowContext(requestContextArrow)
return nil, err
}
}
result = append(result, onDemandFeatures...)

Expand Down
64 changes: 27 additions & 37 deletions go/internal/feast/transformation/transformation_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"github.com/feast-dev/feast/go/internal/feast/registry"
"google.golang.org/protobuf/types/known/timestamppb"
"strings"

Expand All @@ -14,50 +15,49 @@ import (
"github.com/feast-dev/feast/go/internal/feast/model"
"github.com/feast-dev/feast/go/internal/feast/onlineserving"
"github.com/feast-dev/feast/go/protos/feast/serving"
prototypes "github.com/feast-dev/feast/go/protos/feast/types"
"github.com/feast-dev/feast/go/types"
"google.golang.org/grpc"
"io"
)

type grpcTransformationService struct {
endpoint string
project string
type GrpcTransformationService struct {
project string
conn *grpc.ClientConn
client *serving.TransformationServiceClient
}

func (s *grpcTransformationService) GetTransformation(
func NewGrpcTransformationService(config *registry.RepoConfig, endpoint string) (*GrpcTransformationService, error) {
opts := make([]grpc.DialOption, 0)
opts = append(opts, grpc.WithDefaultCallOptions())

conn, err := grpc.Dial(endpoint, opts...)
if err != nil {
return nil, err
}
client := serving.NewTransformationServiceClient(conn)
return &GrpcTransformationService{ config.Project, conn, &client }, nil
}

func (s *GrpcTransformationService) Close() error {
return s.conn.Close()
}

func (s *GrpcTransformationService) GetTransformation(
ctx context.Context,
featureView *model.OnDemandFeatureView,
requestData map[string]*prototypes.RepeatedValue,
entityRows map[string]*prototypes.RepeatedValue,
features []*onlineserving.FeatureVector,
retrievedFeatures map[string]arrow.Array,
requestContext map[string]arrow.Array,
numRows int,
fullFeatureNames bool,
) ([]*onlineserving.FeatureVector, error) {
var err error
arrowMemory := memory.NewGoAllocator()

inputFields := make([]arrow.Field, 0)
inputColumns := make([]arrow.Array, 0)
for _, vector := range features {
inputFields = append(inputFields, arrow.Field{Name: vector.Name, Type: vector.Values.DataType()})
inputColumns = append(inputColumns, vector.Values)
}

for name, values := range requestData {
arr, err := types.ProtoValuesToArrowArray(values.Val, arrowMemory, numRows)
if err != nil {
return nil, err
}
for name, arr := range retrievedFeatures {
inputFields = append(inputFields, arrow.Field{Name: name, Type: arr.DataType()})
inputColumns = append(inputColumns, arr)
}

for name, values := range entityRows {
arr, err := types.ProtoValuesToArrowArray(values.Val, arrowMemory, numRows)
if err != nil {
return nil, err
}
for name, arr := range requestContext {
inputFields = append(inputFields, arrow.Field{Name: name, Type: arr.DataType()})
inputColumns = append(inputColumns, arr)
}
Expand All @@ -83,17 +83,7 @@ func (s *grpcTransformationService) GetTransformation(
TransformationInput: &transformationInput,
}

opts := make([]grpc.DialOption, 0)
opts = append(opts, grpc.WithDefaultCallOptions())

conn, err := grpc.Dial(s.endpoint, opts...)
if err != nil {
return nil, err
}
defer conn.Close()
client := serving.NewTransformationServiceClient(conn)

res, err := client.TransformFeatures(ctx, &req)
res, err := (*s.client).TransformFeatures(ctx, &req)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 2521da8

Please sign in to comment.