Skip to content

Commit

Permalink
Merge pull request #2326 from brave-intl/master
Browse files Browse the repository at this point in the history
production 2024-01-31_1
  • Loading branch information
husobee authored Jan 31, 2024
2 parents 91308ec + fc29a86 commit 2f85b54
Show file tree
Hide file tree
Showing 8 changed files with 724 additions and 46 deletions.
44 changes: 43 additions & 1 deletion services/skus/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func Router(
cr.Method(http.MethodGet, "/items/{itemID}/batches/{requestID}", metricsMwr("GetOrderCredsByID", getOrderCredsByID(svc, false)))

cr.Method(http.MethodPut, "/items/{itemID}/batches/{requestID}", metricsMwr("CreateOrderItemCreds", createItemCreds(svc)))
cr.Method(http.MethodDelete, "/items/{itemID}/batches/{requestID}", metricsMwr("DeleteOrderItemCreds", authMwr(deleteItemCreds(svc))))
})

return r
Expand Down Expand Up @@ -600,6 +601,47 @@ func CreateOrderCreds(svc *Service) handlers.AppHandler {
}
}

// deleteItemCreds handles requests for deleting credentials for an item.
func deleteItemCreds(svc *Service) handlers.AppHandler {
return func(w http.ResponseWriter, r *http.Request) *handlers.AppError {
ctx := r.Context()
lg := logging.Logger(ctx, "skus").With().Str("func", "deleteItemCreds").Logger()

orderID, err := uuid.FromString(chi.URLParamFromCtx(ctx, "orderID"))
if err != nil {
lg.Error().Err(err).Msg("failed to validate order id")
return handlers.ValidationError("request url parameter", map[string]interface{}{
"orderID": err.Error(),
})
}

reqID, err := uuid.FromString(chi.URLParamFromCtx(ctx, "reqID"))
if err != nil {
lg.Error().Err(err).Msg("failed to validate request id")
return handlers.ValidationError("request url parameter", map[string]interface{}{
"itemID": err.Error(),
})
}

isSigned := r.URL.Query().Get("isSigned") == "true"

if err := svc.DeleteOrderCreds(ctx, orderID, reqID, isSigned); err != nil {
lg.Error().Err(err).Msg("failed to delete the order credentials")

switch {
case errors.Is(err, model.ErrOrderNotFound), errors.Is(err, ErrOrderHasNoItems):
return handlers.WrapError(err, "order or item not found", http.StatusNotFound)
case errors.Is(err, errExceededMaxActiveOrderCreds):
return handlers.WrapError(err, errExceededMaxActiveOrderCreds.Error(), http.StatusUnprocessableEntity)
default:
return handlers.WrapError(err, "error deleting credentials", http.StatusInternalServerError)
}
}

return handlers.RenderContent(ctx, nil, w, http.StatusOK)
}
}

// createItemCredsRequest includes the blinded credentials to be signed.
type createItemCredsRequest struct {
BlindedCreds []string `json:"blindedCreds" valid:"base64"`
Expand Down Expand Up @@ -705,7 +747,7 @@ func DeleteOrderCreds(service *Service) handlers.AppHandler {
}

isSigned := r.URL.Query().Get("isSigned") == "true"
if err := service.DeleteOrderCreds(ctx, id, isSigned); err != nil {
if err := service.DeleteOrderCreds(ctx, id, uuid.Nil, isSigned); err != nil {
return handlers.WrapError(err, "Error deleting credentials", http.StatusBadRequest)
}

Expand Down
34 changes: 32 additions & 2 deletions services/skus/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ var (
errItemDoesNotExist model.Error = "order item does not exist for order"
errCredsAlreadySubmitted model.Error = "credentials already submitted"

errExceededMaxActiveOrderCreds model.Error = "maximum active order credentials exceeded"

defaultExpiresAt = time.Now().Add(17532 * time.Hour) // 2 years
retryPolicy = retrypolicy.DefaultRetry
dontRetryCodes = map[int]struct{}{
Expand Down Expand Up @@ -688,7 +690,7 @@ func (s *SigningOrderResultErrorHandler) Handle(ctx context.Context, message kaf
// - create repos for credentials;
// - move the corresponding methods there;
// - make those methods work on per-item basis.
func (s *Service) DeleteOrderCreds(ctx context.Context, orderID uuid.UUID, isSigned bool) error {
func (s *Service) DeleteOrderCreds(ctx context.Context, orderID uuid.UUID, reqID uuid.UUID, isSigned bool) error {
order, err := s.Datastore.GetOrder(orderID)
if err != nil {
return err
Expand Down Expand Up @@ -720,7 +722,7 @@ func (s *Service) DeleteOrderCreds(ctx context.Context, orderID uuid.UUID, isSig
}

if doTlv2 {
if err := s.Datastore.DeleteTimeLimitedV2OrderCredsByOrderTx(ctx, tx, orderID); err != nil {
if err := s.deleteTLV2(ctx, tx, order, reqID, time.Now()); err != nil {
return fmt.Errorf("error deleting time limited v2 order creds: %w", err)
}
}
Expand All @@ -736,6 +738,34 @@ func (s *Service) DeleteOrderCreds(ctx context.Context, orderID uuid.UUID, isSig
return nil
}

// maxTLV2ActiveItemCreds is the max number of credentials an item is allowed to have in the given day
const maxTLV2ActiveOrderCreds = 10

func (s *Service) deleteTLV2(ctx context.Context, dbi sqlx.ExtContext, order *model.Order, reqID uuid.UUID, now time.Time) error {

// Pass the request id as an "item id", which will allow for legacy credentials to be deleted.
// Otherwise, do not delete said credentials for multiple device support.
if !uuid.Equal(reqID, uuid.Nil) {
// check if we already have N active credentials on this item for the current day
activeCreds, err := s.Datastore.GetCountActiveOrderCreds(ctx, dbi, order.ID, now)
if err != nil {
return fmt.Errorf("failed to get count of active order credentials: %w", err)
}
if activeCreds > maxTLV2ActiveOrderCreds {
return errExceededMaxActiveOrderCreds
}
return s.Datastore.DeleteTimeLimitedV2OrderCredsByOrderTx(ctx, dbi, order.ID, reqID)
}

itemIDs := make([]uuid.UUID, 0, len(order.Items))
// Legacy, delete all items.
for i := range order.Items {
itemIDs = append(itemIDs, order.Items[i].ID)
}

return s.Datastore.DeleteTimeLimitedV2OrderCredsByOrderTx(ctx, dbi, order.ID, itemIDs...)
}

// checkNumBlindedCreds checks the number of submitted blinded credentials.
//
// The number of submitted credentials must not exceed:
Expand Down
110 changes: 110 additions & 0 deletions services/skus/credentials_noint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package skus

import (
"context"
"testing"
"time"

"github.com/golang/mock/gomock"
uuid "github.com/satori/go.uuid"
should "github.com/stretchr/testify/assert"
must "github.com/stretchr/testify/require"

"github.com/brave-intl/bat-go/services/skus/model"
)

func TestService_DeleteTLV2(t *testing.T) {
type tcGiven struct {
ord *model.Order
reqID uuid.UUID
}

type tcExpected struct {
args []gomock.Matcher
err error
}

type testCase struct {
name string
given tcGiven
exp tcExpected
}

now := time.Now()

tests := []testCase{
{
name: "request_id_specified",
given: tcGiven{
ord: &model.Order{ID: uuid.Must(uuid.FromString("a6b72f11-c886-49ee-b4f4-913eaa0984ae"))},
reqID: uuid.Must(uuid.FromString("d10cf2ae-30d8-4ada-965c-11c582968f26")),
},
exp: tcExpected{
args: []gomock.Matcher{
gomock.Eq(context.Background()),
gomock.Nil(), // dbi
gomock.Eq(uuid.Must(uuid.FromString("a6b72f11-c886-49ee-b4f4-913eaa0984ae"))),
gomock.Eq([]uuid.UUID{uuid.Must(uuid.FromString("d10cf2ae-30d8-4ada-965c-11c582968f26"))}),
gomock.Eq(now),
},
},
},

{
name: "request_id_nil",
given: tcGiven{
ord: &model.Order{
ID: uuid.Must(uuid.FromString("d7b4f524-ebe3-48c1-b60d-d6f548b25aae")),
Items: []model.OrderItem{
{ID: uuid.Must(uuid.FromString("3882ec99-73fb-476b-b176-1f4b40a9b767"))},
{ID: uuid.Must(uuid.FromString("ed437d36-182b-460f-8213-2ce3d4bb5c93"))},
{ID: uuid.Must(uuid.FromString("d3e62075-996f-4bed-bbc7-f6cd324b83e0"))},
},
},
reqID: uuid.Nil,
},
exp: tcExpected{
args: []gomock.Matcher{
gomock.Eq(context.Background()),
gomock.Nil(), // dbi
gomock.Eq(uuid.Must(uuid.FromString("d7b4f524-ebe3-48c1-b60d-d6f548b25aae"))),
gomock.Eq([]uuid.UUID{
uuid.Must(uuid.FromString("3882ec99-73fb-476b-b176-1f4b40a9b767")),
uuid.Must(uuid.FromString("ed437d36-182b-460f-8213-2ce3d4bb5c93")),
uuid.Must(uuid.FromString("d3e62075-996f-4bed-bbc7-f6cd324b83e0")),
}),
gomock.Eq(now),
},
},
},
}

for i := range tests {
tc := tests[i]

t.Run(tc.name, func(t *testing.T) {
must.Equal(t, 5, len(tc.exp.args))

ctrl := gomock.NewController(t)
defer ctrl.Finish()

ds := NewMockDatastore(ctrl)

svc := &Service{Datastore: ds}

ctx := context.Background()
if tc.given.reqID != uuid.Nil {
ds.EXPECT().GetCountActiveOrderCreds(
tc.exp.args[0], tc.exp.args[1], tc.exp.args[2], tc.exp.args[4],
).Return(0, nil)
}

ds.EXPECT().DeleteTimeLimitedV2OrderCredsByOrderTx(
tc.exp.args[0], tc.exp.args[1], tc.exp.args[2], tc.exp.args[3],
).Return(nil)

actual := svc.deleteTLV2(ctx, nil, tc.given.ord, tc.given.reqID, now)
should.Equal(t, tc.exp.err, actual)
})
}
}
29 changes: 26 additions & 3 deletions services/skus/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ type Datastore interface {
AreTimeLimitedV2CredsSubmitted(ctx context.Context, blindedCreds ...string) (bool, error)
GetTimeLimitedV2OrderCredsByOrder(orderID uuid.UUID) (*TimeLimitedV2Creds, error)
GetTLV2Creds(ctx context.Context, dbi sqlx.QueryerContext, ordID, itemID, reqID uuid.UUID) (*TimeLimitedV2Creds, error)
DeleteTimeLimitedV2OrderCredsByOrderTx(ctx context.Context, tx *sqlx.Tx, orderID uuid.UUID) error
DeleteTimeLimitedV2OrderCredsByOrderTx(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID, itemIDs ...uuid.UUID) error
GetTimeLimitedV2OrderCredsByOrderItem(itemID uuid.UUID) (*TimeLimitedV2Creds, error)
GetCountActiveOrderCreds(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID, now time.Time) (int, error)
InsertTimeLimitedV2OrderCredsTx(ctx context.Context, tx *sqlx.Tx, tlv2 TimeAwareSubIssuedCreds) error
InsertSigningOrderRequestOutbox(ctx context.Context, requestID uuid.UUID, orderID uuid.UUID, itemID uuid.UUID, signingOrderRequest SigningOrderRequest) error
GetSigningOrderRequestOutboxByRequestID(ctx context.Context, dbi sqlx.QueryerContext, reqID uuid.UUID) (*SigningOrderRequestOutbox, error)
Expand Down Expand Up @@ -796,11 +797,20 @@ func (pg *Postgres) DeleteSingleUseOrderCredsByOrderTx(ctx context.Context, tx *

// DeleteTimeLimitedV2OrderCredsByOrderTx performs a hard delete for all time limited v2 order
// credentials for a given OrderID.
func (pg *Postgres) DeleteTimeLimitedV2OrderCredsByOrderTx(ctx context.Context, tx *sqlx.Tx, orderID uuid.UUID) error {
_, err := tx.ExecContext(ctx, `delete from time_limited_v2_order_creds where order_id = $1`, orderID)
func (pg *Postgres) DeleteTimeLimitedV2OrderCredsByOrderTx(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID, itemIDs ...uuid.UUID) error {
// Legacy, if the request id matches the item id on the tlv2 creds, we actually want to delete the record.
// Otherwise we will keep these credentials here for multiple device refresh capabilities.
const q = "DELETE FROM time_limited_v2_order_creds WHERE order_id = ? AND item_id in (?)"
query, args, err := sqlx.In(q, orderID, itemIDs)
if err != nil {
return fmt.Errorf("error creating delete query for order with item ids: %w", err)
}

query = dbi.Rebind(query)
if _, err := dbi.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("error deleting time limited v2 order creds: %w", err)
}

return nil
}

Expand Down Expand Up @@ -1026,6 +1036,19 @@ func (pg *Postgres) GetTLV2Creds(ctx context.Context, dbi sqlx.QueryerContext, o
return result, nil
}

// GetCountActiveOrderCreds returns the count of order creds currently active on an order.
func (pg *Postgres) GetCountActiveOrderCreds(ctx context.Context, dbi sqlx.ExtContext, orderID uuid.UUID, now time.Time) (int, error) {
const q = `SELECT COUNT(1) FROM time_limited_v2_order_creds
WHERE order_id = $1 AND valid_from < $2 AND valid_to > $2 GROUP BY request_id`

var activeCredCount int
if err := sqlx.GetContext(ctx, dbi, &activeCredCount, q, orderID, now); err != nil {
return 0, fmt.Errorf("error getting active credential count: %w", err)
}

return activeCredCount, nil
}

// GetTimeLimitedV2OrderCredsByOrderItem returns all the order credentials for a single order item.
func (pg *Postgres) GetTimeLimitedV2OrderCredsByOrderItem(itemID uuid.UUID) (*TimeLimitedV2Creds, error) {
query := `
Expand Down
19 changes: 19 additions & 0 deletions services/skus/datastore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,25 @@ func (suite *PostgresTestSuite) TestGetOrderByExternalID() {
}
}

func (suite *PostgresTestSuite) TestCountActiveOrderCreds_Success() {
env := os.Getenv("ENV")
ctx := context.WithValue(context.Background(), appctx.EnvironmentCTXKey, env)

ctx = context.WithValue(ctx, appctx.WhitelistSKUsCTXKey, []string{
devBraveFirewallVPNPremiumTimeLimited,
devBraveSearchPremiumYearTimeLimited,
})

creds := suite.createTimeLimitedV2OrderCreds(suite.T(), ctx, devBraveFirewallVPNPremiumTimeLimited, devBraveSearchPremiumYearTimeLimited)

actual, err := suite.storage.GetCountActiveOrderCreds(ctx, suite.storage.RawDB(), creds[0].OrderID, time.Now())
suite.Require().NoError(err)

const expected = 2

suite.Assert().Equal(expected, actual)
}

func (suite *PostgresTestSuite) TestGetTimeLimitedV2OrderCredsByOrder_Success() {
env := os.Getenv("ENV")
ctx := context.WithValue(context.Background(), appctx.EnvironmentCTXKey, env)
Expand Down
Loading

0 comments on commit 2f85b54

Please sign in to comment.