Skip to content

Commit

Permalink
Try with single big request with all keys
Browse files Browse the repository at this point in the history
  • Loading branch information
acevedosharp committed Dec 4, 2024
1 parent ddfcca1 commit 7899223
Showing 1 changed file with 70 additions and 99 deletions.
169 changes: 70 additions & 99 deletions go/internal/feast/onlinestore/cassandraonlinestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/feast-dev/feast/go/protos/feast/serving"
"google.golang.org/protobuf/types/known/timestamppb"
"os"
"strings"
"sync"
"time"

"github.com/feast-dev/feast/go/internal/feast/registry"
"github.com/feast-dev/feast/go/internal/feast/utils"
"github.com/feast-dev/feast/go/protos/feast/serving"
"github.com/feast-dev/feast/go/protos/feast/types"
"github.com/gocql/gocql"
"github.com/golang/protobuf/proto"
"github.com/rs/zerolog/log"

"google.golang.org/protobuf/types/known/timestamppb"
gocqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/gocql/gocql"
)

Expand Down Expand Up @@ -209,16 +208,22 @@ func (c *CassandraOnlineStore) getFqTableName(tableName string) string {
return fmt.Sprintf(`"%s"."%s_%s"`, c.clusterConfigs.Keyspace, c.project, tableName)
}

func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string) string {
func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string, nKeys int) string {
// this prevents fetching unnecessary features
quotedFeatureNames := make([]string, len(featureNames))
for i, featureName := range featureNames {
quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName)
}

keyPlaceholders := make([]string, nKeys)
for i := 0; i < nKeys; i++ {
keyPlaceholders[i] = "?"
}

return fmt.Sprintf(
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" = ? AND "feature_name" IN (%s)`,
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" IN (%s) AND "feature_name" IN (%s)`,
tableName,
strings.Join(keyPlaceholders, ","),
strings.Join(quotedFeatureNames, ","),
)
}
Expand Down Expand Up @@ -265,110 +270,76 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ

// Prepare the query
tableName := c.getFqTableName(featureViewName)
cqlStatement := c.getCQLStatement(tableName, featureNames)

var waitGroup sync.WaitGroup
waitGroup.Add(len(serializedEntityKeys))

errorsChannel := make(chan error, len(serializedEntityKeys))
for _, serializedEntityKey := range serializedEntityKeys {
go func(serEntityKey any) {
defer waitGroup.Done()

iter := c.session.Query(cqlStatement, serEntityKey).WithContext(ctx).Iter()

rowIdx := serializedEntityKeyToIndex[serializedEntityKey.(string)]

// fill the row with nulls if not found
if iter.NumRows() == 0 {
for _, featName := range featureNames {
results[rowIdx][featureNamesToIdx[featName]] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featName,
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
},
}
}
return
}

scanner := iter.Scanner()
var entityKey string
var featureName string
var eventTs time.Time
var valueStr []byte
var deserializedValue types.Value
rowFeatures := make(map[string]FeatureData)
for scanner.Next() {
err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr)
if err != nil {
errorsChannel <- errors.New("could not read row in query for (entity key, feature name, value, event ts)")
return
}
if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil {
errorsChannel <- errors.New("error converting parsed Cassandra Value to types.Value")
return
}
cqlStatement := c.getCQLStatement(tableName, featureNames, len(entityKeys))

scanner := c.session.Query(cqlStatement, serializedEntityKeys...).Iter().Scanner()

// Process the results
var entityKey string
var featureName string
var eventTs time.Time
var valueStr []byte
var deserializedValue types.Value
for scanner.Next() {
err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr)
if err != nil {
return nil, errors.New("could not read row in query for (entity key, feature name, value, event ts)")
}
if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil {
return nil, errors.New("error converting parsed Cassandra Value to types.Value")
}

if deserializedValue.Val != nil {
// Convert the value to a FeatureData struct
rowFeatures[featureName] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureName,
},
Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())},
Value: types.Value{
Val: deserializedValue.Val,
},
}
}
var featureValues FeatureData
if deserializedValue.Val != nil {
// Convert the value to a FeatureData struct
featureValues = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureName,
},
Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())},
Value: types.Value{
Val: deserializedValue.Val,
},
}

if err := scanner.Err(); err != nil {
errorsChannel <- errors.New("failed to scan features: " + err.Error())
return
} else {
// Return FeatureData with a null value
featureValues = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureName,
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
},
}
}
// Add the FeatureData to the results
rowIndx := serializedEntityKeyToIndex[entityKey]
results[rowIndx][featureNamesToIdx[featureName]] = featureValues
}

for _, featName := range featureNames {
featureData, ok := rowFeatures[featName]
if !ok {
featureData = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featName,
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
for i := 0; i < len(entityKeys); i++ {
for j := 0; j < len(featureNames); j++ {
if results[i][j].Timestamp.GetSeconds() == 0 {
results[i][j] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureViewNames[j],
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
}
},
}
results[rowIdx][featureNamesToIdx[featName]] = featureData
}
}(serializedEntityKey)
}

// wait until all concurrent single-key queries are done
waitGroup.Wait()
close(errorsChannel)

var collectedErrors []error
for err := range errorsChannel {
if err != nil {
collectedErrors = append(collectedErrors, err)
}
}
if len(collectedErrors) > 0 {
return nil, errors.Join(collectedErrors...)
}

// wait until all concurrent single-key queries are done
return results, nil
}

Expand Down

0 comments on commit 7899223

Please sign in to comment.