diff --git a/services/skus/controllers.go b/services/skus/controllers.go index 2ab672544..c1642d835 100644 --- a/services/skus/controllers.go +++ b/services/skus/controllers.go @@ -96,10 +96,17 @@ func Router( r.Route("/{orderID}/credentials", func(cr chi.Router) { cr.Use(NewCORSMwr(copts, http.MethodGet, http.MethodPost)) - cr.Method(http.MethodPost, "/", metricsMwr("CreateOrderCreds", CreateOrderCreds(svc))) cr.Method(http.MethodGet, "/", metricsMwr("GetOrderCreds", GetOrderCreds(svc))) - cr.Method(http.MethodGet, "/{itemID}", metricsMwr("GetOrderCredsByID", GetOrderCredsByID(svc))) + cr.Method(http.MethodPost, "/", metricsMwr("CreateOrderCreds", CreateOrderCreds(svc))) cr.Method(http.MethodDelete, "/", metricsMwr("DeleteOrderCreds", authMwr(DeleteOrderCreds(svc)))) + + // Handle the old endpoint while the new is being rolled out: + // - true: the handler uses itemID as the request id, which is the old mode; + // - false: the handler uses the requestID from the URI. + cr.Method(http.MethodGet, "/{itemID}", metricsMwr("GetOrderCredsByID", getOrderCredsByID(svc, true))) + cr.Method(http.MethodGet, "/items/{itemID}/batches/{requestID}", metricsMwr("GetOrderCredsByID", getOrderCredsByID(svc, false))) + + cr.Method(http.MethodPut, "/items/{itemID}/batches/{requestID}", metricsMwr("CreateOrderItemCreds", createItemCreds(svc))) }) return r @@ -535,7 +542,7 @@ func CreateAnonCardTransaction(service *Service) handlers.AppHandler { }) } -// CreateOrderCredsRequest includes the item ID and blinded credentials which to be signed +// CreateOrderCredsRequest includes the item ID and blinded credentials which to be signed. type CreateOrderCredsRequest struct { ItemID uuid.UUID `json:"itemId" valid:"-"` BlindedCreds []string `json:"blindedCreds" valid:"base64"` @@ -569,9 +576,65 @@ func CreateOrderCreds(svc *Service) handlers.AppHandler { ) } - requestID := uuid.NewV4() + // Use the itemID for the request id so the old credential uniqueness constraint remains enforced. + reqID := req.ItemID + + if err := svc.CreateOrderItemCredentials(ctx, *orderID.UUID(), req.ItemID, reqID, req.BlindedCreds); err != nil { + lg.Error().Err(err).Msg("failed to create the order credentials") + return handlers.WrapError(err, "Error creating order creds", http.StatusBadRequest) + } + + return handlers.RenderContent(ctx, nil, w, http.StatusOK) + } +} + +// createItemCredsRequest includes the blinded credentials to be signed. +type createItemCredsRequest struct { + BlindedCreds []string `json:"blindedCreds" valid:"base64"` +} + +// createItemCreds handles requests for creating credentials for an item. +func createItemCreds(svc *Service) handlers.AppHandler { + return func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + ctx := r.Context() + lg := logging.Logger(ctx, "skus.createItemCreds") + + req := &createItemCredsRequest{} + if err := requestutils.ReadJSON(ctx, r.Body, req); err != nil { + lg.Error().Err(err).Msg("failed to read body payload") + return handlers.WrapError(err, "Error in request body", http.StatusBadRequest) + } + + if _, err := govalidator.ValidateStruct(req); err != nil { + lg.Error().Err(err).Msg("failed to validate struct") + return handlers.WrapValidationError(err) + } + + orderID := &inputs.ID{} + if err := inputs.DecodeAndValidateString(ctx, orderID, chi.URLParamFromCtx(ctx, "orderID")); err != nil { + lg.Error().Err(err).Msg("failed to validate order id") + return handlers.ValidationError("Error validating request url parameter", map[string]interface{}{ + "orderID": err.Error(), + }) + } + + itemID := &inputs.ID{} + if err := inputs.DecodeAndValidateString(ctx, itemID, chi.URLParamFromCtx(ctx, "itemID")); err != nil { + lg.Error().Err(err).Msg("failed to validate item id") + return handlers.ValidationError("Error validating request url parameter", map[string]interface{}{ + "itemID": err.Error(), + }) + } + + reqID := &inputs.ID{} + if err := inputs.DecodeAndValidateString(ctx, reqID, chi.URLParamFromCtx(ctx, "requestID")); err != nil { + lg.Error().Err(err).Msg("failed to validate request id") + return handlers.ValidationError("Error validating request url parameter", map[string]interface{}{ + "requestID": err.Error(), + }) + } - if err := svc.CreateOrderItemCredentials(ctx, *orderID.UUID(), req.ItemID, requestID, req.BlindedCreds); err != nil { + if err := svc.CreateOrderItemCredentials(ctx, *orderID.UUID(), *itemID.UUID(), *reqID.UUID(), req.BlindedCreds); err != nil { lg.Error().Err(err).Msg("failed to create the order credentials") return handlers.WrapError(err, "Error creating order creds", http.StatusBadRequest) } @@ -638,54 +701,67 @@ func DeleteOrderCreds(service *Service) handlers.AppHandler { } } -// GetOrderCredsByID is the handler for fetching order credentials by an item id -func GetOrderCredsByID(service *Service) handlers.AppHandler { +// getOrderCredsByID handles requests for fetching order credentials by an item id. +// +// Requests may come in via two endpoints: +// - /{itemID} – legacyMode, reqID == itemID +// - /items/{itemID}/batches/{requestID} – new mode, reqID == requestID. +// +// The legacy mode will be gone after confirming a successful rollout. +// +// TODO: Clean up the legacy mode. +func getOrderCredsByID(svc *Service, legacyMode bool) handlers.AppHandler { return handlers.AppHandler(func(w http.ResponseWriter, r *http.Request) *handlers.AppError { + ctx := r.Context() - // get the IDs from the URL - var ( - orderID = new(inputs.ID) - itemID = new(inputs.ID) - validationPayload = map[string]interface{}{} - err error - ) - - // decode and validate orderID url param - if err = inputs.DecodeAndValidateString( - context.Background(), orderID, chi.URLParam(r, "orderID")); err != nil { - validationPayload["orderID"] = err.Error() + orderID := &inputs.ID{} + if err := inputs.DecodeAndValidateString(ctx, orderID, chi.URLParamFromCtx(ctx, "orderID")); err != nil { + return handlers.ValidationError("Error validating request url parameter", map[string]interface{}{ + "orderID": err.Error(), + }) } - // decode and validate itemID url param - if err = inputs.DecodeAndValidateString( - context.Background(), itemID, chi.URLParam(r, "itemID")); err != nil { - validationPayload["itemID"] = err.Error() + itemID := &inputs.ID{} + if err := inputs.DecodeAndValidateString(ctx, itemID, chi.URLParamFromCtx(ctx, "itemID")); err != nil { + return handlers.ValidationError("Error validating request url parameter", map[string]interface{}{ + "itemID": err.Error(), + }) } - // did we get any validation errors? - if len(validationPayload) > 0 { - return handlers.ValidationError( - "Error validating request url parameter", - validationPayload) + var reqID uuid.UUID + if legacyMode { + reqID = *itemID.UUID() + } else { + reqIDRaw := &inputs.ID{} + if err := inputs.DecodeAndValidateString(ctx, reqIDRaw, chi.URLParamFromCtx(ctx, "requestID")); err != nil { + return handlers.ValidationError("Error validating request url parameter", map[string]interface{}{ + "requestID": err.Error(), + }) + } + + reqID = *reqIDRaw.UUID() } - creds, status, err := service.GetItemCredentials(r.Context(), *orderID.UUID(), *itemID.UUID()) + creds, status, err := svc.GetItemCredentials(ctx, *orderID.UUID(), *itemID.UUID(), reqID) if err != nil { - if errors.Is(err, errSetRetryAfter) { - // error specifies a retry after period, add to response header - avg, err := service.Datastore.GetOutboxMovAvgDurationSeconds() - if err != nil { - return handlers.WrapError(err, "Error getting credential retry-after", status) - } - w.Header().Set("Retry-After", strconv.FormatInt(avg, 10)) - } else { + if !errors.Is(err, errSetRetryAfter) { return handlers.WrapError(err, "Error getting credentials", status) } + + // Add to response header as error specifies a retry after period. + avg, err := svc.Datastore.GetOutboxMovAvgDurationSeconds() + if err != nil { + return handlers.WrapError(err, "Error getting credential retry-after", status) + } + + w.Header().Set("Retry-After", strconv.FormatInt(avg, 10)) } + if creds == nil { - return handlers.RenderContent(r.Context(), map[string]interface{}{}, w, status) + return handlers.RenderContent(ctx, map[string]interface{}{}, w, status) } - return handlers.RenderContent(r.Context(), creds, w, status) + + return handlers.RenderContent(ctx, creds, w, status) }) } diff --git a/services/skus/controllers_test.go b/services/skus/controllers_test.go index 5312ef367..1cb5fa7a3 100644 --- a/services/skus/controllers_test.go +++ b/services/skus/controllers_test.go @@ -1435,7 +1435,7 @@ func (suite *ControllersTestSuite) TestExpiredTimeLimitedCred() { ValidFor: &valid, } - creds, status, err := suite.service.GetTimeLimitedCreds(ctx, order) + creds, status, err := suite.service.GetTimeLimitedCreds(ctx, order, uuid.Nil, uuid.Nil) suite.Require().True(creds == nil, "should not get creds back") suite.Require().True(status == http.StatusBadRequest, "should not get creds back") suite.Require().Error(err, "should get an error") diff --git a/services/skus/credentials.go b/services/skus/credentials.go index 53d47ec18..c17a8b1a2 100644 --- a/services/skus/credentials.go +++ b/services/skus/credentials.go @@ -601,7 +601,7 @@ func (s *SignedOrderCredentialsHandler) Handle(ctx context.Context, message kafk defer rollback() // Check to see if the signing request has not been deleted whilst signing the request. - sor, err := s.datastore.GetSigningOrderRequestOutboxByRequestIDTx(ctx, tx, requestID) + sor, err := s.datastore.GetSigningOrderRequestOutboxByRequestID(ctx, tx, requestID) if err != nil && !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("error get signing order credentials tx: %w", err) } diff --git a/services/skus/datastore.go b/services/skus/datastore.go index 5e49aceed..67937c96b 100644 --- a/services/skus/datastore.go +++ b/services/skus/datastore.go @@ -29,7 +29,8 @@ import ( const ( signingRequestBatchSize = 10 - errNotFound = model.Error("not found") + errNotFound = model.Error("not found") + errNoTLV2Creds = model.Error("no unexpired time-limited-v2 credentials found") ) // Datastore abstracts over the underlying datastore. @@ -81,11 +82,12 @@ type Datastore interface { InsertSignedOrderCredentialsTx(ctx context.Context, tx *sqlx.Tx, signedOrderResult *SigningOrderResult) error AreTimeLimitedV2CredsSubmitted(ctx context.Context, blindedCreds ...string) (bool, error) GetTimeLimitedV2OrderCredsByOrder(orderID uuid.UUID) (*TimeLimitedV2Creds, error) + GetTLV2Creds(ctx context.Context, dbi sqlx.QueryerContext, ordID, itemID, reqID uuid.UUID) (*TimeLimitedV2Creds, error) DeleteTimeLimitedV2OrderCredsByOrderTx(ctx context.Context, tx *sqlx.Tx, orderID uuid.UUID) error GetTimeLimitedV2OrderCredsByOrderItem(itemID uuid.UUID) (*TimeLimitedV2Creds, error) InsertTimeLimitedV2OrderCredsTx(ctx context.Context, tx *sqlx.Tx, tlv2 TimeAwareSubIssuedCreds) error InsertSigningOrderRequestOutbox(ctx context.Context, requestID uuid.UUID, orderID uuid.UUID, itemID uuid.UUID, signingOrderRequest SigningOrderRequest) error - GetSigningOrderRequestOutboxByRequestIDTx(ctx context.Context, tx *sqlx.Tx, requestID uuid.UUID) (*SigningOrderRequestOutbox, error) + GetSigningOrderRequestOutboxByRequestID(ctx context.Context, dbi sqlx.QueryerContext, reqID uuid.UUID) (*SigningOrderRequestOutbox, error) GetSigningOrderRequestOutboxByOrder(ctx context.Context, orderID uuid.UUID) ([]SigningOrderRequestOutbox, error) GetSigningOrderRequestOutboxByOrderItem(ctx context.Context, itemID uuid.UUID) ([]SigningOrderRequestOutbox, error) DeleteSigningOrderRequestOutboxByOrderTx(ctx context.Context, tx *sqlx.Tx, orderID uuid.UUID) error @@ -996,6 +998,34 @@ func (pg *Postgres) GetTimeLimitedV2OrderCredsByOrder(orderID uuid.UUID) (*TimeL return &timeLimitedV2Creds, nil } +// GetTLV2Creds returns all the non expired tlv2 credentials for a given order, item and request ids. +// +// If no credentials have been found, the method returns errNoTLV2Creds. +func (pg *Postgres) GetTLV2Creds(ctx context.Context, dbi sqlx.QueryerContext, ordID, itemID, reqID uuid.UUID) (*TimeLimitedV2Creds, error) { + const q = `SELECT + order_id, item_id, issuer_id, blinded_creds, signed_creds, + batch_proof, public_key, valid_from, valid_to + FROM time_limited_v2_order_creds + WHERE order_id = $1 AND item_id = $2 AND request_id = $3 AND valid_to > now()` + + creds := make([]TimeAwareSubIssuedCreds, 0) + if err := sqlx.SelectContext(ctx, dbi, &creds, q, ordID, itemID, reqID); err != nil { + return nil, err + } + + if len(creds) == 0 { + return nil, errNoTLV2Creds + } + + result := &TimeLimitedV2Creds{ + OrderID: creds[0].OrderID, + IssuerID: creds[0].IssuerID, + Credentials: creds, + } + + return result, nil +} + // GetTimeLimitedV2OrderCredsByOrderItem returns all the order credentials for a single order item. func (pg *Postgres) GetTimeLimitedV2OrderCredsByOrderItem(itemID uuid.UUID) (*TimeLimitedV2Creds, error) { query := ` @@ -1083,29 +1113,19 @@ func (pg *Postgres) GetSigningOrderRequestOutboxByOrderItem(ctx context.Context, } // GetSigningOrderRequestOutboxByRequestID retrieves the SigningOrderRequestOutbox by requestID. +// // An error is returned if the result set is empty. -func (pg *Postgres) GetSigningOrderRequestOutboxByRequestID(ctx context.Context, requestID uuid.UUID) (*SigningOrderRequestOutbox, error) { - var signingRequestOutbox SigningOrderRequestOutbox - err := pg.RawDB().GetContext(ctx, &signingRequestOutbox, - `select request_id, order_id, item_id, completed_at, message_data - from signing_order_request_outbox where request_id = $1`, requestID) - if err != nil { - return nil, fmt.Errorf("error retrieving signing request from outbox: %w", err) - } - return &signingRequestOutbox, nil -} +func (pg *Postgres) GetSigningOrderRequestOutboxByRequestID(ctx context.Context, dbi sqlx.QueryerContext, reqID uuid.UUID) (*SigningOrderRequestOutbox, error) { + const q = `SELECT request_id, order_id, item_id, completed_at, message_data + FROM signing_order_request_outbox + WHERE request_id = $1 FOR UPDATE` -// GetSigningOrderRequestOutboxByRequestIDTx retrieves the SigningOrderRequestOutbox by requestID. -// An error is returned if the result set is empty. -func (pg *Postgres) GetSigningOrderRequestOutboxByRequestIDTx(ctx context.Context, tx *sqlx.Tx, requestID uuid.UUID) (*SigningOrderRequestOutbox, error) { - var signingRequestOutbox SigningOrderRequestOutbox - err := tx.GetContext(ctx, &signingRequestOutbox, - `select request_id, order_id, item_id, completed_at, message_data - from signing_order_request_outbox where request_id = $1 for update`, requestID) - if err != nil { + result := &SigningOrderRequestOutbox{} + if err := sqlx.GetContext(ctx, dbi, result, q, reqID); err != nil { return nil, fmt.Errorf("error retrieving signing request from outbox: %w", err) } - return &signingRequestOutbox, nil + + return result, nil } // UpdateSigningOrderRequestOutboxTx updates a signing order request outbox message for the given requestID. @@ -1277,7 +1297,6 @@ func (pg *Postgres) InsertSignedOrderCredentialsTx(ctx context.Context, tx *sqlx } case timeLimitedV2: - if so.ValidTo.Value() == nil { return fmt.Errorf("error validTo for order creds orderID %s itemID %s is null: %w", metadata.OrderID, metadata.ItemID, err) diff --git a/services/skus/instrumented_datastore.go b/services/skus/instrumented_datastore.go index 6a4f5b893..f79130141 100644 --- a/services/skus/instrumented_datastore.go +++ b/services/skus/instrumented_datastore.go @@ -422,8 +422,8 @@ func (_d DatastoreWithPrometheus) GetSigningOrderRequestOutboxByOrderItem(ctx co return _d.base.GetSigningOrderRequestOutboxByOrderItem(ctx, itemID) } -// GetSigningOrderRequestOutboxByRequestIDTx implements Datastore -func (_d DatastoreWithPrometheus) GetSigningOrderRequestOutboxByRequestIDTx(ctx context.Context, tx *sqlx.Tx, requestID uuid.UUID) (sp1 *SigningOrderRequestOutbox, err error) { +// GetSigningOrderRequestOutboxByRequestID implements Datastore +func (_d DatastoreWithPrometheus) GetSigningOrderRequestOutboxByRequestID(ctx context.Context, dbi sqlx.QueryerContext, reqID uuid.UUID) (sp1 *SigningOrderRequestOutbox, err error) { _since := time.Now() defer func() { result := "ok" @@ -431,9 +431,9 @@ func (_d DatastoreWithPrometheus) GetSigningOrderRequestOutboxByRequestIDTx(ctx result = "error" } - datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "GetSigningOrderRequestOutboxByRequestIDTx", result).Observe(time.Since(_since).Seconds()) + datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "GetSigningOrderRequestOutboxByRequestID", result).Observe(time.Since(_since).Seconds()) }() - return _d.base.GetSigningOrderRequestOutboxByRequestIDTx(ctx, tx, requestID) + return _d.base.GetSigningOrderRequestOutboxByRequestID(ctx, dbi, reqID) } // GetSumForTransactions implements Datastore @@ -478,6 +478,20 @@ func (_d DatastoreWithPrometheus) GetTimeLimitedV2OrderCredsByOrderItem(itemID u return _d.base.GetTimeLimitedV2OrderCredsByOrderItem(itemID) } +// GetTLV2Creds implements Datastore +func (_d DatastoreWithPrometheus) GetTLV2Creds(ctx context.Context, dbi sqlx.QueryerContext, ordID, itemID, reqID uuid.UUID) (tp1 *TimeLimitedV2Creds, err error) { + _since := time.Now() + defer func() { + result := "ok" + if err != nil { + result = "error" + } + + datastoreDurationSummaryVec.WithLabelValues(_d.instanceName, "GetTLV2Creds", result).Observe(time.Since(_since).Seconds()) + }() + return _d.base.GetTLV2Creds(ctx, dbi, ordID, itemID, reqID) +} + // GetTransaction implements Datastore func (_d DatastoreWithPrometheus) GetTransaction(externalTransactionID string) (tp1 *Transaction, err error) { _since := time.Now() diff --git a/services/skus/mockdatastore.go b/services/skus/mockdatastore.go index 88cd749f3..9375051c8 100644 --- a/services/skus/mockdatastore.go +++ b/services/skus/mockdatastore.go @@ -446,19 +446,19 @@ func (mr *MockDatastoreMockRecorder) GetSigningOrderRequestOutboxByOrderItem(ctx return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningOrderRequestOutboxByOrderItem", reflect.TypeOf((*MockDatastore)(nil).GetSigningOrderRequestOutboxByOrderItem), ctx, itemID) } -// GetSigningOrderRequestOutboxByRequestIDTx mocks base method. -func (m *MockDatastore) GetSigningOrderRequestOutboxByRequestIDTx(ctx context.Context, tx *sqlx.Tx, requestID go_uuid.UUID) (*SigningOrderRequestOutbox, error) { +// GetSigningOrderRequestOutboxByRequestID mocks base method. +func (m *MockDatastore) GetSigningOrderRequestOutboxByRequestID(ctx context.Context, dbi sqlx.QueryerContext, reqID go_uuid.UUID) (*SigningOrderRequestOutbox, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSigningOrderRequestOutboxByRequestIDTx", ctx, tx, requestID) + ret := m.ctrl.Call(m, "GetSigningOrderRequestOutboxByRequestID", ctx, dbi, reqID) ret0, _ := ret[0].(*SigningOrderRequestOutbox) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetSigningOrderRequestOutboxByRequestIDTx indicates an expected call of GetSigningOrderRequestOutboxByRequestIDTx. -func (mr *MockDatastoreMockRecorder) GetSigningOrderRequestOutboxByRequestIDTx(ctx, tx, requestID interface{}) *gomock.Call { +// GetSigningOrderRequestOutboxByRequestID indicates an expected call of GetSigningOrderRequestOutboxByRequestID. +func (mr *MockDatastoreMockRecorder) GetSigningOrderRequestOutboxByRequestID(ctx, dbi, reqID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningOrderRequestOutboxByRequestIDTx", reflect.TypeOf((*MockDatastore)(nil).GetSigningOrderRequestOutboxByRequestIDTx), ctx, tx, requestID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningOrderRequestOutboxByRequestID", reflect.TypeOf((*MockDatastore)(nil).GetSigningOrderRequestOutboxByRequestID), ctx, dbi, reqID) } // GetSumForTransactions mocks base method. @@ -506,6 +506,21 @@ func (mr *MockDatastoreMockRecorder) GetTimeLimitedV2OrderCredsByOrderItem(itemI return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTimeLimitedV2OrderCredsByOrderItem", reflect.TypeOf((*MockDatastore)(nil).GetTimeLimitedV2OrderCredsByOrderItem), itemID) } +// GetTLV2Creds mocks base method. +func (m *MockDatastore) GetTLV2Creds(ctx context.Context, dbi sqlx.QueryerContext, ordID, itemID, reqID go_uuid.UUID) (*TimeLimitedV2Creds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTLV2Creds", ctx, dbi, ordID, itemID, reqID) + ret0, _ := ret[0].(*TimeLimitedV2Creds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTLV2Creds indicates an expected call of GetTLV2Creds. +func (mr *MockDatastoreMockRecorder) GetTLV2Creds(ctx, dbi, ordID, itemID, reqID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLV2Creds", reflect.TypeOf((*MockDatastore)(nil).GetTLV2Creds), ctx, dbi, ordID, itemID, reqID) +} + // GetTransaction mocks base method. func (m *MockDatastore) GetTransaction(externalTransactionID string) (*Transaction, error) { m.ctrl.T.Helper() diff --git a/services/skus/model/model.go b/services/skus/model/model.go index e98df6ea9..83bd63f72 100644 --- a/services/skus/model/model.go +++ b/services/skus/model/model.go @@ -288,6 +288,20 @@ func (o *Order) NumIntervals() (int, error) { return result, nil } +// HasItem returns the item if found. +// +// It exposes a comma, ok API similar to a map. +// Today items are stored in a slice, but it might change to a map in the future. +func (o *Order) HasItem(id uuid.UUID) (*OrderItem, bool) { + for i := range o.Items { + if uuid.Equal(o.Items[i].ID, id) { + return &o.Items[i], true + } + } + + return nil, false +} + // OrderItem represents a particular order item. type OrderItem struct { ID uuid.UUID `json:"id" db:"id"` diff --git a/services/skus/model/model_test.go b/services/skus/model/model_test.go index 6cb2a0a87..237119485 100644 --- a/services/skus/model/model_test.go +++ b/services/skus/model/model_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/lib/pq" + uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" should "github.com/stretchr/testify/assert" must "github.com/stretchr/testify/require" @@ -574,6 +575,129 @@ func TestOrderItemList_TotalCost(t *testing.T) { } } +func TestOrder_HasItem(t *testing.T) { + type tcGiven struct { + order *model.Order + itemID uuid.UUID + } + + type tcExpected struct { + item *model.OrderItem + ok bool + } + + type testCase struct { + name string + given tcGiven + exp tcExpected + } + + tests := []testCase{ + { + name: "no_items_nothing_found", + given: tcGiven{ + order: &model.Order{}, + itemID: uuid.Must(uuid.FromString("b5e3f3e4-0bd4-4fd5-a693-a50f4dbfd6ac")), + }, + }, + + { + name: "one_item_not_found", + given: tcGiven{ + order: &model.Order{ + Items: []model.OrderItem{ + { + ID: uuid.Must(uuid.FromString("dbc6416a-7713-4aa5-8968-56aef7ec0e81")), + }, + }, + }, + itemID: uuid.Must(uuid.FromString("b5e3f3e4-0bd4-4fd5-a693-a50f4dbfd6ac")), + }, + }, + + { + name: "two_items_not_found", + given: tcGiven{ + order: &model.Order{ + Items: []model.OrderItem{ + { + ID: uuid.Must(uuid.FromString("dbc6416a-7713-4aa5-8968-56aef7ec0e81")), + }, + + { + ID: uuid.Must(uuid.FromString("4efbedfe-a598-43a4-a345-17653d6289e8")), + }, + }, + }, + itemID: uuid.Must(uuid.FromString("b5e3f3e4-0bd4-4fd5-a693-a50f4dbfd6ac")), + }, + }, + + { + name: "one_item_found", + given: tcGiven{ + order: &model.Order{ + Items: []model.OrderItem{ + { + ID: uuid.Must(uuid.FromString("b5e3f3e4-0bd4-4fd5-a693-a50f4dbfd6ac")), + }, + }, + }, + itemID: uuid.Must(uuid.FromString("b5e3f3e4-0bd4-4fd5-a693-a50f4dbfd6ac")), + }, + exp: tcExpected{ + item: &model.OrderItem{ + ID: uuid.Must(uuid.FromString("b5e3f3e4-0bd4-4fd5-a693-a50f4dbfd6ac")), + }, + ok: true, + }, + }, + + { + name: "many_items_found", + given: tcGiven{ + order: &model.Order{ + Items: []model.OrderItem{ + { + ID: uuid.Must(uuid.FromString("dbc6416a-7713-4aa5-8968-56aef7ec0e81")), + }, + + { + ID: uuid.Must(uuid.FromString("4efbedfe-a598-43a4-a345-17653d6289e8")), + }, + + { + ID: uuid.Must(uuid.FromString("b5e3f3e4-0bd4-4fd5-a693-a50f4dbfd6ac")), + }, + }, + }, + itemID: uuid.Must(uuid.FromString("b5e3f3e4-0bd4-4fd5-a693-a50f4dbfd6ac")), + }, + exp: tcExpected{ + item: &model.OrderItem{ + ID: uuid.Must(uuid.FromString("b5e3f3e4-0bd4-4fd5-a693-a50f4dbfd6ac")), + }, + ok: true, + }, + }, + } + + for i := range tests { + tc := tests[i] + + t.Run(tc.name, func(t *testing.T) { + item, ok := tc.given.order.HasItem(tc.given.itemID) + must.Equal(t, tc.exp.ok, ok) + + if !tc.exp.ok { + return + } + + should.Equal(t, tc.exp.item, item) + }) + } +} + func mustDecimalFromString(v string) decimal.Decimal { result, err := decimal.NewFromString(v) if err != nil { diff --git a/services/skus/service.go b/services/skus/service.go index 13774dce5..9625cc459 100644 --- a/services/skus/service.go +++ b/services/skus/service.go @@ -54,6 +54,9 @@ var ( errClosingResource = errors.New("error closing resource") errInvalidRadomURL = model.Error("service: invalid radom url") errGeminiClientNotConfigured = errors.New("service: gemini client not configured") + errLegacyOutboxNotFound = model.Error("error no order credentials have been submitted for signing") + errWrongOrderIDForRequestID = model.Error("signed request order id does not belong to request id") + errLegacySUCredsNotFound = model.Error("credentials do not exist") voteTopic = os.Getenv("ENV") + ".payment.vote" @@ -986,26 +989,39 @@ const ( var errInvalidCredentialType = errors.New("invalid credential type on order") -// GetItemCredentials - based on the order, get the associated credentials -func (s *Service) GetItemCredentials(ctx context.Context, orderID, itemID uuid.UUID) (interface{}, int, error) { - orderCreds, status, err := s.GetCredentials(ctx, orderID) +// GetItemCredentials returns credentials based on the order, item and request id. +func (s *Service) GetItemCredentials(ctx context.Context, orderID, itemID, reqID uuid.UUID) (interface{}, int, error) { + order, err := s.Datastore.GetOrder(orderID) if err != nil { - return nil, status, err + return nil, http.StatusNotFound, fmt.Errorf("failed to get order: %w", err) } - for _, oc := range orderCreds.([]OrderCreds) { - if uuid.Equal(oc.ID, itemID) { - return oc, status, nil - } + if order == nil { + return nil, http.StatusNotFound, fmt.Errorf("failed to get order: %w", err) + } + + item, ok := order.HasItem(itemID) + if !ok { + return nil, http.StatusNotFound, fmt.Errorf("failed to get item: %w", err) + } + + switch item.CredentialType { + case singleUse: + return s.GetSingleUseCreds(ctx, order.ID, itemID, reqID) + case timeLimited: + return s.GetTimeLimitedCreds(ctx, order, itemID, reqID) + case timeLimitedV2: + return s.GetTimeLimitedV2Creds(ctx, order.ID, itemID, reqID) + default: + return nil, http.StatusConflict, errInvalidCredentialType } - // order creds are not available yet - return nil, status, nil } -// GetCredentials - based on the order, get the associated credentials +// GetCredentials returns credentials on the order. +// +// This is a legacy method. +// For backward compatibility, similar to creating credentials, it uses item id as request id. func (s *Service) GetCredentials(ctx context.Context, orderID uuid.UUID) (interface{}, int, error) { - var credentialType string - order, err := s.Datastore.GetOrder(orderID) if err != nil { return nil, http.StatusNotFound, fmt.Errorf("failed to get order: %w", err) @@ -1015,104 +1031,94 @@ func (s *Service) GetCredentials(ctx context.Context, orderID uuid.UUID) (interf return nil, http.StatusNotFound, fmt.Errorf("failed to get order: %w", err) } - // look through order, find out what all the order item's credential types are - for i, v := range order.Items { - if i > 0 { - if v.CredentialType != credentialType { - // all the order items on the order need the same credential type - return nil, http.StatusConflict, fmt.Errorf("all items must have the same credential type") - } - } else { - credentialType = v.CredentialType - } + if len(order.Items) != 1 { + return nil, http.StatusConflict, model.Error("order must only have one item") } - switch credentialType { + itemID := order.Items[0].ID + + switch order.Items[0].CredentialType { case singleUse: - return s.GetSingleUseCreds(ctx, order) + return s.GetSingleUseCreds(ctx, order.ID, itemID, itemID) case timeLimited: - return s.GetTimeLimitedCreds(ctx, order) + return s.GetTimeLimitedCreds(ctx, order, itemID, itemID) case timeLimitedV2: - return s.GetTimeLimitedV2Creds(ctx, order) + return s.GetTimeLimitedV2Creds(ctx, order.ID, itemID, itemID) + default: + return nil, http.StatusConflict, errInvalidCredentialType } - return nil, http.StatusConflict, errInvalidCredentialType } -// GetSingleUseCreds returns all the single use credentials for a given order. +// GetSingleUseCreds returns single use credentials for a given order, item and request. +// // If the credentials have been submitted but not yet signed it returns a http.StatusAccepted and an empty body. // If the credentials have been signed it will return a http.StatusOK and the order credentials. -func (s *Service) GetSingleUseCreds(ctx context.Context, order *Order) ([]OrderCreds, int, error) { - if order == nil { - return nil, http.StatusBadRequest, fmt.Errorf("failed to create credentials, bad order") - } - - creds, err := s.Datastore.GetOrderCreds(order.ID, false) +func (s *Service) GetSingleUseCreds(ctx context.Context, orderID, itemID, reqID uuid.UUID) ([]OrderCreds, int, error) { + // Single use credentials retain the old semantics, only one request is ever allowed. + creds, err := s.Datastore.GetOrderCredsByItemID(orderID, itemID, false) if err != nil { - return nil, http.StatusInternalServerError, fmt.Errorf("error getting credentials: %w", err) + return nil, http.StatusInternalServerError, fmt.Errorf("failed to get single use creds: %w", err) } - if len(creds) > 0 { + if creds != nil { // TODO: Issues #1541 remove once all creds using RunOrderJob have need processed - for i := 0; i < len(creds); i++ { - if creds[i].SignedCreds == nil { - return nil, http.StatusAccepted, nil - } + if creds.SignedCreds == nil { + return nil, http.StatusAccepted, nil } + // TODO: End - return creds, http.StatusOK, nil + return []OrderCreds{*creds}, http.StatusOK, nil } - outboxMessages, err := s.Datastore.GetSigningOrderRequestOutboxByOrder(ctx, order.ID) + outboxMessages, err := s.Datastore.GetSigningOrderRequestOutboxByRequestID(ctx, s.Datastore.RawDB(), reqID) if err != nil { - return nil, http.StatusInternalServerError, fmt.Errorf("error getting credentials: %w", err) - } + if errors.Is(err, sql.ErrNoRows) { + return nil, http.StatusNotFound, errLegacySUCredsNotFound + } - if len(outboxMessages) == 0 { - return nil, http.StatusNotFound, fmt.Errorf("credentials do not exist") + return nil, http.StatusInternalServerError, fmt.Errorf("error getting outbox messages: %w", err) } - for _, m := range outboxMessages { - if m.CompletedAt == nil { - return nil, http.StatusAccepted, nil - } + if outboxMessages.CompletedAt == nil { + return nil, http.StatusAccepted, nil } - return creds, http.StatusOK, nil + return nil, http.StatusInternalServerError, model.Error("unreachable condition") } -// GetTimeLimitedV2Creds returns all the single use credentials for a given order. +// GetTimeLimitedV2Creds returns all the tlv2 credentials for a given order, item and request id. +// // If the credentials have been submitted but not yet signed it returns a http.StatusAccepted and an empty body. // If the credentials have been signed it will return a http.StatusOK and the time limited v2 credentials. -func (s *Service) GetTimeLimitedV2Creds(ctx context.Context, order *Order) ([]TimeAwareSubIssuedCreds, int, error) { - var resp = []TimeAwareSubIssuedCreds{} // browser api_request_helper does not understand "null" as json - if order == nil { - return resp, http.StatusBadRequest, fmt.Errorf("error order cannot be nil") - } - - outboxMessages, err := s.Datastore.GetSigningOrderRequestOutboxByOrder(ctx, order.ID) +// +// Browser's api_request_helper does not understand Go's nil slices, hence explicit empty slice is returned. +func (s *Service) GetTimeLimitedV2Creds(ctx context.Context, orderID, itemID, reqID uuid.UUID) ([]TimeAwareSubIssuedCreds, int, error) { + obmsg, err := s.Datastore.GetSigningOrderRequestOutboxByRequestID(ctx, s.Datastore.RawDB(), reqID) if err != nil { - return resp, http.StatusInternalServerError, fmt.Errorf("error getting outbox messages: %w", err) + if errors.Is(err, sql.ErrNoRows) { + return []TimeAwareSubIssuedCreds{}, http.StatusNotFound, errLegacyOutboxNotFound + } + + return []TimeAwareSubIssuedCreds{}, http.StatusInternalServerError, fmt.Errorf("error getting outbox messages: %w", err) } - if len(outboxMessages) == 0 { - return resp, http.StatusNotFound, errors.New("error no order credentials have been submitted for signing") + if !uuid.Equal(obmsg.OrderID, orderID) { + return []TimeAwareSubIssuedCreds{}, http.StatusBadRequest, errWrongOrderIDForRequestID } - for _, m := range outboxMessages { - if m.CompletedAt == nil { - // get average of last 10 outbox messages duration as the retry after - return resp, http.StatusAccepted, errSetRetryAfter - } + if obmsg.CompletedAt == nil { + // Get average of last 10 outbox messages duration as the retry after. + return []TimeAwareSubIssuedCreds{}, http.StatusAccepted, errSetRetryAfter } - creds, err := s.Datastore.GetTimeLimitedV2OrderCredsByOrder(order.ID) + creds, err := s.Datastore.GetTLV2Creds(ctx, s.Datastore.RawDB(), orderID, itemID, reqID) if err != nil { - return resp, http.StatusInternalServerError, fmt.Errorf("error getting credentials: %w", err) - } + if errors.Is(err, errNoTLV2Creds) { + // Credentials could be signed, but nothing to return as they are all expired. + return []TimeAwareSubIssuedCreds{}, http.StatusOK, nil + } - // Potentially we can have all creds signed but nothing to return as they are all expired. - if creds == nil { - return resp, http.StatusOK, nil + return []TimeAwareSubIssuedCreds{}, http.StatusInternalServerError, fmt.Errorf("error getting credentials: %w", err) } return creds.Credentials, http.StatusOK, nil @@ -1234,69 +1240,66 @@ func timeChunking(ctx context.Context, issuerID string, timeLimitedSecret crypto return credentials, nil } -// GetTimeLimitedCreds get an order's time limited creds -func (s *Service) GetTimeLimitedCreds(ctx context.Context, order *Order) ([]TimeLimitedCreds, int, error) { - if order == nil { - return nil, http.StatusBadRequest, fmt.Errorf("failed to create credentials, bad order") - } - - // is the order paid? +// GetTimeLimitedCreds returns get an order's time limited creds. +func (s *Service) GetTimeLimitedCreds(ctx context.Context, order *Order, itemID, reqID uuid.UUID) ([]TimeLimitedCreds, int, error) { if !order.IsPaid() || order.LastPaidAt == nil { - return nil, http.StatusBadRequest, fmt.Errorf("order is not paid, or invalid last paid at") + return nil, http.StatusBadRequest, model.Error("order is not paid, or invalid last paid at") } issuedAt := order.LastPaidAt - // if the order has an expiry, use that if order.ExpiresAt != nil { - // check if we are past expiration, if so issue nothing + // Check if it's past expiration, if so issue nothing. if time.Now().After(*order.ExpiresAt) { - return nil, http.StatusBadRequest, fmt.Errorf("order has expired") + return nil, http.StatusBadRequest, model.Error("order has expired") } } - var credentials []TimeLimitedCreds secret, err := s.GetActiveCredentialSigningKey(ctx, order.MerchantID) if err != nil { return nil, http.StatusInternalServerError, fmt.Errorf("failed to get merchant signing key: %w", err) } + timeLimitedSecret := cryptography.NewTimeLimitedSecret(secret) - for _, item := range order.Items { + item, ok := order.HasItem(itemID) + if !ok { + return nil, http.StatusBadRequest, model.Error("could not find specified item") + } - if item.ValidForISO == nil { - return nil, http.StatusBadRequest, fmt.Errorf("order item has no valid for time") - } - duration, err := timeutils.ParseDuration(*(item.ValidForISO)) - if err != nil { - return nil, http.StatusInternalServerError, fmt.Errorf("unable to parse order duration for credentials") - } + if item.ValidForISO == nil { + return nil, http.StatusBadRequest, model.Error("order item has no valid for time") + } - if item.IssuanceIntervalISO == nil { - item.IssuanceIntervalISO = new(string) - *(item.IssuanceIntervalISO) = "P1D" - } - interval, err := timeutils.ParseDuration(*(item.IssuanceIntervalISO)) - if err != nil { - return nil, http.StatusInternalServerError, fmt.Errorf("unable to parse issuance interval for credentials") - } + duration, err := timeutils.ParseDuration(*item.ValidForISO) + if err != nil { + return nil, http.StatusInternalServerError, model.Error("unable to parse order duration for credentials") + } - issuerID, err := encodeIssuerID(order.MerchantID, item.SKU) - if err != nil { - return nil, http.StatusInternalServerError, fmt.Errorf("error encoding issuer: %w", err) - } + if item.IssuanceIntervalISO == nil { + item.IssuanceIntervalISO = ptrTo("P1D") + } - creds, err := timeChunking(ctx, issuerID, timeLimitedSecret, order.ID, item.ID, *issuedAt, *duration, *interval) - if err != nil { - return nil, http.StatusInternalServerError, fmt.Errorf("failed to derive credential chunking: %w", err) - } - credentials = append(credentials, creds...) + interval, err := timeutils.ParseDuration(*(item.IssuanceIntervalISO)) + if err != nil { + return nil, http.StatusInternalServerError, model.Error("unable to parse issuance interval for credentials") } - if len(credentials) > 0 { - return credentials, http.StatusOK, nil + issuerID, err := encodeIssuerID(order.MerchantID, item.SKU) + if err != nil { + return nil, http.StatusInternalServerError, fmt.Errorf("error encoding issuer: %w", err) } - return nil, http.StatusBadRequest, fmt.Errorf("failed to issue credentials") + + credentials, err := timeChunking(ctx, issuerID, timeLimitedSecret, order.ID, item.ID, *issuedAt, *duration, *interval) + if err != nil { + return nil, http.StatusInternalServerError, fmt.Errorf("failed to derive credential chunking: %w", err) + } + + if len(credentials) == 0 { + return nil, http.StatusBadRequest, model.Error("failed to issue credentials") + } + + return credentials, http.StatusOK, nil } type credential interface { @@ -1937,3 +1940,7 @@ type tlv1CredPresentation struct { IssuedAt string `json:"issuedAt"` ExpiresAt string `json:"expiresAt"` } + +func ptrTo[T any](v T) *T { + return &v +}