diff --git a/pkg/rid/store/cockroach/identification_service_area.go b/pkg/rid/store/cockroach/identification_service_area.go index e5b7ff627..ecfb77966 100644 --- a/pkg/rid/store/cockroach/identification_service_area.go +++ b/pkg/rid/store/cockroach/identification_service_area.go @@ -9,14 +9,11 @@ import ( "github.com/interuss/dss/pkg/geo" dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" - repos "github.com/interuss/dss/pkg/rid/repos" dssql "github.com/interuss/dss/pkg/sql" "github.com/interuss/stacktrace" - "github.com/coreos/go-semver/semver" "github.com/golang/geo/s2" "github.com/jackc/pgx/v5/pgtype" - "go.uber.org/zap" ) const ( @@ -24,28 +21,8 @@ const ( updateISAFields = "id, url, cells, starts_at, ends_at, writer, updated_at" ) -func NewISARepo(ctx context.Context, db dssql.Queryable, dbVersion semver.Version, logger *zap.Logger) repos.ISA { - if dbVersion.Compare(v400) >= 0 { - return &isaRepo{ - Queryable: db, - logger: logger, - } - } - return &isaRepoV3{ - Queryable: db, - logger: logger, - } -} - -// isaRepo is an implementation of the ISARepo for CRDB. -type isaRepo struct { - dssql.Queryable - - logger *zap.Logger -} - -func (c *isaRepo) process(ctx context.Context, query string, args ...interface{}) ([]*ridmodels.IdentificationServiceArea, error) { - rows, err := c.Query(ctx, query, args...) +func (r *repo) fetchISAs(ctx context.Context, query string, args ...interface{}) ([]*ridmodels.IdentificationServiceArea, error) { + rows, err := r.Query(ctx, query, args...) if err != nil { return nil, stacktrace.Propagate(err, fmt.Sprintf("Error in query: %s", query)) } @@ -85,8 +62,8 @@ func (c *isaRepo) process(ctx context.Context, query string, args ...interface{} return payload, nil } -func (c *isaRepo) processOne(ctx context.Context, query string, args ...interface{}) (*ridmodels.IdentificationServiceArea, error) { - isas, err := c.process(ctx, query, args...) +func (r *repo) fetchISA(ctx context.Context, query string, args ...interface{}) (*ridmodels.IdentificationServiceArea, error) { + isas, err := r.fetchISAs(ctx, query, args...) if err != nil { return nil, err // No need to Propagate this error as this stack layer does not add useful information } @@ -101,7 +78,7 @@ func (c *isaRepo) processOne(ctx context.Context, query string, args ...interfac // GetISA returns the isa identified by "id". // Returns nil, nil if not found -func (c *isaRepo) GetISA(ctx context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) { +func (r *repo) GetISA(ctx context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) { var query = fmt.Sprintf(` SELECT %s FROM identification_service_areas @@ -112,7 +89,7 @@ func (c *isaRepo) GetISA(ctx context.Context, id dssmodels.ID, forUpdate bool) ( if err != nil { return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") } - return c.processOne(ctx, query, uid) + return r.fetchISA(ctx, query, uid) } // InsertISA inserts the IdentificationServiceArea identified by "id" and owned @@ -122,7 +99,7 @@ func (c *isaRepo) GetISA(ctx context.Context, id dssmodels.ID, forUpdate bool) ( // by it. // TODO: Simplify the logic to insert without a query, such that the insert fails // if there's an existing entity. -func (c *isaRepo) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { +func (r *repo) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { var ( insertAreasQuery = fmt.Sprintf(` INSERT INTO @@ -147,7 +124,7 @@ func (c *isaRepo) InsertISA(ctx context.Context, isa *ridmodels.IdentificationSe if err != nil { return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") } - return c.processOne(ctx, insertAreasQuery, id, isa.Owner, isa.URL, cids, isa.StartTime, isa.EndTime, isa.Writer) + return r.fetchISA(ctx, insertAreasQuery, id, isa.Owner, isa.URL, cids, isa.StartTime, isa.EndTime, isa.Writer) } @@ -158,7 +135,7 @@ func (c *isaRepo) InsertISA(ctx context.Context, isa *ridmodels.IdentificationSe // by it. // TODO: simplify the logic to just update, without the primary query. // Returns nil, nil if ID, version not found -func (c *isaRepo) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { +func (r *repo) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { var ( updateAreasQuery = fmt.Sprintf(` UPDATE @@ -177,13 +154,13 @@ func (c *isaRepo) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationSe if err != nil { return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") } - return c.processOne(ctx, updateAreasQuery, id, isa.URL, cids, isa.StartTime, isa.EndTime, isa.Version.ToTimestamp(), isa.Writer) + return r.fetchISA(ctx, updateAreasQuery, id, isa.URL, cids, isa.StartTime, isa.EndTime, isa.Version.ToTimestamp(), isa.Writer) } // DeleteISA deletes the IdentificationServiceArea identified by "id" and owned by "owner". // Returns the delete IdentificationServiceArea and all Subscriptions affected by the delete. // Returns nil, nil if ID, version not found -func (c *isaRepo) DeleteISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { +func (r *repo) DeleteISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { var ( deleteQuery = fmt.Sprintf(` DELETE FROM @@ -198,13 +175,13 @@ func (c *isaRepo) DeleteISA(ctx context.Context, isa *ridmodels.IdentificationSe if err != nil { return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") } - return c.processOne(ctx, deleteQuery, id, isa.Version.ToTimestamp()) + return r.fetchISA(ctx, deleteQuery, id, isa.Version.ToTimestamp()) } // SearchISAs searches IdentificationServiceArea // instances that intersect with "cells" and, if set, the temporal volume // defined by "earliest" and "latest". -func (c *isaRepo) SearchISAs(ctx context.Context, cells s2.CellUnion, earliest *time.Time, latest *time.Time) ([]*ridmodels.IdentificationServiceArea, error) { +func (r *repo) SearchISAs(ctx context.Context, cells s2.CellUnion, earliest *time.Time, latest *time.Time) ([]*ridmodels.IdentificationServiceArea, error) { var ( // TODO: make earliest and latest required (NOT NULL) and remove coalesce. // Make them real values (not pointers), on the model layer. @@ -230,13 +207,13 @@ func (c *isaRepo) SearchISAs(ctx context.Context, cells s2.CellUnion, earliest * return nil, stacktrace.NewError("Earliest start time is missing") } - return c.process(ctx, isasInCellsQuery, earliest, latest, dssql.CellUnionToCellIds(cells), dssmodels.MaxResultLimit) + return r.fetchISAs(ctx, isasInCellsQuery, earliest, latest, dssql.CellUnionToCellIds(cells), dssmodels.MaxResultLimit) } // ListExpiredISAs lists all expired ISAs based on writer. // Records expire if current time is minutes more than records' endTime. // The function queries both empty writer and null writer when passing empty string as a writer. -func (c *isaRepo) ListExpiredISAs(ctx context.Context, writer string) ([]*ridmodels.IdentificationServiceArea, error) { +func (r *repo) ListExpiredISAs(ctx context.Context, writer string) ([]*ridmodels.IdentificationServiceArea, error) { writerQuery := "'" + writer + "'" if len(writer) == 0 { writerQuery = "'' OR writer = NULL" @@ -255,5 +232,5 @@ func (c *isaRepo) ListExpiredISAs(ctx context.Context, writer string) ([]*ridmod LIMIT $1`, isaFields, expiredDurationInMin, writerQuery) ) - return c.process(ctx, isasInCellsQuery, dssmodels.MaxResultLimit) + return r.fetchISAs(ctx, isasInCellsQuery, dssmodels.MaxResultLimit) } diff --git a/pkg/rid/store/cockroach/identification_service_area_v3.go b/pkg/rid/store/cockroach/identification_service_area_v3.go deleted file mode 100644 index 121d18579..000000000 --- a/pkg/rid/store/cockroach/identification_service_area_v3.go +++ /dev/null @@ -1,215 +0,0 @@ -package cockroach - -import ( - "context" - "fmt" - "time" - - dsserr "github.com/interuss/dss/pkg/errors" - dssmodels "github.com/interuss/dss/pkg/models" - ridmodels "github.com/interuss/dss/pkg/rid/models" - dssql "github.com/interuss/dss/pkg/sql" - "github.com/interuss/stacktrace" - - "github.com/golang/geo/s2" - "go.uber.org/zap" -) - -const ( - isaFieldsV3 = "id, owner, url, cells, starts_at, ends_at, updated_at" - updateISAFieldsV3 = "id, url, cells, starts_at, ends_at, updated_at" -) - -// The purpose od isaRepoV3 is solely to support backwards compatibility -// It will be deleted from the codebase when all existing production deployments have been upgraded to 3.1.0+. -type isaRepoV3 struct { - dssql.Queryable - - logger *zap.Logger -} - -func (c *isaRepoV3) process(ctx context.Context, query string, args ...interface{}) ([]*ridmodels.IdentificationServiceArea, error) { - rows, err := c.Query(ctx, query, args...) - if err != nil { - return nil, stacktrace.Propagate(err, fmt.Sprintf("Error in query: %s", query)) - } - defer rows.Close() - - var payload []*ridmodels.IdentificationServiceArea - var cids []int64 - - for rows.Next() { - i := new(ridmodels.IdentificationServiceArea) - - var updateTime time.Time - - err := rows.Scan( - &i.ID, - &i.Owner, - &i.URL, - &cids, - &i.StartTime, - &i.EndTime, - &updateTime, - ) - if err != nil { - return nil, stacktrace.Propagate(err, "Error scanning ISA row") - } - i.SetCells(cids) - i.Version = dssmodels.VersionFromTime(updateTime) - payload = append(payload, i) - } - if err := rows.Err(); err != nil { - return nil, stacktrace.Propagate(err, "Error in rows query result") - } - - return payload, nil -} - -func (c *isaRepoV3) processOne(ctx context.Context, query string, args ...interface{}) (*ridmodels.IdentificationServiceArea, error) { - isas, err := c.process(ctx, query, args...) - if err != nil { - return nil, err // No need to Propagate this error as this stack layer does not add useful information - } - if len(isas) > 1 { - return nil, stacktrace.NewError("Query returned %d identification_service_areas when only 0 or 1 was expected", len(isas)) - } - if len(isas) == 0 { - return nil, nil - } - return isas[0], nil -} - -// GetISA returns the isa identified by "id". -// Returns nil, nil if not found -func (c *isaRepoV3) GetISA(ctx context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) { - var query = fmt.Sprintf(` - SELECT %s FROM - identification_service_areas - WHERE - id = $1 - %s`, isaFieldsV3, dssql.ForUpdate(forUpdate)) - uid, err := id.PgUUID() - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") - } - return c.processOne(ctx, query, uid) -} - -// InsertISA inserts the IdentificationServiceArea identified by "id" and owned -// by "owner", affecting "cells" in the time interval ["starts", "ends"]. -// -// Returns the created IdentificationServiceArea and all Subscriptions affected -// by it. -// TODO: Simplify the logic to insert without a query, such that the insert fails -// if there's an existing entity. -func (c *isaRepoV3) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - var ( - insertAreasQuery = fmt.Sprintf(` - INSERT INTO - identification_service_areas - (%s) - VALUES - ($1, $2, $3, $4, $5, $6, transaction_timestamp()) - RETURNING - %s`, isaFieldsV3, isaFieldsV3) - ) - - cids, err := dssql.CellUnionToCellIdsWithValidation(isa.Cells) - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert array to jackc/pgtype") - } - uid, err := isa.ID.PgUUID() - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") - } - return c.processOne(ctx, insertAreasQuery, uid, isa.Owner, isa.URL, cids, isa.StartTime, isa.EndTime) -} - -// UpdateISA updates the IdentificationServiceArea identified by "id" and owned -// by "owner", affecting "cells" in the time interval ["starts", "ends"]. -// -// Returns the created IdentificationServiceArea and all Subscriptions affected -// by it. -// TODO: simplify the logic to just update, without the primary query. -// Returns nil, nil if ID, version not found -func (c *isaRepoV3) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - var ( - updateAreasQuery = fmt.Sprintf(` - UPDATE - identification_service_areas - SET (%s) = ($1, $2, $3, $4, $5, transaction_timestamp()) - WHERE id = $1 AND updated_at = $6 - RETURNING - %s`, updateISAFieldsV3, isaFieldsV3) - ) - - cids, err := dssql.CellUnionToCellIdsWithValidation(isa.Cells) - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert array to jackc/pgtype") - } - - uid, err := isa.ID.PgUUID() - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") - } - return c.processOne(ctx, updateAreasQuery, uid, isa.URL, cids, isa.StartTime, isa.EndTime, isa.Version.ToTimestamp()) -} - -// DeleteISA deletes the IdentificationServiceArea identified by "id" and owned by "owner". -// Returns the delete IdentificationServiceArea and all Subscriptions affected by the delete. -// Returns nil, nil if ID, version not found -func (c *isaRepoV3) DeleteISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - var ( - deleteQuery = fmt.Sprintf(` - DELETE FROM - identification_service_areas - WHERE - id = $1 - AND - updated_at = $2 - RETURNING %s`, isaFieldsV3) - ) - uid, err := isa.ID.PgUUID() - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") - } - return c.processOne(ctx, deleteQuery, uid, isa.Version.ToTimestamp()) -} - -// SearchISAs searches IdentificationServiceArea -// instances that intersect with "cells" and, if set, the temporal volume -// defined by "earliest" and "latest". -func (c *isaRepoV3) SearchISAs(ctx context.Context, cells s2.CellUnion, earliest *time.Time, latest *time.Time) ([]*ridmodels.IdentificationServiceArea, error) { - var ( - // TODO: make earliest and latest required (NOT NULL) and remove coalesce. - // Make them real values (not pointers), on the model layer. - isasInCellsQuery = fmt.Sprintf(` - SELECT - %s - FROM - identification_service_areas - WHERE - ends_at >= $1 - AND - COALESCE(starts_at <= $2, true) - AND - cells && $3 - LIMIT $4`, isaFieldsV3) - ) - - if len(cells) == 0 { - return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Missing cell IDs for query") - } - - if earliest == nil { - return nil, stacktrace.NewError("Earliest start time is missing") - } - - return c.process(ctx, isasInCellsQuery, earliest, latest, dssql.CellUnionToCellIds(cells), dssmodels.MaxResultLimit) -} - -// ListExpiredISAs returns empty. We don't support thi function in store v3.0 because db doesn't have 'writer' field. -func (c *isaRepoV3) ListExpiredISAs(ctx context.Context, writer string) ([]*ridmodels.IdentificationServiceArea, error) { - return make([]*ridmodels.IdentificationServiceArea, 0), nil -} diff --git a/pkg/rid/store/cockroach/store.go b/pkg/rid/store/cockroach/store.go index 076ded7de..b0c55f902 100644 --- a/pkg/rid/store/cockroach/store.go +++ b/pkg/rid/store/cockroach/store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/cockroachdb/cockroach-go/v2/crdb" "github.com/interuss/dss/pkg/datastore/flags" + dssql "github.com/interuss/dss/pkg/sql" "time" "github.com/cockroachdb/cockroach-go/v2/crdb/crdbpgxv5" @@ -35,13 +36,12 @@ var ( // deadline is used // TODO: use this in other function calls DefaultTimeout = 10 * time.Second - - v400 = *semver.New("4.0.0") ) type repo struct { - repos.ISA - repos.Subscription + dssql.Queryable + clock clockwork.Clock + logger *zap.Logger } // Store is an implementation of store.Store using Cockroach DB as its backend @@ -101,14 +101,10 @@ func (s *Store) CheckCurrentMajorSchemaVersion(ctx context.Context) error { // Interact implements store.Interactor interface. func (s *Store) Interact(ctx context.Context) (repos.Repository, error) { logger := logging.WithValuesFromContext(ctx, s.logger) - storeVersion, err := s.GetVersion(ctx) - if err != nil { - return nil, stacktrace.Propagate(err, "Error determining database RID schema version") - } - return &repo{ - ISA: NewISARepo(ctx, s.db.Pool, *storeVersion, logger), - Subscription: NewISASubscriptionRepo(ctx, s.db.Pool, *storeVersion, logger, s.clock), + Queryable: s.db.Pool, + clock: s.clock, + logger: logger, }, nil } @@ -125,16 +121,13 @@ func (s *Store) Transact(ctx context.Context, f func(repo repos.Repository) erro ctx = crdb.WithMaxRetries(ctx, flags.ConnectParameters().MaxRetries) - storeVersion, err := s.GetVersion(ctx) - if err != nil { - return stacktrace.Propagate(err, "Error determining database RID schema version") - } return crdbpgx.ExecuteTx(ctx, s.db.Pool, pgx.TxOptions{}, func(tx pgx.Tx) error { // Is this recover still necessary? defer recoverRollbackRepanic(ctx, tx) return f(&repo{ - ISA: NewISARepo(ctx, tx, *storeVersion, logger), - Subscription: NewISASubscriptionRepo(ctx, tx, *storeVersion, logger, s.clock), + Queryable: tx, + clock: s.clock, + logger: logger, }) }) } diff --git a/pkg/rid/store/cockroach/store_test.go b/pkg/rid/store/cockroach/store_test.go index ed47ea93b..63b4290d7 100644 --- a/pkg/rid/store/cockroach/store_test.go +++ b/pkg/rid/store/cockroach/store_test.go @@ -210,29 +210,19 @@ func TestBasicTxn(t *testing.T) { tx1, err := store.db.Pool.Begin(ctx) require.NoError(t, err) s1 := &repo{ - ISA: &isaRepo{ - Queryable: tx1, - logger: logging.Logger, - }, - Subscription: &subscriptionRepo{ - Queryable: tx1, - logger: logging.Logger, - clock: DefaultClock, - }, + Queryable: tx1, + logger: logging.Logger, + clock: DefaultClock, } tx2, err := store.db.Pool.Begin(ctx) require.NoError(t, err) s2 := &repo{ - ISA: &isaRepo{ - Queryable: tx2, - logger: logging.Logger, - }, - Subscription: &subscriptionRepo{ - Queryable: tx2, - logger: logging.Logger, - clock: DefaultClock, - }, + + Queryable: tx2, + logger: logging.Logger, + + clock: DefaultClock, } subs, err := s1.SearchSubscriptions(ctx, subscription1.Cells) diff --git a/pkg/rid/store/cockroach/subcriptions_v3.go b/pkg/rid/store/cockroach/subcriptions_v3.go deleted file mode 100644 index d31c25da0..000000000 --- a/pkg/rid/store/cockroach/subcriptions_v3.go +++ /dev/null @@ -1,280 +0,0 @@ -package cockroach - -import ( - "context" - "fmt" - "time" - - dsserr "github.com/interuss/dss/pkg/errors" - dssmodels "github.com/interuss/dss/pkg/models" - ridmodels "github.com/interuss/dss/pkg/rid/models" - "github.com/jonboulle/clockwork" - - "github.com/golang/geo/s2" - dssql "github.com/interuss/dss/pkg/sql" - "github.com/interuss/stacktrace" - "go.uber.org/zap" -) - -const ( - subscriptionFieldsV3 = "id, owner, url, notification_index, cells, starts_at, ends_at, updated_at" - updateSubscriptionFieldsV3 = "id, url, notification_index, cells, starts_at, ends_at, updated_at" -) - -// subscriptions is an implementation of the SubscriptionRepo for CRDB. -type subscriptionRepoV3 struct { - dssql.Queryable - - clock clockwork.Clock - logger *zap.Logger -} - -// process a query that should return one or many subscriptions. -func (c *subscriptionRepoV3) process(ctx context.Context, query string, args ...interface{}) ([]*ridmodels.Subscription, error) { - rows, err := c.Query(ctx, query, args...) - if err != nil { - return nil, stacktrace.Propagate(err, fmt.Sprintf("Error in query: %s", query)) - } - defer rows.Close() - - var payload []*ridmodels.Subscription - var cids []int64 - - for rows.Next() { - s := new(ridmodels.Subscription) - - var updateTime time.Time - - err := rows.Scan( - &s.ID, - &s.Owner, - &s.URL, - &s.NotificationIndex, - &cids, - &s.StartTime, - &s.EndTime, - &updateTime, - ) - if err != nil { - return nil, stacktrace.Propagate(err, "Error scanning Subscription row") - } - s.SetCells(cids) - s.Version = dssmodels.VersionFromTime(updateTime) - payload = append(payload, s) - } - if err := rows.Err(); err != nil { - return nil, stacktrace.Propagate(err, "Error in rows query result") - } - return payload, nil -} - -// processOne processes a query that should return exactly a single subscription. -func (c *subscriptionRepoV3) processOne(ctx context.Context, query string, args ...interface{}) (*ridmodels.Subscription, error) { - subs, err := c.process(ctx, query, args...) - if err != nil { - return nil, err // No need to Propagate this error as this stack layer does not add useful information - } - if len(subs) > 1 { - return nil, stacktrace.NewError("Query returned %d subscriptions when only 0 or 1 was expected", len(subs)) - } - if len(subs) == 0 { - return nil, nil - } - return subs[0], nil -} - -// MaxSubscriptionCountInCellsByOwner counts how many subscriptions the -// owner has in each one of these cells, and returns the number of subscriptions -// in the cell with the highest number of subscriptions. -func (c *subscriptionRepoV3) MaxSubscriptionCountInCellsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { - // TODO:steeling this query is expensive. The standard defines the max sub - // per "area", but area is loosely defined. Since we may not have to be so - // strict we could keep this count in memory, (or in some other storage). - var query = ` - SELECT - IFNULL(MAX(subscriptions_per_cell_id), 0) - FROM ( - SELECT - COUNT(*) AS subscriptions_per_cell_id - FROM ( - SELECT unnest(cells) as cell_id - FROM subscriptions - WHERE owner = $1 - AND ends_at >= $2 - ) - WHERE - cell_id = ANY($3) - GROUP BY cell_id - )` - - pgCids := dssql.CellUnionToCellIds(cells) - - row := c.QueryRow(ctx, query, owner, c.clock.Now(), pgCids) - var ret int - err := row.Scan(&ret) - return ret, stacktrace.Propagate(err, "Error scanning subscription count row") -} - -// GetSubscription returns the subscription identified by "id". -// Returns nil, nil if not found -func (c *subscriptionRepoV3) GetSubscription(ctx context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) { - // TODO(steeling) we should enforce startTime and endTime to not be null at the DB level. - var query = fmt.Sprintf(` - SELECT %s FROM subscriptions - WHERE id = $1`, subscriptionFieldsV3) - uid, err := id.PgUUID() - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") - } - return c.processOne(ctx, query, uid) -} - -// UpdateSubscription updates the Subscription.. not yet implemented. -// Returns nil, nil if ID, version not found -func (c *subscriptionRepoV3) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { - var ( - updateQuery = fmt.Sprintf(` - UPDATE - subscriptions - SET (%s) = ($1, $2, $3, $4, $5, $6, transaction_timestamp()) - WHERE id = $1 AND updated_at = $7 - RETURNING - %s`, updateSubscriptionFieldsV3, subscriptionFieldsV3) - ) - - cids, err := dssql.CellUnionToCellIdsWithValidation(s.Cells) - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert array to jackc/pgtype") - } - - id, err := s.ID.PgUUID() - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") - } - return c.processOne(ctx, updateQuery, - id, - s.URL, - s.NotificationIndex, - cids, - s.StartTime, - s.EndTime, - s.Version.ToTimestamp()) -} - -// InsertSubscription inserts subscription into the store and returns -// the resulting subscription including its ID. -func (c *subscriptionRepoV3) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { - var ( - insertQuery = fmt.Sprintf(` - INSERT INTO - subscriptions - (%s) - VALUES - ($1, $2, $3, $4, $5, $6, $7, transaction_timestamp()) - RETURNING - %s`, subscriptionFieldsV3, subscriptionFieldsV3) - ) - - cids, err := dssql.CellUnionToCellIdsWithValidation(s.Cells) - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert array to jackc/pgtype") - } - - uid, err := s.ID.PgUUID() - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") - } - return c.processOne(ctx, insertQuery, - uid, - s.Owner, - s.URL, - s.NotificationIndex, - cids, - s.StartTime, - s.EndTime) -} - -// DeleteSubscription deletes the subscription identified by ID. -// It must be done in a txn and the version verified. -// Returns nil, nil if ID, version not found -func (c *subscriptionRepoV3) DeleteSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { - var ( - query = fmt.Sprintf(` - DELETE FROM - subscriptions - WHERE - id = $1 - AND updated_at = $2 - RETURNING %s`, subscriptionFieldsV3) - ) - id, err := s.ID.PgUUID() - if err != nil { - return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") - } - return c.processOne(ctx, query, id, s.Version.ToTimestamp()) -} - -// UpdateNotificationIdxsInCells incremement the notification for each sub in the given cells. -func (c *subscriptionRepoV3) UpdateNotificationIdxsInCells(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { - var updateQuery = fmt.Sprintf(` - UPDATE subscriptions - SET notification_index = notification_index + 1 - WHERE - cells && $1 - AND ends_at >= $2 - RETURNING %s`, subscriptionFieldsV3) - - return c.process( - ctx, updateQuery, dssql.CellUnionToCellIds(cells), c.clock.Now()) -} - -// SearchSubscriptions returns all subscriptions in "cells". -func (c *subscriptionRepoV3) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { - var ( - query = fmt.Sprintf(` - SELECT - %s - FROM - subscriptions - WHERE - cells && $1 - AND - ends_at >= $2 - LIMIT $3`, subscriptionFieldsV3) - ) - - if len(cells) == 0 { - return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") - } - - return c.process(ctx, query, dssql.CellUnionToCellIds(cells), c.clock.Now(), dssmodels.MaxResultLimit) -} - -// SearchSubscriptionsByOwner returns all subscriptions in "cells". -func (c *subscriptionRepoV3) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { - var ( - query = fmt.Sprintf(` - SELECT - %s - FROM - subscriptions - WHERE - cells && $1 - AND - subscriptions.owner = $2 - AND - ends_at >= $3 - LIMIT $4`, subscriptionFieldsV3) - ) - - if len(cells) == 0 { - return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") - } - - return c.process(ctx, query, dssql.CellUnionToCellIds(cells), owner, c.clock.Now(), dssmodels.MaxResultLimit) -} - -// ListExpiredSubscriptions returns empty. We don't support this function in store v3.0 because db doesn't have 'writer' field. -func (c *subscriptionRepoV3) ListExpiredSubscriptions(ctx context.Context, writer string) ([]*ridmodels.Subscription, error) { - return make([]*ridmodels.Subscription, 0), nil -} diff --git a/pkg/rid/store/cockroach/subscriptions.go b/pkg/rid/store/cockroach/subscriptions.go index 93c262a2d..b801b4a82 100644 --- a/pkg/rid/store/cockroach/subscriptions.go +++ b/pkg/rid/store/cockroach/subscriptions.go @@ -8,16 +8,11 @@ import ( dsserr "github.com/interuss/dss/pkg/errors" dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" - repos "github.com/interuss/dss/pkg/rid/repos" dssql "github.com/interuss/dss/pkg/sql" "github.com/interuss/stacktrace" - "github.com/coreos/go-semver/semver" "github.com/golang/geo/s2" "github.com/jackc/pgx/v5/pgtype" - "github.com/jonboulle/clockwork" - - "go.uber.org/zap" ) const ( @@ -25,32 +20,9 @@ const ( updateSubscriptionFields = "id, url, notification_index, cells, starts_at, ends_at, writer, updated_at" ) -func NewISASubscriptionRepo(ctx context.Context, db dssql.Queryable, dbVersion semver.Version, logger *zap.Logger, clock clockwork.Clock) repos.Subscription { - if dbVersion.Compare(v400) >= 0 { - return &subscriptionRepo{ - Queryable: db, - logger: logger, - clock: clock, - } - } - return &subscriptionRepoV3{ - Queryable: db, - logger: logger, - clock: clock, - } -} - -// subscriptions is an implementation of the SubscriptionRepo for CRDB. -type subscriptionRepo struct { - dssql.Queryable - - clock clockwork.Clock - logger *zap.Logger -} - // process a query that should return one or many subscriptions. -func (c *subscriptionRepo) process(ctx context.Context, query string, args ...interface{}) ([]*ridmodels.Subscription, error) { - rows, err := c.Query(ctx, query, args...) +func (r *repo) process(ctx context.Context, query string, args ...interface{}) ([]*ridmodels.Subscription, error) { + rows, err := r.Query(ctx, query, args...) if err != nil { return nil, stacktrace.Propagate(err, fmt.Sprintf("Error in query: %s", query)) } @@ -92,8 +64,8 @@ func (c *subscriptionRepo) process(ctx context.Context, query string, args ...in } // processOne processes a query that should return exactly a single subscription. -func (c *subscriptionRepo) processOne(ctx context.Context, query string, args ...interface{}) (*ridmodels.Subscription, error) { - subs, err := c.process(ctx, query, args...) +func (r *repo) processOne(ctx context.Context, query string, args ...interface{}) (*ridmodels.Subscription, error) { + subs, err := r.process(ctx, query, args...) if err != nil { return nil, err // No need to Propagate this error as this stack layer does not add useful information } @@ -109,7 +81,7 @@ func (c *subscriptionRepo) processOne(ctx context.Context, query string, args .. // MaxSubscriptionCountInCellsByOwner counts how many subscriptions the // owner has in each one of these cells, and returns the number of subscriptions // in the cell with the highest number of subscriptions. -func (c *subscriptionRepo) MaxSubscriptionCountInCellsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { +func (r *repo) MaxSubscriptionCountInCellsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { // TODO:steeling this query is expensive. The standard defines the max sub // per "area", but area is loosely defined. Since we may not have to be so // strict we could keep this count in memory, (or in some other storage). @@ -130,7 +102,7 @@ func (c *subscriptionRepo) MaxSubscriptionCountInCellsByOwner(ctx context.Contex GROUP BY cell_id )` - row := c.QueryRow(ctx, query, owner, c.clock.Now(), dssql.CellUnionToCellIds(cells)) + row := r.QueryRow(ctx, query, owner, r.clock.Now(), dssql.CellUnionToCellIds(cells)) var ret int err := row.Scan(&ret) return ret, stacktrace.Propagate(err, "Error scanning subscription count row") @@ -138,7 +110,7 @@ func (c *subscriptionRepo) MaxSubscriptionCountInCellsByOwner(ctx context.Contex // GetSubscription returns the subscription identified by "id". // Returns nil, nil if not found -func (c *subscriptionRepo) GetSubscription(ctx context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) { +func (r *repo) GetSubscription(ctx context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) { // TODO(steeling) we should enforce startTime and endTime to not be null at the DB level. var query = fmt.Sprintf(` SELECT %s FROM subscriptions @@ -147,12 +119,12 @@ func (c *subscriptionRepo) GetSubscription(ctx context.Context, id dssmodels.ID) if err != nil { return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") } - return c.processOne(ctx, query, uid) + return r.processOne(ctx, query, uid) } // UpdateSubscription updates the Subscription.. not yet implemented. // Returns nil, nil if ID, version not found -func (c *subscriptionRepo) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { +func (r *repo) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { var ( updateQuery = fmt.Sprintf(` UPDATE @@ -173,7 +145,7 @@ func (c *subscriptionRepo) UpdateSubscription(ctx context.Context, s *ridmodels. if err != nil { return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") } - return c.processOne(ctx, updateQuery, + return r.processOne(ctx, updateQuery, id, s.URL, s.NotificationIndex, @@ -186,7 +158,7 @@ func (c *subscriptionRepo) UpdateSubscription(ctx context.Context, s *ridmodels. // InsertSubscription inserts subscription into the store and returns // the resulting subscription including its ID. -func (c *subscriptionRepo) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { +func (r *repo) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { var ( insertQuery = fmt.Sprintf(` INSERT INTO @@ -208,7 +180,7 @@ func (c *subscriptionRepo) InsertSubscription(ctx context.Context, s *ridmodels. if err != nil { return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") } - return c.processOne(ctx, insertQuery, + return r.processOne(ctx, insertQuery, id, s.Owner, s.URL, @@ -222,7 +194,7 @@ func (c *subscriptionRepo) InsertSubscription(ctx context.Context, s *ridmodels. // DeleteSubscription deletes the subscription identified by ID. // It must be done in a txn and the version verified. // Returns nil, nil if ID, version not found -func (c *subscriptionRepo) DeleteSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { +func (r *repo) DeleteSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { var ( query = fmt.Sprintf(` DELETE FROM @@ -236,11 +208,11 @@ func (c *subscriptionRepo) DeleteSubscription(ctx context.Context, s *ridmodels. if err != nil { return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") } - return c.processOne(ctx, query, id, s.Version.ToTimestamp()) + return r.processOne(ctx, query, id, s.Version.ToTimestamp()) } // UpdateNotificationIdxsInCells incremement the notification for each sub in the given cells. -func (c *subscriptionRepo) UpdateNotificationIdxsInCells(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { +func (r *repo) UpdateNotificationIdxsInCells(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { var updateQuery = fmt.Sprintf(` UPDATE subscriptions SET notification_index = notification_index + 1 @@ -249,12 +221,12 @@ func (c *subscriptionRepo) UpdateNotificationIdxsInCells(ctx context.Context, ce AND ends_at >= $2 RETURNING %s`, subscriptionFields) - return c.process( - ctx, updateQuery, dssql.CellUnionToCellIds(cells), c.clock.Now()) + return r.process( + ctx, updateQuery, dssql.CellUnionToCellIds(cells), r.clock.Now()) } // SearchSubscriptions returns all subscriptions in "cells". -func (c *subscriptionRepo) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { +func (r *repo) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { var ( query = fmt.Sprintf(` SELECT @@ -272,11 +244,11 @@ func (c *subscriptionRepo) SearchSubscriptions(ctx context.Context, cells s2.Cel return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") } - return c.process(ctx, query, dssql.CellUnionToCellIds(cells), c.clock.Now(), dssmodels.MaxResultLimit) + return r.process(ctx, query, dssql.CellUnionToCellIds(cells), r.clock.Now(), dssmodels.MaxResultLimit) } // SearchSubscriptionsByOwner returns all subscriptions in "cells". -func (c *subscriptionRepo) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { +func (r *repo) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { var ( query = fmt.Sprintf(` SELECT @@ -296,13 +268,13 @@ func (c *subscriptionRepo) SearchSubscriptionsByOwner(ctx context.Context, cells return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") } - return c.process(ctx, query, dssql.CellUnionToCellIds(cells), owner, c.clock.Now(), dssmodels.MaxResultLimit) + return r.process(ctx, query, dssql.CellUnionToCellIds(cells), owner, r.clock.Now(), dssmodels.MaxResultLimit) } // ListExpiredSubscriptions lists all expired Subscriptions based on writer. // Records expire if current time is minutes more than records' endTime. // The function queries both empty writer and null writer when passing empty string as a writer. -func (c *subscriptionRepo) ListExpiredSubscriptions(ctx context.Context, writer string) ([]*ridmodels.Subscription, error) { +func (r *repo) ListExpiredSubscriptions(ctx context.Context, writer string) ([]*ridmodels.Subscription, error) { writerQuery := "'" + writer + "'" if len(writer) == 0 { writerQuery = "'' OR writer = NULL" @@ -320,5 +292,5 @@ func (c *subscriptionRepo) ListExpiredSubscriptions(ctx context.Context, writer (writer = %s)`, subscriptionFields, expiredDurationInMin, writerQuery) ) - return c.process(ctx, query) + return r.process(ctx, query) } diff --git a/pkg/rid/store/cockroach/subscriptions_test.go b/pkg/rid/store/cockroach/subscriptions_test.go index 55c187f51..d5fea21e5 100644 --- a/pkg/rid/store/cockroach/subscriptions_test.go +++ b/pkg/rid/store/cockroach/subscriptions_test.go @@ -16,7 +16,7 @@ import ( var ( // Ensure the struct conforms to the interface - _ repos.Subscription = &subscriptionRepo{} + _ repos.Subscription = &repo{} subscriptionsPool = []struct { name string input *ridmodels.Subscription