Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multiple device refresh improvements #2281

Merged
merged 12 commits into from
Jan 31, 2024
44 changes: 43 additions & 1 deletion services/skus/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func Router(
cr.Method(http.MethodGet, "/items/{itemID}/batches/{requestID}", metricsMwr("GetOrderCredsByID", getOrderCredsByID(svc, false)))

cr.Method(http.MethodPut, "/items/{itemID}/batches/{requestID}", metricsMwr("CreateOrderItemCreds", createItemCreds(svc)))
cr.Method(http.MethodDelete, "/items/{itemID}/batches/{requestID}", metricsMwr("DeleteOrderItemCreds", authMwr(deleteItemCreds(svc))))
})

return r
Expand Down Expand Up @@ -590,6 +591,47 @@ func CreateOrderCreds(svc *Service) handlers.AppHandler {
}
}

// deleteItemCreds handles requests for deleting credentials for an item.
func deleteItemCreds(svc *Service) handlers.AppHandler {
return func(w http.ResponseWriter, r *http.Request) *handlers.AppError {
ctx := r.Context()
lg := logging.Logger(ctx, "skus.deleteItemCreds")
husobee marked this conversation as resolved.
Show resolved Hide resolved

orderID := &inputs.ID{}
if err := inputs.DecodeAndValidateString(ctx, orderID, chi.URLParamFromCtx(ctx, "orderID")); err != nil {
husobee marked this conversation as resolved.
Show resolved Hide resolved
lg.Error().Err(err).Msg("failed to validate order id")
return handlers.ValidationError("Error validating request url parameter", map[string]interface{}{
"orderID": err.Error(),
})
}

itemID := &inputs.ID{}
if err := inputs.DecodeAndValidateString(ctx, itemID, chi.URLParamFromCtx(ctx, "itemID")); err != nil {
lg.Error().Err(err).Msg("failed to validate item id")
return handlers.ValidationError("Error validating request url parameter", map[string]interface{}{
husobee marked this conversation as resolved.
Show resolved Hide resolved
"itemID": err.Error(),
})
}

reqID := &inputs.ID{}
if err := inputs.DecodeAndValidateString(ctx, reqID, chi.URLParamFromCtx(ctx, "requestID")); err != nil {
lg.Error().Err(err).Msg("failed to validate request id")
return handlers.ValidationError("Error validating request url parameter", map[string]interface{}{
"requestID": err.Error(),
})
}

isSigned := r.URL.Query().Get("isSigned") == "true"

if err := svc.DeleteOrderCreds(ctx, *orderID.UUID(), *reqID.UUID(), isSigned); err != nil {
lg.Error().Err(err).Msg("failed to delete the order credentials")
return handlers.WrapError(err, "Error deleting credentials", http.StatusBadRequest)
husobee marked this conversation as resolved.
Show resolved Hide resolved
}

return handlers.RenderContent(ctx, nil, w, http.StatusOK)
}
}

// createItemCredsRequest includes the blinded credentials to be signed.
type createItemCredsRequest struct {
BlindedCreds []string `json:"blindedCreds" valid:"base64"`
Expand Down Expand Up @@ -695,7 +737,7 @@ func DeleteOrderCreds(service *Service) handlers.AppHandler {
}

isSigned := r.URL.Query().Get("isSigned") == "true"
if err := service.DeleteOrderCreds(ctx, id, isSigned); err != nil {
if err := service.DeleteOrderCreds(ctx, id, uuid.Nil, isSigned); err != nil {
return handlers.WrapError(err, "Error deleting credentials", http.StatusBadRequest)
}

Expand Down
34 changes: 32 additions & 2 deletions services/skus/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ var (
errItemDoesNotExist model.Error = "order item does not exist for order"
errCredsAlreadySubmitted model.Error = "credentials already submitted"

errMaxActiveOrderCredsExceeded model.Error = "maximum active order credentials exceeded"

defaultExpiresAt = time.Now().Add(17532 * time.Hour) // 2 years
retryPolicy = retrypolicy.DefaultRetry
dontRetryCodes = map[int]struct{}{
Expand Down Expand Up @@ -688,7 +690,7 @@ func (s *SigningOrderResultErrorHandler) Handle(ctx context.Context, message kaf
// - create repos for credentials;
// - move the corresponding methods there;
// - make those methods work on per-item basis.
func (s *Service) DeleteOrderCreds(ctx context.Context, orderID uuid.UUID, isSigned bool) error {
func (s *Service) DeleteOrderCreds(ctx context.Context, orderID uuid.UUID, reqID uuid.UUID, isSigned bool) error {
order, err := s.Datastore.GetOrder(orderID)
if err != nil {
return err
Expand Down Expand Up @@ -720,7 +722,7 @@ func (s *Service) DeleteOrderCreds(ctx context.Context, orderID uuid.UUID, isSig
}

if doTlv2 {
if err := s.Datastore.DeleteTimeLimitedV2OrderCredsByOrderTx(ctx, tx, orderID); err != nil {
if err := s.deleteTLV2(ctx, tx, order, reqID); err != nil {
return fmt.Errorf("error deleting time limited v2 order creds: %w", err)
}
}
Expand All @@ -736,6 +738,34 @@ func (s *Service) DeleteOrderCreds(ctx context.Context, orderID uuid.UUID, isSig
return nil
}

// maxTLV2ActiveItemCreds is the max number of credentials an item is allowed to have in the given day
const maxTLV2ActiveOrderCreds = 10

func (s *Service) deleteTLV2(ctx context.Context, dbi sqlx.ExtContext, order *model.Order, reqID uuid.UUID) error {

// Pass the request id as an "item id", which will allow for legacy credentials to be deleted.
// Otherwise, do not delete said credentials for multiple device support.
if !uuid.Equal(reqID, uuid.Nil) {
// check if we already have N active credentials on this item for the current day
activeCreds, err := s.Datastore.GetCountActiveOrderCreds(ctx, dbi, order.ID)
if err != nil {
return fmt.Errorf("failed to get count of active order credentials: %w", err)
}
if activeCreds > maxTLV2ActiveOrderCreds {
return errMaxActiveOrderCredsExceeded
}
return s.Datastore.DeleteTimeLimitedV2OrderCredsByOrderTx(ctx, dbi, order.ID, reqID)
}

itemIDs := make([]uuid.UUID, 0, len(order.Items))
// Legacy, delete all items.
for i := range order.Items {
itemIDs = append(itemIDs, order.Items[i].ID)
}

return s.Datastore.DeleteTimeLimitedV2OrderCredsByOrderTx(ctx, dbi, order.ID, itemIDs...)
}

// checkNumBlindedCreds checks the number of submitted blinded credentials.
//
// The number of submitted credentials must not exceed:
Expand Down
105 changes: 105 additions & 0 deletions services/skus/credentials_noint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package skus

import (
"context"
"testing"

"github.com/golang/mock/gomock"
uuid "github.com/satori/go.uuid"
should "github.com/stretchr/testify/assert"
must "github.com/stretchr/testify/require"

"github.com/brave-intl/bat-go/services/skus/model"
)

func TestService_DeleteTLV2(t *testing.T) {
type tcGiven struct {
ord *model.Order
reqID uuid.UUID
}

type tcExpected struct {
args []gomock.Matcher
err error
}

type testCase struct {
name string
given tcGiven
exp tcExpected
}

tests := []testCase{
{
name: "request_id_specified",
given: tcGiven{
ord: &model.Order{ID: uuid.Must(uuid.FromString("a6b72f11-c886-49ee-b4f4-913eaa0984ae"))},
reqID: uuid.Must(uuid.FromString("d10cf2ae-30d8-4ada-965c-11c582968f26")),
},
exp: tcExpected{
args: []gomock.Matcher{
gomock.Eq(context.Background()),
gomock.Nil(), // dbi
gomock.Eq(uuid.Must(uuid.FromString("a6b72f11-c886-49ee-b4f4-913eaa0984ae"))),
gomock.Eq([]uuid.UUID{uuid.Must(uuid.FromString("d10cf2ae-30d8-4ada-965c-11c582968f26"))}),
},
},
},

{
name: "request_id_nil",
given: tcGiven{
ord: &model.Order{
ID: uuid.Must(uuid.FromString("d7b4f524-ebe3-48c1-b60d-d6f548b25aae")),
Items: []model.OrderItem{
{ID: uuid.Must(uuid.FromString("3882ec99-73fb-476b-b176-1f4b40a9b767"))},
{ID: uuid.Must(uuid.FromString("ed437d36-182b-460f-8213-2ce3d4bb5c93"))},
{ID: uuid.Must(uuid.FromString("d3e62075-996f-4bed-bbc7-f6cd324b83e0"))},
},
},
reqID: uuid.Nil,
},
exp: tcExpected{
args: []gomock.Matcher{
gomock.Eq(context.Background()),
gomock.Nil(), // dbi
gomock.Eq(uuid.Must(uuid.FromString("d7b4f524-ebe3-48c1-b60d-d6f548b25aae"))),
gomock.Eq([]uuid.UUID{
uuid.Must(uuid.FromString("3882ec99-73fb-476b-b176-1f4b40a9b767")),
uuid.Must(uuid.FromString("ed437d36-182b-460f-8213-2ce3d4bb5c93")),
uuid.Must(uuid.FromString("d3e62075-996f-4bed-bbc7-f6cd324b83e0")),
}),
},
},
},
}

for i := range tests {
tc := tests[i]

t.Run(tc.name, func(t *testing.T) {
must.Equal(t, 4, len(tc.exp.args))

ctrl := gomock.NewController(t)
defer ctrl.Finish()

ds := NewMockDatastore(ctrl)

svc := &Service{Datastore: ds}

ctx := context.Background()
if tc.given.reqID != uuid.Nil {
ds.EXPECT().GetCountActiveOrderCreds(
tc.exp.args[0], tc.exp.args[1], tc.exp.args[2],
).Return(0, nil)
}

ds.EXPECT().DeleteTimeLimitedV2OrderCredsByOrderTx(
tc.exp.args[0], tc.exp.args[1], tc.exp.args[2], tc.exp.args[3],
).Return(nil)

actual := svc.deleteTLV2(ctx, nil, tc.given.ord, tc.given.reqID)
should.Equal(t, tc.exp.err, actual)
})
}
}
31 changes: 28 additions & 3 deletions services/skus/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ type Datastore interface {
AreTimeLimitedV2CredsSubmitted(ctx context.Context, blindedCreds ...string) (bool, error)
GetTimeLimitedV2OrderCredsByOrder(orderID uuid.UUID) (*TimeLimitedV2Creds, error)
GetTLV2Creds(ctx context.Context, dbi sqlx.QueryerContext, ordID, itemID, reqID uuid.UUID) (*TimeLimitedV2Creds, error)
DeleteTimeLimitedV2OrderCredsByOrderTx(ctx context.Context, tx *sqlx.Tx, orderID uuid.UUID) error
DeleteTimeLimitedV2OrderCredsByOrderTx(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID, itemIDs ...uuid.UUID) error
GetTimeLimitedV2OrderCredsByOrderItem(itemID uuid.UUID) (*TimeLimitedV2Creds, error)
GetCountActiveOrderCreds(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID) (int, error)
InsertTimeLimitedV2OrderCredsTx(ctx context.Context, tx *sqlx.Tx, tlv2 TimeAwareSubIssuedCreds) error
InsertSigningOrderRequestOutbox(ctx context.Context, requestID uuid.UUID, orderID uuid.UUID, itemID uuid.UUID, signingOrderRequest SigningOrderRequest) error
GetSigningOrderRequestOutboxByRequestID(ctx context.Context, dbi sqlx.QueryerContext, reqID uuid.UUID) (*SigningOrderRequestOutbox, error)
Expand Down Expand Up @@ -796,11 +797,20 @@ func (pg *Postgres) DeleteSingleUseOrderCredsByOrderTx(ctx context.Context, tx *

// DeleteTimeLimitedV2OrderCredsByOrderTx performs a hard delete for all time limited v2 order
// credentials for a given OrderID.
func (pg *Postgres) DeleteTimeLimitedV2OrderCredsByOrderTx(ctx context.Context, tx *sqlx.Tx, orderID uuid.UUID) error {
_, err := tx.ExecContext(ctx, `delete from time_limited_v2_order_creds where order_id = $1`, orderID)
func (pg *Postgres) DeleteTimeLimitedV2OrderCredsByOrderTx(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID, itemIDs ...uuid.UUID) error {
// Legacy, if the request id matches the item id on the tlv2 creds, we actually want to delete the record.
// Otherwise we will keep these credentials here for multiple device refresh capabilities.
const q = "DELETE FROM time_limited_v2_order_creds WHERE order_id = ? AND item_id in (?)"
query, args, err := sqlx.In(q, orderID, itemIDs)
if err != nil {
return fmt.Errorf("error creating delete query for order with item ids: %w", err)
}

query = dbi.Rebind(query)
if _, err := dbi.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("error deleting time limited v2 order creds: %w", err)
}

return nil
}

Expand Down Expand Up @@ -1026,6 +1036,21 @@ func (pg *Postgres) GetTLV2Creds(ctx context.Context, dbi sqlx.QueryerContext, o
return result, nil
}

// GetCountActiveOrderCreds returns the count of order creds currently active on an order
func (pg *Postgres) GetCountActiveOrderCreds(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID) (int, error) {
query := `
pavelbrm marked this conversation as resolved.
Show resolved Hide resolved
select count(1) from time_limited_v2_order_creds
where order_id = $1 and now() > valid_from and valid_to > now() group by request_id
husobee marked this conversation as resolved.
Show resolved Hide resolved
`

var activeCredCount int
if err := sqlx.GetContext(ctx, dbi, &activeCredCount, query, orderID); err != nil {
return 0, fmt.Errorf("error getting active credential count: %w", err)
}

return activeCredCount, nil
}

// GetTimeLimitedV2OrderCredsByOrderItem returns all the order credentials for a single order item.
func (pg *Postgres) GetTimeLimitedV2OrderCredsByOrderItem(itemID uuid.UUID) (*TimeLimitedV2Creds, error) {
query := `
Expand Down
18 changes: 18 additions & 0 deletions services/skus/datastore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ func (suite *PostgresTestSuite) TestGetOrderByExternalID() {
}
}

func (suite *PostgresTestSuite) TestCountActiveOrderCreds_Success() {
clD11 marked this conversation as resolved.
Show resolved Hide resolved
env := os.Getenv("ENV")
ctx := context.WithValue(context.Background(), appctx.EnvironmentCTXKey, env)

// create paid order with two order items
ctx = context.WithValue(ctx, appctx.WhitelistSKUsCTXKey, []string{devBraveFirewallVPNPremiumTimeLimited,
devBraveSearchPremiumYearTimeLimited})

orderCredentials := suite.createTimeLimitedV2OrderCreds(suite.T(), ctx, devBraveFirewallVPNPremiumTimeLimited,
devBraveSearchPremiumYearTimeLimited)

// both order items have same orderID so can use the first element to retrieve all order creds
count, err := suite.storage.GetCountActiveOrderCreds(ctx, suite.storage.RawDB(), orderCredentials[0].OrderID)
husobee marked this conversation as resolved.
Show resolved Hide resolved
suite.Require().NoError(err)

suite.Assert().Equal(count, 2)
husobee marked this conversation as resolved.
Show resolved Hide resolved
}

func (suite *PostgresTestSuite) TestGetTimeLimitedV2OrderCredsByOrder_Success() {
env := os.Getenv("ENV")
ctx := context.WithValue(context.Background(), appctx.EnvironmentCTXKey, env)
Expand Down
Loading
Loading