From f0e589106f6bbb96514ab0802a7a10bcaba3f961 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Sat, 18 Jan 2025 01:21:04 +0000 Subject: [PATCH] Generic hooks for testing (#6938) ## What changed? - Add generic hook interface for fine-grained control of behavior in tests - Use the hooks for matching varying behavior tests (force load balancer to target partitions and disable sync match) - Use the hooks to force a race condition in an update-with-start test (by @stephanos) ## Why? To write integration/functional tests that require tweaking behavior of code under test, without affecting non-test builds. ## Potential risks Hooks are disabled by default, so there should be zero risk to production code, and zero overhead (assuming the Go compiler can do very basic inlining and dead code elimination). The downside is that functional tests now have to be run with `-tags test_dep` everywhere. --------- Co-authored-by: Stephan Behnke --- Makefile | 4 +- client/client_factory_mock.go | 9 +- client/clientfactory.go | 7 +- client/matching/loadbalancer.go | 51 +++++---- client/matching/loadbalancer_test.go | 54 +++++----- common/dynamicconfig/constants.go | 17 --- common/resource/fx.go | 4 + common/testing/testhooks/key.go | 32 ++++++ common/testing/testhooks/noop_impl.go | 56 ++++++++++ common/testing/testhooks/test_impl.go | 102 ++++++++++++++++++ service/history/api/multioperation/api.go | 10 +- service/history/history_engine.go | 5 + service/history/history_engine_factory.go | 3 + service/matching/config.go | 4 - service/matching/handler.go | 3 + service/matching/matching_engine.go | 4 + .../matching/physical_task_queue_manager.go | 3 +- tests/testcore/functional_test_base.go | 13 ++- tests/testcore/onebox.go | 22 +++- tests/update_workflow_test.go | 37 +++++++ 20 files changed, 354 insertions(+), 86 deletions(-) create mode 100644 common/testing/testhooks/key.go create mode 100644 common/testing/testhooks/noop_impl.go create mode 100644 common/testing/testhooks/test_impl.go diff --git a/Makefile b/Makefile index e2957f8038e..b3d2b9cdd2b 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ VISIBILITY_DB ?= temporal_visibility # Always use "protolegacy" tag to allow disabling utf-8 validation on proto messages # during proto library transition. ALL_BUILD_TAGS := protolegacy,$(BUILD_TAG) -ALL_TEST_TAGS := $(ALL_BUILD_TAGS),$(TEST_TAG) +ALL_TEST_TAGS := $(ALL_BUILD_TAGS),test_dep,$(TEST_TAG) BUILD_TAG_FLAG := -tags $(ALL_BUILD_TAGS) TEST_TAG_FLAG := -tags $(ALL_TEST_TAGS) @@ -347,7 +347,7 @@ lint-actions: $(ACTIONLINT) lint-code: $(GOLANGCI_LINT) @printf $(COLOR) "Linting code..." - @$(GOLANGCI_LINT) run --verbose --timeout 10m --fix=$(GOLANGCI_LINT_FIX) --new-from-rev=$(GOLANGCI_LINT_BASE_REV) --config=.golangci.yml + @$(GOLANGCI_LINT) run --verbose --build-tags $(ALL_TEST_TAGS) --timeout 10m --fix=$(GOLANGCI_LINT_FIX) --new-from-rev=$(GOLANGCI_LINT_BASE_REV) --config=.golangci.yml fmt-imports: $(GCI) # Don't get confused, there is a single linter called gci, which is a part of the mega linter we use is called golangci-lint. @printf $(COLOR) "Formatting imports..." diff --git a/client/client_factory_mock.go b/client/client_factory_mock.go index 00788b6554b..93246d97d40 100644 --- a/client/client_factory_mock.go +++ b/client/client_factory_mock.go @@ -46,6 +46,7 @@ import ( log "go.temporal.io/server/common/log" membership "go.temporal.io/server/common/membership" metrics "go.temporal.io/server/common/metrics" + testhooks "go.temporal.io/server/common/testing/testhooks" gomock "go.uber.org/mock/gomock" grpc "google.golang.org/grpc" ) @@ -187,15 +188,15 @@ func (m *MockFactoryProvider) EXPECT() *MockFactoryProviderMockRecorder { } // NewFactory mocks base method. -func (m *MockFactoryProvider) NewFactory(rpcFactory common.RPCFactory, monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, numberOfHistoryShards int32, logger, throttledLogger log.Logger) Factory { +func (m *MockFactoryProvider) NewFactory(rpcFactory common.RPCFactory, monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger, throttledLogger log.Logger) Factory { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewFactory", rpcFactory, monitor, metricsHandler, dc, numberOfHistoryShards, logger, throttledLogger) + ret := m.ctrl.Call(m, "NewFactory", rpcFactory, monitor, metricsHandler, dc, testHooks, numberOfHistoryShards, logger, throttledLogger) ret0, _ := ret[0].(Factory) return ret0 } // NewFactory indicates an expected call of NewFactory. -func (mr *MockFactoryProviderMockRecorder) NewFactory(rpcFactory, monitor, metricsHandler, dc, numberOfHistoryShards, logger, throttledLogger any) *gomock.Call { +func (mr *MockFactoryProviderMockRecorder) NewFactory(rpcFactory, monitor, metricsHandler, dc, testHooks, numberOfHistoryShards, logger, throttledLogger any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewFactory", reflect.TypeOf((*MockFactoryProvider)(nil).NewFactory), rpcFactory, monitor, metricsHandler, dc, numberOfHistoryShards, logger, throttledLogger) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewFactory", reflect.TypeOf((*MockFactoryProvider)(nil).NewFactory), rpcFactory, monitor, metricsHandler, dc, testHooks, numberOfHistoryShards, logger, throttledLogger) } diff --git a/client/clientfactory.go b/client/clientfactory.go index e3c7838378e..85366422b91 100644 --- a/client/clientfactory.go +++ b/client/clientfactory.go @@ -44,6 +44,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/primitives" + "go.temporal.io/server/common/testing/testhooks" "google.golang.org/grpc" ) @@ -65,6 +66,7 @@ type ( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, + testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -79,6 +81,7 @@ type ( monitor membership.Monitor metricsHandler metrics.Handler dynConfig *dynamicconfig.Collection + testHooks testhooks.TestHooks numberOfHistoryShards int32 logger log.Logger throttledLogger log.Logger @@ -103,6 +106,7 @@ func (p *factoryProviderImpl) NewFactory( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, + testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -112,6 +116,7 @@ func (p *factoryProviderImpl) NewFactory( monitor: monitor, metricsHandler: metricsHandler, dynConfig: dc, + testHooks: testHooks, numberOfHistoryShards: numberOfHistoryShards, logger: logger, throttledLogger: throttledLogger, @@ -159,7 +164,7 @@ func (cf *rpcClientFactory) NewMatchingClientWithTimeout( common.NewClientCache(keyResolver, clientProvider), cf.metricsHandler, cf.logger, - matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig), + matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig, cf.testHooks), ) if cf.metricsHandler != nil { diff --git a/client/matching/loadbalancer.go b/client/matching/loadbalancer.go index ddc7095316e..ee08ee6fd31 100644 --- a/client/matching/loadbalancer.go +++ b/client/matching/loadbalancer.go @@ -30,6 +30,7 @@ import ( "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/tqid" ) @@ -57,11 +58,10 @@ type ( } defaultLoadBalancer struct { - namespaceIDToName func(id namespace.ID) (namespace.Name, error) - nReadPartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter - nWritePartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter - forceReadPartition dynamicconfig.IntPropertyFn - forceWritePartition dynamicconfig.IntPropertyFn + namespaceIDToName func(id namespace.ID) (namespace.Name, error) + nReadPartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter + nWritePartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter + testHooks testhooks.TestHooks lock sync.RWMutex taskQueueLBs map[tqid.TaskQueue]*tqLoadBalancer @@ -85,15 +85,14 @@ type ( func NewLoadBalancer( namespaceIDToName func(id namespace.ID) (namespace.Name, error), dc *dynamicconfig.Collection, + testHooks testhooks.TestHooks, ) LoadBalancer { lb := &defaultLoadBalancer{ - namespaceIDToName: namespaceIDToName, - nReadPartitions: dynamicconfig.MatchingNumTaskqueueReadPartitions.Get(dc), - nWritePartitions: dynamicconfig.MatchingNumTaskqueueWritePartitions.Get(dc), - forceReadPartition: dynamicconfig.TestMatchingLBForceReadPartition.Get(dc), - forceWritePartition: dynamicconfig.TestMatchingLBForceWritePartition.Get(dc), - lock: sync.RWMutex{}, - taskQueueLBs: make(map[tqid.TaskQueue]*tqLoadBalancer), + namespaceIDToName: namespaceIDToName, + nReadPartitions: dynamicconfig.MatchingNumTaskqueueReadPartitions.Get(dc), + nWritePartitions: dynamicconfig.MatchingNumTaskqueueWritePartitions.Get(dc), + testHooks: testHooks, + taskQueueLBs: make(map[tqid.TaskQueue]*tqLoadBalancer), } return lb } @@ -101,7 +100,7 @@ func NewLoadBalancer( func (lb *defaultLoadBalancer) PickWritePartition( taskQueue *tqid.TaskQueue, ) *tqid.NormalPartition { - if n := lb.forceWritePartition(); n >= 0 { + if n, ok := testhooks.Get[int](lb.testHooks, testhooks.MatchingLBForceWritePartition); ok { return taskQueue.NormalPartition(n) } @@ -130,7 +129,11 @@ func (lb *defaultLoadBalancer) PickReadPartition( partitionCount = lb.nReadPartitions(string(namespaceName), taskQueue.Name(), taskQueue.TaskType()) } - return tqlb.pickReadPartition(partitionCount, lb.forceReadPartition()) + if n, ok := testhooks.Get[int](lb.testHooks, testhooks.MatchingLBForceWritePartition); ok { + return tqlb.forceReadPartition(partitionCount, n) + } + + return tqlb.pickReadPartition(partitionCount) } func (lb *defaultLoadBalancer) getTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer { @@ -157,18 +160,26 @@ func newTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer { } } -func (b *tqLoadBalancer) pickReadPartition(partitionCount int, forcedPartition int) *pollToken { +func (b *tqLoadBalancer) pickReadPartition(partitionCount int) *pollToken { b.lock.Lock() defer b.lock.Unlock() - // ensure we reflect dynamic config change if it ever happens - b.ensurePartitionCountLocked(max(partitionCount, forcedPartition+1)) + b.ensurePartitionCountLocked(partitionCount) + partitionID := b.pickReadPartitionWithFewestPolls(partitionCount) - partitionID := forcedPartition + b.pollerCounts[partitionID]++ - if partitionID < 0 { - partitionID = b.pickReadPartitionWithFewestPolls(partitionCount) + return &pollToken{ + TQPartition: b.taskQueue.NormalPartition(partitionID), + balancer: b, } +} + +func (b *tqLoadBalancer) forceReadPartition(partitionCount, partitionID int) *pollToken { + b.lock.Lock() + defer b.lock.Unlock() + + b.ensurePartitionCountLocked(max(partitionCount, partitionID+1)) b.pollerCounts[partitionID]++ diff --git a/client/matching/loadbalancer_test.go b/client/matching/loadbalancer_test.go index 65a2ae6421f..76a84253a16 100644 --- a/client/matching/loadbalancer_test.go +++ b/client/matching/loadbalancer_test.go @@ -63,22 +63,22 @@ func TestTQLoadBalancer(t *testing.T) { tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY)) // pick 4 times, each partition picked would have one poller - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) - p3 := tqlb.pickReadPartition(partitionCount, -1) + p3 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) // release one, and pick one, the newly picked one should have one poller p3.Release() - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) // pick one again, this time it should have 2 pollers - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) } @@ -89,27 +89,27 @@ func TestTQLoadBalancerForce(t *testing.T) { tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY)) // pick 4 times, each partition picked would have one poller - p1 := tqlb.pickReadPartition(partitionCount, 1) + p1 := tqlb.forceReadPartition(partitionCount, 1) assert.Equal(t, 1, p1.TQPartition.PartitionId()) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, 1) + tqlb.forceReadPartition(partitionCount, 1) assert.Equal(t, 2, maxPollerCount(tqlb)) // when we don't force it should balance out - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) // releasing the forced one and adding another should still be balanced p1.Release() - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 3, maxPollerCount(tqlb)) } @@ -125,7 +125,7 @@ func TestLoadBalancerConcurrent(t *testing.T) { for i := 0; i < concurrentCount; i++ { go func() { defer wg.Done() - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) }() } wg.Wait() @@ -142,23 +142,23 @@ func TestLoadBalancer_ReducedPartitionCount(t *testing.T) { f, err := tqid.NewTaskQueueFamily("fake-namespace-id", "fake-taskqueue") assert.NoError(t, err) tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY)) - p1 := tqlb.pickReadPartition(partitionCount, -1) - p2 := tqlb.pickReadPartition(partitionCount, -1) + p1 := tqlb.pickReadPartition(partitionCount) + p2 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) assert.Equal(t, 1, maxPollerCount(tqlb)) partitionCount += 2 // increase partition count - p3 := tqlb.pickReadPartition(partitionCount, -1) - p4 := tqlb.pickReadPartition(partitionCount, -1) + p3 := tqlb.pickReadPartition(partitionCount) + p4 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) assert.Equal(t, 1, maxPollerCount(tqlb)) partitionCount -= 2 // reduce partition count - p5 := tqlb.pickReadPartition(partitionCount, -1) - p6 := tqlb.pickReadPartition(partitionCount, -1) + p5 := tqlb.pickReadPartition(partitionCount) + p6 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) assert.Equal(t, 2, maxPollerCount(tqlb)) - p7 := tqlb.pickReadPartition(partitionCount, -1) + p7 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 3, maxPollerCount(tqlb)) // release all of them and it should be ok. @@ -170,11 +170,11 @@ func TestLoadBalancer_ReducedPartitionCount(t *testing.T) { p6.Release() p7.Release() - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) } diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index f793c037d18..388a53b4732 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -1260,23 +1260,6 @@ these log lines can be noisy, we want to be able to turn on and sample selective 1000, `MatchingMaxTaskQueuesInDeployment represents the maximum number of task-queues that can be registed in a single deployment`, ) - // for matching testing only: - - TestMatchingDisableSyncMatch = NewGlobalBoolSetting( - "test.matching.disableSyncMatch", - false, - `TestMatchingDisableSyncMatch forces tasks to go through the db once`, - ) - TestMatchingLBForceReadPartition = NewGlobalIntSetting( - "test.matching.lbForceReadPartition", - -1, - `TestMatchingLBForceReadPartition forces polls to go to a specific partition`, - ) - TestMatchingLBForceWritePartition = NewGlobalIntSetting( - "test.matching.lbForceWritePartition", - -1, - `TestMatchingLBForceWritePartition forces adds to go to a specific partition`, - ) // keys for history diff --git a/common/resource/fx.go b/common/resource/fx.go index 117d661a469..eb33b1883f4 100644 --- a/common/resource/fx.go +++ b/common/resource/fx.go @@ -63,6 +63,7 @@ import ( "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/telemetry" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/utf8validator" "go.uber.org/fx" "google.golang.org/grpc" @@ -129,6 +130,7 @@ var Module = fx.Options( deadlock.Module, config.Module, utf8validator.Module, + testhooks.Module, fx.Invoke(func(*utf8validator.Validator) {}), // force this to be constructed even if not referenced elsewhere ) @@ -227,6 +229,7 @@ func ClientFactoryProvider( membershipMonitor membership.Monitor, metricsHandler metrics.Handler, dynamicCollection *dynamicconfig.Collection, + testHooks testhooks.TestHooks, persistenceConfig *config.Persistence, logger log.SnTaggedLogger, throttledLogger log.ThrottledLogger, @@ -236,6 +239,7 @@ func ClientFactoryProvider( membershipMonitor, metricsHandler, dynamicCollection, + testHooks, persistenceConfig.NumHistoryShards, logger, throttledLogger, diff --git a/common/testing/testhooks/key.go b/common/testing/testhooks/key.go new file mode 100644 index 00000000000..eb8f32a10a5 --- /dev/null +++ b/common/testing/testhooks/key.go @@ -0,0 +1,32 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package testhooks + +type Key int + +const ( + MatchingDisableSyncMatch Key = iota + MatchingLBForceReadPartition + MatchingLBForceWritePartition + UpdateWithStartInBetweenLockAndStart +) diff --git a/common/testing/testhooks/noop_impl.go b/common/testing/testhooks/noop_impl.go new file mode 100644 index 00000000000..9179ceb6490 --- /dev/null +++ b/common/testing/testhooks/noop_impl.go @@ -0,0 +1,56 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:build !test_dep + +package testhooks + +import "go.uber.org/fx" + +var Module = fx.Options( + fx.Provide(func() (_ TestHooks) { return }), +) + +type ( + // TestHooks (in production mode) is an empty struct just so the build works. + // See TestHooks in test_impl.go. + // + // TestHooks are an inherently unclean way of writing tests. They require mixing test-only + // concerns into production code. In general you should prefer other ways of writing tests + // wherever possible, and only use TestHooks sparingly, as a last resort. + TestHooks struct{} +) + +// Get gets the value of a test hook. In production mode it always returns the zero value and +// false, which hopefully the compiler will inline and remove the hook as dead code. +// +// TestHooks should be used very sparingly, see comment on TestHooks. +func Get[T any](_ TestHooks, key Key) (T, bool) { + var zero T + return zero, false +} + +// Call calls a func() hook if present. +// +// TestHooks should be used very sparingly, see comment on TestHooks. +func Call(_ TestHooks, key Key) { +} diff --git a/common/testing/testhooks/test_impl.go b/common/testing/testhooks/test_impl.go new file mode 100644 index 00000000000..e50a6064e6e --- /dev/null +++ b/common/testing/testhooks/test_impl.go @@ -0,0 +1,102 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:build test_dep + +package testhooks + +import ( + "sync" + + "go.uber.org/fx" +) + +var Module = fx.Options( + fx.Provide(NewTestHooksImpl), +) + +type ( + // TestHooks holds a registry of active test hooks. It should be obtained through fx and + // used with Get and Set. + // + // TestHooks are an inherently unclean way of writing tests. They require mixing test-only + // concerns into production code. In general you should prefer other ways of writing tests + // wherever possible, and only use TestHooks sparingly, as a last resort. + TestHooks interface { + // private accessors; access must go through package-level Get/Set + get(Key) (any, bool) + set(Key, any) + del(Key) + } + + // testHooksImpl is an implementation of TestHooks. + testHooksImpl struct { + m sync.Map + } +) + +// Get gets the value of a test hook from the registry. +// +// TestHooks should be used very sparingly, see comment on TestHooks. +func Get[T any](th TestHooks, key Key) (T, bool) { + if val, ok := th.get(key); ok { + // this is only used in test so we want to panic on type mismatch: + return val.(T), ok // nolint:revive + } + var zero T + return zero, false +} + +// Call calls a func() hook if present. +// +// TestHooks should be used very sparingly, see comment on TestHooks. +func Call(th TestHooks, key Key) { + if hook, ok := Get[func()](th, key); ok { + hook() + } +} + +// Set sets a test hook to a value and returns a cleanup function to unset it. +// Calls to Set and the cleanup functions should form a stack. +func Set[T any](th TestHooks, key Key, val T) func() { + th.set(key, val) + return func() { th.del(key) } +} + +// NewTestHooksImpl returns a new instance of a test hook registry. This is provided and used +// in the main "resource" module as a default, but in functional tests, it's overridden by an +// explicitly constructed instance. +func NewTestHooksImpl() TestHooks { + return &testHooksImpl{} +} + +func (th *testHooksImpl) get(key Key) (any, bool) { + return th.m.Load(key) +} + +func (th *testHooksImpl) set(key Key, val any) { + th.m.Store(key, val) +} + +func (th *testHooksImpl) del(key Key) { + th.m.Delete(key) +} diff --git a/service/history/api/multioperation/api.go b/service/history/api/multioperation/api.go index a28a628d271..6dbda379278 100644 --- a/service/history/api/multioperation/api.go +++ b/service/history/api/multioperation/api.go @@ -39,6 +39,7 @@ import ( "go.temporal.io/server/common/locks" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/api" "go.temporal.io/server/service/history/api/startworkflow" "go.temporal.io/server/service/history/api/updateworkflow" @@ -46,9 +47,7 @@ import ( "go.temporal.io/server/service/history/workflow" ) -var ( - multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.") -) +var multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.") type ( // updateError is a wrapper to distinguish an update error from a start error. @@ -62,6 +61,7 @@ type ( shardContext shard.Context namespaceId namespace.ID consistencyChecker api.WorkflowConsistencyChecker + testHooks testhooks.TestHooks updateReq *historyservice.UpdateWorkflowExecutionRequest startReq *historyservice.StartWorkflowExecutionRequest @@ -79,6 +79,7 @@ func Invoke( tokenSerializer common.TaskTokenSerializer, visibilityManager manager.VisibilityManager, matchingClient matchingservice.MatchingServiceClient, + testHooks testhooks.TestHooks, ) (*historyservice.ExecuteMultiOperationResponse, error) { if len(req.Operations) != 2 { return nil, serviceerror.NewInvalidArgument("expected exactly 2 operations") @@ -98,6 +99,7 @@ func Invoke( shardContext: shardContext, namespaceId: namespace.ID(req.NamespaceId), consistencyChecker: workflowConsistencyChecker, + testHooks: testHooks, updateReq: updateReq, startReq: startReq, } @@ -157,6 +159,8 @@ func (mo *multiOp) Invoke(ctx context.Context) (*historyservice.ExecuteMultiOper return mo.updateWorkflow(ctx, runningWorkflowLease) } + testhooks.Call(mo.testHooks, testhooks.UpdateWithStartInBetweenLockAndStart) + // Workflow hasn't been started yet ... resp, err := mo.startAndUpdateWorkflow(ctx) var noStartErr *noStartError diff --git a/service/history/history_engine.go b/service/history/history_engine.go index b4339043594..ff7e2029c6d 100644 --- a/service/history/history_engine.go +++ b/service/history/history_engine.go @@ -55,6 +55,7 @@ import ( "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/api" "go.temporal.io/server/service/history/api/addtasks" "go.temporal.io/server/service/history/api/deleteworkflow" @@ -157,6 +158,7 @@ type ( replicationProgressCache replication.ProgressCache syncStateRetriever replication.SyncStateRetriever outboundQueueCBPool *circuitbreakerpool.OutboundQueueCircuitBreakerPool + testHooks testhooks.TestHooks } ) @@ -183,6 +185,7 @@ func NewEngineWithShardContext( dlqWriter replication.DLQWriter, commandHandlerRegistry *workflow.CommandHandlerRegistry, outboundQueueCBPool *circuitbreakerpool.OutboundQueueCircuitBreakerPool, + testHooks testhooks.TestHooks, ) shard.Engine { currentClusterName := shard.GetClusterMetadata().GetCurrentClusterName() @@ -232,6 +235,7 @@ func NewEngineWithShardContext( replicationProgressCache: replicationProgressCache, syncStateRetriever: syncStateRetriever, outboundQueueCBPool: outboundQueueCBPool, + testHooks: testHooks, } historyEngImpl.queueProcessors = make(map[tasks.Category]queues.Queue) @@ -429,6 +433,7 @@ func (e *historyEngineImpl) ExecuteMultiOperation( e.tokenSerializer, e.persistenceVisibilityMgr, e.matchingClient, + e.testHooks, ) } diff --git a/service/history/history_engine_factory.go b/service/history/history_engine_factory.go index de90aa4fb3c..a005be63929 100644 --- a/service/history/history_engine_factory.go +++ b/service/history/history_engine_factory.go @@ -32,6 +32,7 @@ import ( "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/resource" "go.temporal.io/server/common/sdk" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/api" "go.temporal.io/server/service/history/circuitbreakerpool" "go.temporal.io/server/service/history/configs" @@ -68,6 +69,7 @@ type ( ReplicationDLQWriter replication.DLQWriter CommandHandlerRegistry *workflow.CommandHandlerRegistry OutboundQueueCBPool *circuitbreakerpool.OutboundQueueCircuitBreakerPool + TestHooks testhooks.TestHooks } historyEngineFactory struct { @@ -108,5 +110,6 @@ func (f *historyEngineFactory) CreateEngine( f.ReplicationDLQWriter, f.CommandHandlerRegistry, f.OutboundQueueCBPool, + f.TestHooks, ) } diff --git a/service/matching/config.go b/service/matching/config.go index f6685526998..ec852c79c0e 100644 --- a/service/matching/config.go +++ b/service/matching/config.go @@ -47,7 +47,6 @@ type ( PersistenceDynamicRateLimitingParams dynamicconfig.TypedPropertyFn[dynamicconfig.DynamicRateLimitingParams] PersistenceQPSBurstRatio dynamicconfig.FloatPropertyFn SyncMatchWaitDuration dynamicconfig.DurationPropertyFnWithTaskQueueFilter - TestDisableSyncMatch dynamicconfig.BoolPropertyFn RPS dynamicconfig.IntPropertyFn OperatorRPSRatio dynamicconfig.FloatPropertyFn AlignMembershipChange dynamicconfig.DurationPropertyFn @@ -135,7 +134,6 @@ type ( BacklogNegligibleAge func() time.Duration MaxWaitForPollerBeforeFwd func() time.Duration QueryPollerUnavailableWindow func() time.Duration - TestDisableSyncMatch func() bool // Time to hold a poll request before returning an empty response if there are no tasks LongPollExpirationInterval func() time.Duration RangeSize int64 @@ -216,7 +214,6 @@ func NewConfig( PersistenceDynamicRateLimitingParams: dynamicconfig.MatchingPersistenceDynamicRateLimitingParams.Get(dc), PersistenceQPSBurstRatio: dynamicconfig.PersistenceQPSBurstRatio.Get(dc), SyncMatchWaitDuration: dynamicconfig.MatchingSyncMatchWaitDuration.Get(dc), - TestDisableSyncMatch: dynamicconfig.TestMatchingDisableSyncMatch.Get(dc), LoadUserData: dynamicconfig.MatchingLoadUserData.Get(dc), HistoryMaxPageSize: dynamicconfig.MatchingHistoryMaxPageSize.Get(dc), EnableDeployments: dynamicconfig.EnableDeployments.Get(dc), @@ -311,7 +308,6 @@ func newTaskQueueConfig(tq *tqid.TaskQueue, config *Config, ns namespace.Name) * return config.MaxWaitForPollerBeforeFwd(ns.String(), taskQueueName, taskType) }, QueryPollerUnavailableWindow: config.QueryPollerUnavailableWindow, - TestDisableSyncMatch: config.TestDisableSyncMatch, LongPollExpirationInterval: func() time.Duration { return config.LongPollExpirationInterval(ns.String(), taskQueueName, taskType) }, diff --git a/service/matching/handler.go b/service/matching/handler.go index c2ae8263a84..5a56de8457a 100644 --- a/service/matching/handler.go +++ b/service/matching/handler.go @@ -42,6 +42,7 @@ import ( "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/resource" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/tqid" "go.temporal.io/server/service/worker/deployment" "google.golang.org/protobuf/proto" @@ -88,6 +89,7 @@ func NewHandler( namespaceReplicationQueue persistence.NamespaceReplicationQueue, visibilityManager manager.VisibilityManager, nexusEndpointManager persistence.NexusEndpointManager, + testHooks testhooks.TestHooks, ) *Handler { handler := &Handler{ config: config, @@ -110,6 +112,7 @@ func NewHandler( namespaceReplicationQueue, visibilityManager, nexusEndpointManager, + testHooks, ), namespaceRegistry: namespaceRegistry, } diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index c975703a9be..f8e46091b8c 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -69,6 +69,7 @@ import ( "go.temporal.io/server/common/resource" serviceerrors "go.temporal.io/server/common/serviceerror" "go.temporal.io/server/common/tasktoken" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/tqid" "go.temporal.io/server/common/util" "go.temporal.io/server/common/worker_versioning" @@ -143,6 +144,7 @@ type ( partitions map[tqid.PartitionKey]taskQueuePartitionManager gaugeMetrics gaugeMetrics // per-namespace task queue counters config *Config + testHooks testhooks.TestHooks // queryResults maps query TaskID (which is a UUID generated in QueryWorkflow() call) to a channel // that QueryWorkflow() will block on. The channel is unblocked either by worker sending response through // RespondQueryTaskCompleted() or through an internal service error causing temporal to be unable to dispatch @@ -203,6 +205,7 @@ func NewEngine( namespaceReplicationQueue persistence.NamespaceReplicationQueue, visibilityManager manager.VisibilityManager, nexusEndpointManager persistence.NexusEndpointManager, + testHooks testhooks.TestHooks, ) Engine { scopedMetricsHandler := metricsHandler.WithTags(metrics.OperationTag(metrics.MatchingEngineScope)) e := &matchingEngineImpl{ @@ -233,6 +236,7 @@ func NewEngine( loadedPhysicalTaskQueueCount: make(map[taskQueueCounterKey]int), }, config: config, + testHooks: testHooks, queryResults: collection.NewSyncMap[string, chan *queryResult](), nexusResults: collection.NewSyncMap[string, chan *nexusResult](), outstandingPollers: collection.NewSyncMap[string, context.CancelFunc](), diff --git a/service/matching/physical_task_queue_manager.go b/service/matching/physical_task_queue_manager.go index bfddde6a953..a628bfb6654 100644 --- a/service/matching/physical_task_queue_manager.go +++ b/service/matching/physical_task_queue_manager.go @@ -50,6 +50,7 @@ import ( "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/worker_versioning" "go.temporal.io/server/service/worker/deployment" "google.golang.org/protobuf/types/known/durationpb" @@ -529,7 +530,7 @@ func (c *physicalTaskQueueManagerImpl) TrySyncMatch(ctx context.Context, task *i // request sent by history service c.liveness.markAlive() c.tasksAddedInIntervals.incrementTaskCount() - if c.config.TestDisableSyncMatch() { + if disable, _ := testhooks.Get[bool](c.partitionMgr.engine.testHooks, testhooks.MatchingDisableSyncMatch); disable { return false, nil } } diff --git a/tests/testcore/functional_test_base.go b/tests/testcore/functional_test_base.go index 971d5a81ce5..83587d4acbd 100644 --- a/tests/testcore/functional_test_base.go +++ b/tests/testcore/functional_test_base.go @@ -59,6 +59,7 @@ import ( "go.temporal.io/server/common/rpc" "go.temporal.io/server/common/testing/historyrequire" "go.temporal.io/server/common/testing/protorequire" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/testing/updateutils" "go.temporal.io/server/environment" "go.uber.org/fx" @@ -493,6 +494,10 @@ func (s *FunctionalTestBase) OverrideDynamicConfig(setting dynamicconfig.Generic return s.testCluster.host.overrideDynamicConfig(s.T(), setting.Key(), value) } +func (s *FunctionalTestBase) InjectHook(key testhooks.Key, value any) (cleanup func()) { + return s.testCluster.host.injectHook(s.T(), key, value) +} + func (s *FunctionalTestBase) GetNamespaceID(namespace string) string { namespaceResp, err := s.FrontendClient().DescribeNamespace(NewContext(), &workflowservice.DescribeNamespaceRequest{ Namespace: namespace, @@ -525,20 +530,20 @@ func (s *FunctionalTestBase) RunTestWithMatchingBehavior(subtest func()) { name, func() { if forceTaskForward { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueWritePartitions, 13) - s.OverrideDynamicConfig(dynamicconfig.TestMatchingLBForceWritePartition, 11) + s.InjectHook(testhooks.MatchingLBForceWritePartition, 11) } else { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueWritePartitions, 1) } if forcePollForward { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueReadPartitions, 13) - s.OverrideDynamicConfig(dynamicconfig.TestMatchingLBForceReadPartition, 5) + s.InjectHook(testhooks.MatchingLBForceReadPartition, 5) } else { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueReadPartitions, 1) } if forceAsync { - s.OverrideDynamicConfig(dynamicconfig.TestMatchingDisableSyncMatch, true) + s.InjectHook(testhooks.MatchingDisableSyncMatch, true) } else { - s.OverrideDynamicConfig(dynamicconfig.TestMatchingDisableSyncMatch, false) + s.InjectHook(testhooks.MatchingDisableSyncMatch, false) } subtest() diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index a5c034e1fca..bedac34aef3 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -71,6 +71,7 @@ import ( "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/telemetry" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/frontend" "go.temporal.io/server/service/history" "go.temporal.io/server/service/history/replication" @@ -100,6 +101,7 @@ type ( matchingClient matchingservice.MatchingServiceClient dcClient *dynamicconfig.MemoryClient + testHooks testhooks.TestHooks logger log.Logger clusterMetadataConfig *cluster.Config persistenceConfig config.Persistence @@ -221,9 +223,11 @@ func newTemporal(t *testing.T, params *TemporalParams) *TemporalImpl { tlsConfigProvider: params.TLSConfigProvider, captureMetricsHandler: params.CaptureMetricsHandler, dcClient: dynamicconfig.NewMemoryClient(), - serviceFxOptions: params.ServiceFxOptions, - taskCategoryRegistry: params.TaskCategoryRegistry, - hostsByProtocolByService: params.HostsByProtocolByService, + // If this doesn't build, make sure you're building with tags 'test_dep': + testHooks: testhooks.NewTestHooksImpl(), + serviceFxOptions: params.ServiceFxOptions, + taskCategoryRegistry: params.TaskCategoryRegistry, + hostsByProtocolByService: params.HostsByProtocolByService, } for k, v := range dynamicConfigOverrides { @@ -372,6 +376,7 @@ func (c *TemporalImpl) startFrontend() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), @@ -443,6 +448,7 @@ func (c *TemporalImpl) startHistory() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), @@ -495,6 +501,7 @@ func (c *TemporalImpl) startMatching() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), fx.Provide(c.GetTLSConfigProvider), @@ -558,6 +565,7 @@ func (c *TemporalImpl) startWorker() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() esclient.Client { return c.esClient }), fx.Provide(func() *esclient.Config { return c.esConfig }), @@ -704,6 +712,7 @@ func (p *clientFactoryProvider) NewFactory( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, + testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -713,6 +722,7 @@ func (p *clientFactoryProvider) NewFactory( monitor, metricsHandler, dc, + testHooks, numberOfHistoryShards, logger, throttledLogger, @@ -826,6 +836,12 @@ func (c *TemporalImpl) overrideDynamicConfig(t *testing.T, name dynamicconfig.Ke return cleanup } +func (c *TemporalImpl) injectHook(t *testing.T, key testhooks.Key, value any) func() { + cleanup := testhooks.Set(c.testHooks, key, value) + t.Cleanup(cleanup) + return cleanup +} + func mustPortFromAddress(addr string) httpPort { _, port, err := net.SplitHostPort(addr) if err != nil { diff --git a/tests/update_workflow_test.go b/tests/update_workflow_test.go index 460c98f0518..ea1ad8debc8 100644 --- a/tests/update_workflow_test.go +++ b/tests/update_workflow_test.go @@ -50,6 +50,7 @@ import ( "go.temporal.io/server/common/metrics/metricstest" "go.temporal.io/server/common/testing/protoutils" "go.temporal.io/server/common/testing/taskpoller" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/testing/testvars" "go.temporal.io/server/tests/testcore" "google.golang.org/protobuf/types/known/durationpb" @@ -5340,6 +5341,42 @@ func (s *UpdateWorkflowSuite) TestUpdateWithStart() { }) }) + s.Run("workflow start conflict", func() { + + s.Run("workflow id conflict policy fail: use-existing", func() { + tv := testvars.New(s.T()) + + startReq := startWorkflowReq(tv) + startReq.WorkflowIdConflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING + updateReq := s.updateWorkflowRequest(tv, + &updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}) + + // simulate a race condition + s.InjectHook(testhooks.UpdateWithStartInBetweenLockAndStart, func() { + _, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq) + s.NoError(err) + }) + + uwsCh := sendUpdateWithStart(testcore.NewContext(), startReq, updateReq) + + _, err := s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + return &workflowservice.RespondWorkflowTaskCompletedRequest{}, nil + }) + s.NoError(err) + + _, err = s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + return &workflowservice.RespondWorkflowTaskCompletedRequest{ + Messages: s.UpdateAcceptCompleteMessages(tv, task.Messages[0]), + }, nil + }) + s.NoError(err) + + <-uwsCh + }) + }) + s.Run("return update rate limit error", func() { // lower maximum total number of updates for testing purposes maxTotalUpdates := 1