From 660e52307a58ba29e05a71c4698af14b3b24f4b2 Mon Sep 17 00:00:00 2001 From: Julien Perrochet Date: Fri, 20 Sep 2024 18:04:16 +0200 Subject: [PATCH] [dss/RID] GetISA: optional forUpdate flag to allow for early locking (#1117) --- pkg/rid/application/isa.go | 8 ++++---- pkg/rid/application/isa_test.go | 2 +- pkg/rid/repos/isa.go | 2 +- pkg/rid/store/cockroach/garbage_collector_test.go | 4 ++-- pkg/rid/store/cockroach/identification_service_area.go | 5 +++-- .../store/cockroach/identification_service_area_test.go | 6 +++--- pkg/rid/store/cockroach/identification_service_area_v3.go | 5 +++-- pkg/rid/store/cockroach/store_test.go | 2 +- pkg/sql/utils.go | 7 +++++++ 9 files changed, 25 insertions(+), 16 deletions(-) diff --git a/pkg/rid/application/isa.go b/pkg/rid/application/isa.go index 5abb8202d..dd9ad3fc1 100644 --- a/pkg/rid/application/isa.go +++ b/pkg/rid/application/isa.go @@ -38,7 +38,7 @@ func (a *app) GetISA(ctx context.Context, id dssmodels.ID) (*ridmodels.Identific if err != nil { return nil, stacktrace.Propagate(err, "Unable to interact with store") } - return repo.GetISA(ctx, id) + return repo.GetISA(ctx, id, false) } // SearchISAs for ISA within the volume bounds. @@ -64,7 +64,7 @@ func (a *app) DeleteISA(ctx context.Context, id dssmodels.ID, owner dssmodels.Ow ) // The following will automatically retry TXN retry errors. err := a.Store.Transact(ctx, func(repo repos.Repository) error { - old, err := repo.GetISA(ctx, id) + old, err := repo.GetISA(ctx, id, true) switch { case err != nil: return stacktrace.Propagate(err, "Error getting ISA") @@ -106,7 +106,7 @@ func (a *app) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServic // The following will automatically retry TXN retry errors. err := a.Store.Transact(ctx, func(repo repos.Repository) error { // ensure it doesn't exist yet - old, err := repo.GetISA(ctx, isa.ID) + old, err := repo.GetISA(ctx, isa.ID, false) if err != nil { return stacktrace.Propagate(err, "Error getting ISA") } @@ -141,7 +141,7 @@ func (a *app) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServic err := a.Store.Transact(ctx, func(repo repos.Repository) error { var err error - old, err := repo.GetISA(ctx, isa.ID) + old, err := repo.GetISA(ctx, isa.ID, true) switch { case err != nil: return stacktrace.Propagate(err, "Error getting ISA") diff --git a/pkg/rid/application/isa_test.go b/pkg/rid/application/isa_test.go index 2cf54f424..155162c7f 100644 --- a/pkg/rid/application/isa_test.go +++ b/pkg/rid/application/isa_test.go @@ -31,7 +31,7 @@ type isaStore struct { isas map[dssmodels.ID]*ridmodels.IdentificationServiceArea } -func (store *isaStore) GetISA(ctx context.Context, id dssmodels.ID) (*ridmodels.IdentificationServiceArea, error) { +func (store *isaStore) GetISA(ctx context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) { if isa, ok := store.isas[id]; ok { return isa, nil } diff --git a/pkg/rid/repos/isa.go b/pkg/rid/repos/isa.go index 26ef2026d..3016f6978 100644 --- a/pkg/rid/repos/isa.go +++ b/pkg/rid/repos/isa.go @@ -12,7 +12,7 @@ import ( // ISA is an interface to a storage layer for the ISA entity type ISA interface { // Returns nil, nil if not found - GetISA(ctx context.Context, id dssmodels.ID) (*ridmodels.IdentificationServiceArea, error) + GetISA(ctx context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) // DeleteISA deletes the IdentificationServiceArea identified by "id" and owned by "owner". // Returns the delete IdentificationServiceArea and all Subscriptions affected by the delete. diff --git a/pkg/rid/store/cockroach/garbage_collector_test.go b/pkg/rid/store/cockroach/garbage_collector_test.go index f525e8276..02452ca22 100644 --- a/pkg/rid/store/cockroach/garbage_collector_test.go +++ b/pkg/rid/store/cockroach/garbage_collector_test.go @@ -42,7 +42,7 @@ func TestDeleteExpiredISAs(t *testing.T) { require.NoError(t, err) require.NotNil(t, saOut) - ret, err := repo.GetISA(ctx, serviceArea.ID) + ret, err := repo.GetISA(ctx, serviceArea.ID, false) require.NoError(t, err) require.NotNil(t, ret) @@ -50,7 +50,7 @@ func TestDeleteExpiredISAs(t *testing.T) { err = gc.DeleteRIDExpiredRecords(ctx) require.NoError(t, err) - ret, err = repo.GetISA(ctx, serviceArea.ID) + ret, err = repo.GetISA(ctx, serviceArea.ID, false) require.NoError(t, err) require.Nil(t, ret) } diff --git a/pkg/rid/store/cockroach/identification_service_area.go b/pkg/rid/store/cockroach/identification_service_area.go index ba6fc1ce6..e5b7ff627 100644 --- a/pkg/rid/store/cockroach/identification_service_area.go +++ b/pkg/rid/store/cockroach/identification_service_area.go @@ -101,12 +101,13 @@ func (c *isaRepo) processOne(ctx context.Context, query string, args ...interfac // GetISA returns the isa identified by "id". // Returns nil, nil if not found -func (c *isaRepo) GetISA(ctx context.Context, id dssmodels.ID) (*ridmodels.IdentificationServiceArea, error) { +func (c *isaRepo) GetISA(ctx context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) { var query = fmt.Sprintf(` SELECT %s FROM identification_service_areas WHERE - id = $1`, isaFields) + id = $1 + %s`, isaFields, dssql.ForUpdate(forUpdate)) uid, err := id.PgUUID() if err != nil { return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") diff --git a/pkg/rid/store/cockroach/identification_service_area_test.go b/pkg/rid/store/cockroach/identification_service_area_test.go index 3196441b6..703b151f0 100644 --- a/pkg/rid/store/cockroach/identification_service_area_test.go +++ b/pkg/rid/store/cockroach/identification_service_area_test.go @@ -186,7 +186,7 @@ func TestStoreExpiredISA(t *testing.T) { require.NoError(t, err) require.Len(t, serviceAreas, 1) - ret, err := repo.GetISA(ctx, serviceArea.ID) + ret, err := repo.GetISA(ctx, serviceArea.ID, false) require.NoError(t, err) require.NotNil(t, ret) @@ -199,7 +199,7 @@ func TestStoreExpiredISA(t *testing.T) { require.Len(t, serviceAreas, 0) // A get should work even if it is expired. - ret, err = repo.GetISA(ctx, serviceArea.ID) + ret, err = repo.GetISA(ctx, serviceArea.ID, false) require.NoError(t, err) require.NotNil(t, ret) } @@ -222,7 +222,7 @@ func TestStoreDeleteISAs(t *testing.T) { // Delete the ISA. // Ensure a fresh Get, then delete still updates the sub indexes - isa, err = repo.GetISA(ctx, isa.ID) + isa, err = repo.GetISA(ctx, isa.ID, false) require.NoError(t, err) serviceAreaOut, err := repo.DeleteISA(ctx, isa) diff --git a/pkg/rid/store/cockroach/identification_service_area_v3.go b/pkg/rid/store/cockroach/identification_service_area_v3.go index 96d9c3c4f..121d18579 100644 --- a/pkg/rid/store/cockroach/identification_service_area_v3.go +++ b/pkg/rid/store/cockroach/identification_service_area_v3.go @@ -82,12 +82,13 @@ func (c *isaRepoV3) processOne(ctx context.Context, query string, args ...interf // GetISA returns the isa identified by "id". // Returns nil, nil if not found -func (c *isaRepoV3) GetISA(ctx context.Context, id dssmodels.ID) (*ridmodels.IdentificationServiceArea, error) { +func (c *isaRepoV3) GetISA(ctx context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) { var query = fmt.Sprintf(` SELECT %s FROM identification_service_areas WHERE - id = $1`, isaFieldsV3) + id = $1 + %s`, isaFieldsV3, dssql.ForUpdate(forUpdate)) uid, err := id.PgUUID() if err != nil { return nil, stacktrace.Propagate(err, "Failed to convert id to PgUUID") diff --git a/pkg/rid/store/cockroach/store_test.go b/pkg/rid/store/cockroach/store_test.go index 01a76d07c..08d235f0a 100644 --- a/pkg/rid/store/cockroach/store_test.go +++ b/pkg/rid/store/cockroach/store_test.go @@ -116,7 +116,7 @@ func TestTxnRetrier(t *testing.T) { repo, err := store.Interact(ctx) require.NoError(t, err) - isa, err := repo.GetISA(ctx, serviceArea.ID) + isa, err := repo.GetISA(ctx, serviceArea.ID, false) require.NoError(t, err) require.NotNil(t, isa) diff --git a/pkg/sql/utils.go b/pkg/sql/utils.go index 420107602..f026174be 100644 --- a/pkg/sql/utils.go +++ b/pkg/sql/utils.go @@ -26,3 +26,10 @@ func CellUnionToCellIdsWithValidation(cu s2.CellUnion) ([]int64, error) { } return pgCids, nil } + +func ForUpdate(forUpdate bool) string { + if forUpdate { + return "FOR UPDATE" + } + return "" +}