diff --git a/Makefile b/Makefile index e2957f8038e..b3d2b9cdd2b 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ VISIBILITY_DB ?= temporal_visibility # Always use "protolegacy" tag to allow disabling utf-8 validation on proto messages # during proto library transition. ALL_BUILD_TAGS := protolegacy,$(BUILD_TAG) -ALL_TEST_TAGS := $(ALL_BUILD_TAGS),$(TEST_TAG) +ALL_TEST_TAGS := $(ALL_BUILD_TAGS),test_dep,$(TEST_TAG) BUILD_TAG_FLAG := -tags $(ALL_BUILD_TAGS) TEST_TAG_FLAG := -tags $(ALL_TEST_TAGS) @@ -347,7 +347,7 @@ lint-actions: $(ACTIONLINT) lint-code: $(GOLANGCI_LINT) @printf $(COLOR) "Linting code..." - @$(GOLANGCI_LINT) run --verbose --timeout 10m --fix=$(GOLANGCI_LINT_FIX) --new-from-rev=$(GOLANGCI_LINT_BASE_REV) --config=.golangci.yml + @$(GOLANGCI_LINT) run --verbose --build-tags $(ALL_TEST_TAGS) --timeout 10m --fix=$(GOLANGCI_LINT_FIX) --new-from-rev=$(GOLANGCI_LINT_BASE_REV) --config=.golangci.yml fmt-imports: $(GCI) # Don't get confused, there is a single linter called gci, which is a part of the mega linter we use is called golangci-lint. @printf $(COLOR) "Formatting imports..." diff --git a/client/client_factory_mock.go b/client/client_factory_mock.go index 00788b6554b..93246d97d40 100644 --- a/client/client_factory_mock.go +++ b/client/client_factory_mock.go @@ -46,6 +46,7 @@ import ( log "go.temporal.io/server/common/log" membership "go.temporal.io/server/common/membership" metrics "go.temporal.io/server/common/metrics" + testhooks "go.temporal.io/server/common/testing/testhooks" gomock "go.uber.org/mock/gomock" grpc "google.golang.org/grpc" ) @@ -187,15 +188,15 @@ func (m *MockFactoryProvider) EXPECT() *MockFactoryProviderMockRecorder { } // NewFactory mocks base method. -func (m *MockFactoryProvider) NewFactory(rpcFactory common.RPCFactory, monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, numberOfHistoryShards int32, logger, throttledLogger log.Logger) Factory { +func (m *MockFactoryProvider) NewFactory(rpcFactory common.RPCFactory, monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger, throttledLogger log.Logger) Factory { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewFactory", rpcFactory, monitor, metricsHandler, dc, numberOfHistoryShards, logger, throttledLogger) + ret := m.ctrl.Call(m, "NewFactory", rpcFactory, monitor, metricsHandler, dc, testHooks, numberOfHistoryShards, logger, throttledLogger) ret0, _ := ret[0].(Factory) return ret0 } // NewFactory indicates an expected call of NewFactory. -func (mr *MockFactoryProviderMockRecorder) NewFactory(rpcFactory, monitor, metricsHandler, dc, numberOfHistoryShards, logger, throttledLogger any) *gomock.Call { +func (mr *MockFactoryProviderMockRecorder) NewFactory(rpcFactory, monitor, metricsHandler, dc, testHooks, numberOfHistoryShards, logger, throttledLogger any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewFactory", reflect.TypeOf((*MockFactoryProvider)(nil).NewFactory), rpcFactory, monitor, metricsHandler, dc, numberOfHistoryShards, logger, throttledLogger) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewFactory", reflect.TypeOf((*MockFactoryProvider)(nil).NewFactory), rpcFactory, monitor, metricsHandler, dc, testHooks, numberOfHistoryShards, logger, throttledLogger) } diff --git a/client/clientfactory.go b/client/clientfactory.go index e3c7838378e..85366422b91 100644 --- a/client/clientfactory.go +++ b/client/clientfactory.go @@ -44,6 +44,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/primitives" + "go.temporal.io/server/common/testing/testhooks" "google.golang.org/grpc" ) @@ -65,6 +66,7 @@ type ( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, + testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -79,6 +81,7 @@ type ( monitor membership.Monitor metricsHandler metrics.Handler dynConfig *dynamicconfig.Collection + testHooks testhooks.TestHooks numberOfHistoryShards int32 logger log.Logger throttledLogger log.Logger @@ -103,6 +106,7 @@ func (p *factoryProviderImpl) NewFactory( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, + testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -112,6 +116,7 @@ func (p *factoryProviderImpl) NewFactory( monitor: monitor, metricsHandler: metricsHandler, dynConfig: dc, + testHooks: testHooks, numberOfHistoryShards: numberOfHistoryShards, logger: logger, throttledLogger: throttledLogger, @@ -159,7 +164,7 @@ func (cf *rpcClientFactory) NewMatchingClientWithTimeout( common.NewClientCache(keyResolver, clientProvider), cf.metricsHandler, cf.logger, - matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig), + matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig, cf.testHooks), ) if cf.metricsHandler != nil { diff --git a/client/matching/loadbalancer.go b/client/matching/loadbalancer.go index ddc7095316e..ee08ee6fd31 100644 --- a/client/matching/loadbalancer.go +++ b/client/matching/loadbalancer.go @@ -30,6 +30,7 @@ import ( "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/tqid" ) @@ -57,11 +58,10 @@ type ( } defaultLoadBalancer struct { - namespaceIDToName func(id namespace.ID) (namespace.Name, error) - nReadPartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter - nWritePartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter - forceReadPartition dynamicconfig.IntPropertyFn - forceWritePartition dynamicconfig.IntPropertyFn + namespaceIDToName func(id namespace.ID) (namespace.Name, error) + nReadPartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter + nWritePartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter + testHooks testhooks.TestHooks lock sync.RWMutex taskQueueLBs map[tqid.TaskQueue]*tqLoadBalancer @@ -85,15 +85,14 @@ type ( func NewLoadBalancer( namespaceIDToName func(id namespace.ID) (namespace.Name, error), dc *dynamicconfig.Collection, + testHooks testhooks.TestHooks, ) LoadBalancer { lb := &defaultLoadBalancer{ - namespaceIDToName: namespaceIDToName, - nReadPartitions: dynamicconfig.MatchingNumTaskqueueReadPartitions.Get(dc), - nWritePartitions: dynamicconfig.MatchingNumTaskqueueWritePartitions.Get(dc), - forceReadPartition: dynamicconfig.TestMatchingLBForceReadPartition.Get(dc), - forceWritePartition: dynamicconfig.TestMatchingLBForceWritePartition.Get(dc), - lock: sync.RWMutex{}, - taskQueueLBs: make(map[tqid.TaskQueue]*tqLoadBalancer), + namespaceIDToName: namespaceIDToName, + nReadPartitions: dynamicconfig.MatchingNumTaskqueueReadPartitions.Get(dc), + nWritePartitions: dynamicconfig.MatchingNumTaskqueueWritePartitions.Get(dc), + testHooks: testHooks, + taskQueueLBs: make(map[tqid.TaskQueue]*tqLoadBalancer), } return lb } @@ -101,7 +100,7 @@ func NewLoadBalancer( func (lb *defaultLoadBalancer) PickWritePartition( taskQueue *tqid.TaskQueue, ) *tqid.NormalPartition { - if n := lb.forceWritePartition(); n >= 0 { + if n, ok := testhooks.Get[int](lb.testHooks, testhooks.MatchingLBForceWritePartition); ok { return taskQueue.NormalPartition(n) } @@ -130,7 +129,11 @@ func (lb *defaultLoadBalancer) PickReadPartition( partitionCount = lb.nReadPartitions(string(namespaceName), taskQueue.Name(), taskQueue.TaskType()) } - return tqlb.pickReadPartition(partitionCount, lb.forceReadPartition()) + if n, ok := testhooks.Get[int](lb.testHooks, testhooks.MatchingLBForceWritePartition); ok { + return tqlb.forceReadPartition(partitionCount, n) + } + + return tqlb.pickReadPartition(partitionCount) } func (lb *defaultLoadBalancer) getTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer { @@ -157,18 +160,26 @@ func newTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer { } } -func (b *tqLoadBalancer) pickReadPartition(partitionCount int, forcedPartition int) *pollToken { +func (b *tqLoadBalancer) pickReadPartition(partitionCount int) *pollToken { b.lock.Lock() defer b.lock.Unlock() - // ensure we reflect dynamic config change if it ever happens - b.ensurePartitionCountLocked(max(partitionCount, forcedPartition+1)) + b.ensurePartitionCountLocked(partitionCount) + partitionID := b.pickReadPartitionWithFewestPolls(partitionCount) - partitionID := forcedPartition + b.pollerCounts[partitionID]++ - if partitionID < 0 { - partitionID = b.pickReadPartitionWithFewestPolls(partitionCount) + return &pollToken{ + TQPartition: b.taskQueue.NormalPartition(partitionID), + balancer: b, } +} + +func (b *tqLoadBalancer) forceReadPartition(partitionCount, partitionID int) *pollToken { + b.lock.Lock() + defer b.lock.Unlock() + + b.ensurePartitionCountLocked(max(partitionCount, partitionID+1)) b.pollerCounts[partitionID]++ diff --git a/client/matching/loadbalancer_test.go b/client/matching/loadbalancer_test.go index 65a2ae6421f..76a84253a16 100644 --- a/client/matching/loadbalancer_test.go +++ b/client/matching/loadbalancer_test.go @@ -63,22 +63,22 @@ func TestTQLoadBalancer(t *testing.T) { tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY)) // pick 4 times, each partition picked would have one poller - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) - p3 := tqlb.pickReadPartition(partitionCount, -1) + p3 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) // release one, and pick one, the newly picked one should have one poller p3.Release() - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) // pick one again, this time it should have 2 pollers - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) } @@ -89,27 +89,27 @@ func TestTQLoadBalancerForce(t *testing.T) { tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY)) // pick 4 times, each partition picked would have one poller - p1 := tqlb.pickReadPartition(partitionCount, 1) + p1 := tqlb.forceReadPartition(partitionCount, 1) assert.Equal(t, 1, p1.TQPartition.PartitionId()) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, 1) + tqlb.forceReadPartition(partitionCount, 1) assert.Equal(t, 2, maxPollerCount(tqlb)) // when we don't force it should balance out - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) // releasing the forced one and adding another should still be balanced p1.Release() - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 3, maxPollerCount(tqlb)) } @@ -125,7 +125,7 @@ func TestLoadBalancerConcurrent(t *testing.T) { for i := 0; i < concurrentCount; i++ { go func() { defer wg.Done() - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) }() } wg.Wait() @@ -142,23 +142,23 @@ func TestLoadBalancer_ReducedPartitionCount(t *testing.T) { f, err := tqid.NewTaskQueueFamily("fake-namespace-id", "fake-taskqueue") assert.NoError(t, err) tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY)) - p1 := tqlb.pickReadPartition(partitionCount, -1) - p2 := tqlb.pickReadPartition(partitionCount, -1) + p1 := tqlb.pickReadPartition(partitionCount) + p2 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) assert.Equal(t, 1, maxPollerCount(tqlb)) partitionCount += 2 // increase partition count - p3 := tqlb.pickReadPartition(partitionCount, -1) - p4 := tqlb.pickReadPartition(partitionCount, -1) + p3 := tqlb.pickReadPartition(partitionCount) + p4 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) assert.Equal(t, 1, maxPollerCount(tqlb)) partitionCount -= 2 // reduce partition count - p5 := tqlb.pickReadPartition(partitionCount, -1) - p6 := tqlb.pickReadPartition(partitionCount, -1) + p5 := tqlb.pickReadPartition(partitionCount) + p6 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) assert.Equal(t, 2, maxPollerCount(tqlb)) - p7 := tqlb.pickReadPartition(partitionCount, -1) + p7 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 3, maxPollerCount(tqlb)) // release all of them and it should be ok. @@ -170,11 +170,11 @@ func TestLoadBalancer_ReducedPartitionCount(t *testing.T) { p6.Release() p7.Release() - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) } diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index f793c037d18..388a53b4732 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -1260,23 +1260,6 @@ these log lines can be noisy, we want to be able to turn on and sample selective 1000, `MatchingMaxTaskQueuesInDeployment represents the maximum number of task-queues that can be registed in a single deployment`, ) - // for matching testing only: - - TestMatchingDisableSyncMatch = NewGlobalBoolSetting( - "test.matching.disableSyncMatch", - false, - `TestMatchingDisableSyncMatch forces tasks to go through the db once`, - ) - TestMatchingLBForceReadPartition = NewGlobalIntSetting( - "test.matching.lbForceReadPartition", - -1, - `TestMatchingLBForceReadPartition forces polls to go to a specific partition`, - ) - TestMatchingLBForceWritePartition = NewGlobalIntSetting( - "test.matching.lbForceWritePartition", - -1, - `TestMatchingLBForceWritePartition forces adds to go to a specific partition`, - ) // keys for history diff --git a/common/resource/fx.go b/common/resource/fx.go index 117d661a469..eb33b1883f4 100644 --- a/common/resource/fx.go +++ b/common/resource/fx.go @@ -63,6 +63,7 @@ import ( "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/telemetry" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/utf8validator" "go.uber.org/fx" "google.golang.org/grpc" @@ -129,6 +130,7 @@ var Module = fx.Options( deadlock.Module, config.Module, utf8validator.Module, + testhooks.Module, fx.Invoke(func(*utf8validator.Validator) {}), // force this to be constructed even if not referenced elsewhere ) @@ -227,6 +229,7 @@ func ClientFactoryProvider( membershipMonitor membership.Monitor, metricsHandler metrics.Handler, dynamicCollection *dynamicconfig.Collection, + testHooks testhooks.TestHooks, persistenceConfig *config.Persistence, logger log.SnTaggedLogger, throttledLogger log.ThrottledLogger, @@ -236,6 +239,7 @@ func ClientFactoryProvider( membershipMonitor, metricsHandler, dynamicCollection, + testHooks, persistenceConfig.NumHistoryShards, logger, throttledLogger, diff --git a/common/testing/testhooks/key.go b/common/testing/testhooks/key.go new file mode 100644 index 00000000000..eb8f32a10a5 --- /dev/null +++ b/common/testing/testhooks/key.go @@ -0,0 +1,32 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package testhooks + +type Key int + +const ( + MatchingDisableSyncMatch Key = iota + MatchingLBForceReadPartition + MatchingLBForceWritePartition + UpdateWithStartInBetweenLockAndStart +) diff --git a/common/testing/testhooks/noop_impl.go b/common/testing/testhooks/noop_impl.go new file mode 100644 index 00000000000..9179ceb6490 --- /dev/null +++ b/common/testing/testhooks/noop_impl.go @@ -0,0 +1,56 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:build !test_dep + +package testhooks + +import "go.uber.org/fx" + +var Module = fx.Options( + fx.Provide(func() (_ TestHooks) { return }), +) + +type ( + // TestHooks (in production mode) is an empty struct just so the build works. + // See TestHooks in test_impl.go. + // + // TestHooks are an inherently unclean way of writing tests. They require mixing test-only + // concerns into production code. In general you should prefer other ways of writing tests + // wherever possible, and only use TestHooks sparingly, as a last resort. + TestHooks struct{} +) + +// Get gets the value of a test hook. In production mode it always returns the zero value and +// false, which hopefully the compiler will inline and remove the hook as dead code. +// +// TestHooks should be used very sparingly, see comment on TestHooks. +func Get[T any](_ TestHooks, key Key) (T, bool) { + var zero T + return zero, false +} + +// Call calls a func() hook if present. +// +// TestHooks should be used very sparingly, see comment on TestHooks. +func Call(_ TestHooks, key Key) { +} diff --git a/common/testing/testhooks/test_impl.go b/common/testing/testhooks/test_impl.go new file mode 100644 index 00000000000..e50a6064e6e --- /dev/null +++ b/common/testing/testhooks/test_impl.go @@ -0,0 +1,102 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:build test_dep + +package testhooks + +import ( + "sync" + + "go.uber.org/fx" +) + +var Module = fx.Options( + fx.Provide(NewTestHooksImpl), +) + +type ( + // TestHooks holds a registry of active test hooks. It should be obtained through fx and + // used with Get and Set. + // + // TestHooks are an inherently unclean way of writing tests. They require mixing test-only + // concerns into production code. In general you should prefer other ways of writing tests + // wherever possible, and only use TestHooks sparingly, as a last resort. + TestHooks interface { + // private accessors; access must go through package-level Get/Set + get(Key) (any, bool) + set(Key, any) + del(Key) + } + + // testHooksImpl is an implementation of TestHooks. + testHooksImpl struct { + m sync.Map + } +) + +// Get gets the value of a test hook from the registry. +// +// TestHooks should be used very sparingly, see comment on TestHooks. +func Get[T any](th TestHooks, key Key) (T, bool) { + if val, ok := th.get(key); ok { + // this is only used in test so we want to panic on type mismatch: + return val.(T), ok // nolint:revive + } + var zero T + return zero, false +} + +// Call calls a func() hook if present. +// +// TestHooks should be used very sparingly, see comment on TestHooks. +func Call(th TestHooks, key Key) { + if hook, ok := Get[func()](th, key); ok { + hook() + } +} + +// Set sets a test hook to a value and returns a cleanup function to unset it. +// Calls to Set and the cleanup functions should form a stack. +func Set[T any](th TestHooks, key Key, val T) func() { + th.set(key, val) + return func() { th.del(key) } +} + +// NewTestHooksImpl returns a new instance of a test hook registry. This is provided and used +// in the main "resource" module as a default, but in functional tests, it's overridden by an +// explicitly constructed instance. +func NewTestHooksImpl() TestHooks { + return &testHooksImpl{} +} + +func (th *testHooksImpl) get(key Key) (any, bool) { + return th.m.Load(key) +} + +func (th *testHooksImpl) set(key Key, val any) { + th.m.Store(key, val) +} + +func (th *testHooksImpl) del(key Key) { + th.m.Delete(key) +} diff --git a/service/history/api/multioperation/api.go b/service/history/api/multioperation/api.go index a28a628d271..6dbda379278 100644 --- a/service/history/api/multioperation/api.go +++ b/service/history/api/multioperation/api.go @@ -39,6 +39,7 @@ import ( "go.temporal.io/server/common/locks" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/api" "go.temporal.io/server/service/history/api/startworkflow" "go.temporal.io/server/service/history/api/updateworkflow" @@ -46,9 +47,7 @@ import ( "go.temporal.io/server/service/history/workflow" ) -var ( - multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.") -) +var multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.") type ( // updateError is a wrapper to distinguish an update error from a start error. @@ -62,6 +61,7 @@ type ( shardContext shard.Context namespaceId namespace.ID consistencyChecker api.WorkflowConsistencyChecker + testHooks testhooks.TestHooks updateReq *historyservice.UpdateWorkflowExecutionRequest startReq *historyservice.StartWorkflowExecutionRequest @@ -79,6 +79,7 @@ func Invoke( tokenSerializer common.TaskTokenSerializer, visibilityManager manager.VisibilityManager, matchingClient matchingservice.MatchingServiceClient, + testHooks testhooks.TestHooks, ) (*historyservice.ExecuteMultiOperationResponse, error) { if len(req.Operations) != 2 { return nil, serviceerror.NewInvalidArgument("expected exactly 2 operations") @@ -98,6 +99,7 @@ func Invoke( shardContext: shardContext, namespaceId: namespace.ID(req.NamespaceId), consistencyChecker: workflowConsistencyChecker, + testHooks: testHooks, updateReq: updateReq, startReq: startReq, } @@ -157,6 +159,8 @@ func (mo *multiOp) Invoke(ctx context.Context) (*historyservice.ExecuteMultiOper return mo.updateWorkflow(ctx, runningWorkflowLease) } + testhooks.Call(mo.testHooks, testhooks.UpdateWithStartInBetweenLockAndStart) + // Workflow hasn't been started yet ... resp, err := mo.startAndUpdateWorkflow(ctx) var noStartErr *noStartError diff --git a/service/history/history_engine.go b/service/history/history_engine.go index b4339043594..ff7e2029c6d 100644 --- a/service/history/history_engine.go +++ b/service/history/history_engine.go @@ -55,6 +55,7 @@ import ( "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/api" "go.temporal.io/server/service/history/api/addtasks" "go.temporal.io/server/service/history/api/deleteworkflow" @@ -157,6 +158,7 @@ type ( replicationProgressCache replication.ProgressCache syncStateRetriever replication.SyncStateRetriever outboundQueueCBPool *circuitbreakerpool.OutboundQueueCircuitBreakerPool + testHooks testhooks.TestHooks } ) @@ -183,6 +185,7 @@ func NewEngineWithShardContext( dlqWriter replication.DLQWriter, commandHandlerRegistry *workflow.CommandHandlerRegistry, outboundQueueCBPool *circuitbreakerpool.OutboundQueueCircuitBreakerPool, + testHooks testhooks.TestHooks, ) shard.Engine { currentClusterName := shard.GetClusterMetadata().GetCurrentClusterName() @@ -232,6 +235,7 @@ func NewEngineWithShardContext( replicationProgressCache: replicationProgressCache, syncStateRetriever: syncStateRetriever, outboundQueueCBPool: outboundQueueCBPool, + testHooks: testHooks, } historyEngImpl.queueProcessors = make(map[tasks.Category]queues.Queue) @@ -429,6 +433,7 @@ func (e *historyEngineImpl) ExecuteMultiOperation( e.tokenSerializer, e.persistenceVisibilityMgr, e.matchingClient, + e.testHooks, ) } diff --git a/service/history/history_engine_factory.go b/service/history/history_engine_factory.go index de90aa4fb3c..a005be63929 100644 --- a/service/history/history_engine_factory.go +++ b/service/history/history_engine_factory.go @@ -32,6 +32,7 @@ import ( "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/resource" "go.temporal.io/server/common/sdk" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/api" "go.temporal.io/server/service/history/circuitbreakerpool" "go.temporal.io/server/service/history/configs" @@ -68,6 +69,7 @@ type ( ReplicationDLQWriter replication.DLQWriter CommandHandlerRegistry *workflow.CommandHandlerRegistry OutboundQueueCBPool *circuitbreakerpool.OutboundQueueCircuitBreakerPool + TestHooks testhooks.TestHooks } historyEngineFactory struct { @@ -108,5 +110,6 @@ func (f *historyEngineFactory) CreateEngine( f.ReplicationDLQWriter, f.CommandHandlerRegistry, f.OutboundQueueCBPool, + f.TestHooks, ) } diff --git a/service/matching/config.go b/service/matching/config.go index f6685526998..ec852c79c0e 100644 --- a/service/matching/config.go +++ b/service/matching/config.go @@ -47,7 +47,6 @@ type ( PersistenceDynamicRateLimitingParams dynamicconfig.TypedPropertyFn[dynamicconfig.DynamicRateLimitingParams] PersistenceQPSBurstRatio dynamicconfig.FloatPropertyFn SyncMatchWaitDuration dynamicconfig.DurationPropertyFnWithTaskQueueFilter - TestDisableSyncMatch dynamicconfig.BoolPropertyFn RPS dynamicconfig.IntPropertyFn OperatorRPSRatio dynamicconfig.FloatPropertyFn AlignMembershipChange dynamicconfig.DurationPropertyFn @@ -135,7 +134,6 @@ type ( BacklogNegligibleAge func() time.Duration MaxWaitForPollerBeforeFwd func() time.Duration QueryPollerUnavailableWindow func() time.Duration - TestDisableSyncMatch func() bool // Time to hold a poll request before returning an empty response if there are no tasks LongPollExpirationInterval func() time.Duration RangeSize int64 @@ -216,7 +214,6 @@ func NewConfig( PersistenceDynamicRateLimitingParams: dynamicconfig.MatchingPersistenceDynamicRateLimitingParams.Get(dc), PersistenceQPSBurstRatio: dynamicconfig.PersistenceQPSBurstRatio.Get(dc), SyncMatchWaitDuration: dynamicconfig.MatchingSyncMatchWaitDuration.Get(dc), - TestDisableSyncMatch: dynamicconfig.TestMatchingDisableSyncMatch.Get(dc), LoadUserData: dynamicconfig.MatchingLoadUserData.Get(dc), HistoryMaxPageSize: dynamicconfig.MatchingHistoryMaxPageSize.Get(dc), EnableDeployments: dynamicconfig.EnableDeployments.Get(dc), @@ -311,7 +308,6 @@ func newTaskQueueConfig(tq *tqid.TaskQueue, config *Config, ns namespace.Name) * return config.MaxWaitForPollerBeforeFwd(ns.String(), taskQueueName, taskType) }, QueryPollerUnavailableWindow: config.QueryPollerUnavailableWindow, - TestDisableSyncMatch: config.TestDisableSyncMatch, LongPollExpirationInterval: func() time.Duration { return config.LongPollExpirationInterval(ns.String(), taskQueueName, taskType) }, diff --git a/service/matching/handler.go b/service/matching/handler.go index c2ae8263a84..5a56de8457a 100644 --- a/service/matching/handler.go +++ b/service/matching/handler.go @@ -42,6 +42,7 @@ import ( "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/resource" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/tqid" "go.temporal.io/server/service/worker/deployment" "google.golang.org/protobuf/proto" @@ -88,6 +89,7 @@ func NewHandler( namespaceReplicationQueue persistence.NamespaceReplicationQueue, visibilityManager manager.VisibilityManager, nexusEndpointManager persistence.NexusEndpointManager, + testHooks testhooks.TestHooks, ) *Handler { handler := &Handler{ config: config, @@ -110,6 +112,7 @@ func NewHandler( namespaceReplicationQueue, visibilityManager, nexusEndpointManager, + testHooks, ), namespaceRegistry: namespaceRegistry, } diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index c975703a9be..f8e46091b8c 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -69,6 +69,7 @@ import ( "go.temporal.io/server/common/resource" serviceerrors "go.temporal.io/server/common/serviceerror" "go.temporal.io/server/common/tasktoken" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/tqid" "go.temporal.io/server/common/util" "go.temporal.io/server/common/worker_versioning" @@ -143,6 +144,7 @@ type ( partitions map[tqid.PartitionKey]taskQueuePartitionManager gaugeMetrics gaugeMetrics // per-namespace task queue counters config *Config + testHooks testhooks.TestHooks // queryResults maps query TaskID (which is a UUID generated in QueryWorkflow() call) to a channel // that QueryWorkflow() will block on. The channel is unblocked either by worker sending response through // RespondQueryTaskCompleted() or through an internal service error causing temporal to be unable to dispatch @@ -203,6 +205,7 @@ func NewEngine( namespaceReplicationQueue persistence.NamespaceReplicationQueue, visibilityManager manager.VisibilityManager, nexusEndpointManager persistence.NexusEndpointManager, + testHooks testhooks.TestHooks, ) Engine { scopedMetricsHandler := metricsHandler.WithTags(metrics.OperationTag(metrics.MatchingEngineScope)) e := &matchingEngineImpl{ @@ -233,6 +236,7 @@ func NewEngine( loadedPhysicalTaskQueueCount: make(map[taskQueueCounterKey]int), }, config: config, + testHooks: testHooks, queryResults: collection.NewSyncMap[string, chan *queryResult](), nexusResults: collection.NewSyncMap[string, chan *nexusResult](), outstandingPollers: collection.NewSyncMap[string, context.CancelFunc](), diff --git a/service/matching/physical_task_queue_manager.go b/service/matching/physical_task_queue_manager.go index bfddde6a953..a628bfb6654 100644 --- a/service/matching/physical_task_queue_manager.go +++ b/service/matching/physical_task_queue_manager.go @@ -50,6 +50,7 @@ import ( "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/worker_versioning" "go.temporal.io/server/service/worker/deployment" "google.golang.org/protobuf/types/known/durationpb" @@ -529,7 +530,7 @@ func (c *physicalTaskQueueManagerImpl) TrySyncMatch(ctx context.Context, task *i // request sent by history service c.liveness.markAlive() c.tasksAddedInIntervals.incrementTaskCount() - if c.config.TestDisableSyncMatch() { + if disable, _ := testhooks.Get[bool](c.partitionMgr.engine.testHooks, testhooks.MatchingDisableSyncMatch); disable { return false, nil } } diff --git a/tests/testcore/functional_test_base.go b/tests/testcore/functional_test_base.go index 971d5a81ce5..83587d4acbd 100644 --- a/tests/testcore/functional_test_base.go +++ b/tests/testcore/functional_test_base.go @@ -59,6 +59,7 @@ import ( "go.temporal.io/server/common/rpc" "go.temporal.io/server/common/testing/historyrequire" "go.temporal.io/server/common/testing/protorequire" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/testing/updateutils" "go.temporal.io/server/environment" "go.uber.org/fx" @@ -493,6 +494,10 @@ func (s *FunctionalTestBase) OverrideDynamicConfig(setting dynamicconfig.Generic return s.testCluster.host.overrideDynamicConfig(s.T(), setting.Key(), value) } +func (s *FunctionalTestBase) InjectHook(key testhooks.Key, value any) (cleanup func()) { + return s.testCluster.host.injectHook(s.T(), key, value) +} + func (s *FunctionalTestBase) GetNamespaceID(namespace string) string { namespaceResp, err := s.FrontendClient().DescribeNamespace(NewContext(), &workflowservice.DescribeNamespaceRequest{ Namespace: namespace, @@ -525,20 +530,20 @@ func (s *FunctionalTestBase) RunTestWithMatchingBehavior(subtest func()) { name, func() { if forceTaskForward { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueWritePartitions, 13) - s.OverrideDynamicConfig(dynamicconfig.TestMatchingLBForceWritePartition, 11) + s.InjectHook(testhooks.MatchingLBForceWritePartition, 11) } else { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueWritePartitions, 1) } if forcePollForward { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueReadPartitions, 13) - s.OverrideDynamicConfig(dynamicconfig.TestMatchingLBForceReadPartition, 5) + s.InjectHook(testhooks.MatchingLBForceReadPartition, 5) } else { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueReadPartitions, 1) } if forceAsync { - s.OverrideDynamicConfig(dynamicconfig.TestMatchingDisableSyncMatch, true) + s.InjectHook(testhooks.MatchingDisableSyncMatch, true) } else { - s.OverrideDynamicConfig(dynamicconfig.TestMatchingDisableSyncMatch, false) + s.InjectHook(testhooks.MatchingDisableSyncMatch, false) } subtest() diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index a5c034e1fca..bedac34aef3 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -71,6 +71,7 @@ import ( "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/telemetry" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/frontend" "go.temporal.io/server/service/history" "go.temporal.io/server/service/history/replication" @@ -100,6 +101,7 @@ type ( matchingClient matchingservice.MatchingServiceClient dcClient *dynamicconfig.MemoryClient + testHooks testhooks.TestHooks logger log.Logger clusterMetadataConfig *cluster.Config persistenceConfig config.Persistence @@ -221,9 +223,11 @@ func newTemporal(t *testing.T, params *TemporalParams) *TemporalImpl { tlsConfigProvider: params.TLSConfigProvider, captureMetricsHandler: params.CaptureMetricsHandler, dcClient: dynamicconfig.NewMemoryClient(), - serviceFxOptions: params.ServiceFxOptions, - taskCategoryRegistry: params.TaskCategoryRegistry, - hostsByProtocolByService: params.HostsByProtocolByService, + // If this doesn't build, make sure you're building with tags 'test_dep': + testHooks: testhooks.NewTestHooksImpl(), + serviceFxOptions: params.ServiceFxOptions, + taskCategoryRegistry: params.TaskCategoryRegistry, + hostsByProtocolByService: params.HostsByProtocolByService, } for k, v := range dynamicConfigOverrides { @@ -372,6 +376,7 @@ func (c *TemporalImpl) startFrontend() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), @@ -443,6 +448,7 @@ func (c *TemporalImpl) startHistory() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), @@ -495,6 +501,7 @@ func (c *TemporalImpl) startMatching() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), fx.Provide(c.GetTLSConfigProvider), @@ -558,6 +565,7 @@ func (c *TemporalImpl) startWorker() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() esclient.Client { return c.esClient }), fx.Provide(func() *esclient.Config { return c.esConfig }), @@ -704,6 +712,7 @@ func (p *clientFactoryProvider) NewFactory( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, + testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -713,6 +722,7 @@ func (p *clientFactoryProvider) NewFactory( monitor, metricsHandler, dc, + testHooks, numberOfHistoryShards, logger, throttledLogger, @@ -826,6 +836,12 @@ func (c *TemporalImpl) overrideDynamicConfig(t *testing.T, name dynamicconfig.Ke return cleanup } +func (c *TemporalImpl) injectHook(t *testing.T, key testhooks.Key, value any) func() { + cleanup := testhooks.Set(c.testHooks, key, value) + t.Cleanup(cleanup) + return cleanup +} + func mustPortFromAddress(addr string) httpPort { _, port, err := net.SplitHostPort(addr) if err != nil { diff --git a/tests/update_workflow_test.go b/tests/update_workflow_test.go index 460c98f0518..ea1ad8debc8 100644 --- a/tests/update_workflow_test.go +++ b/tests/update_workflow_test.go @@ -50,6 +50,7 @@ import ( "go.temporal.io/server/common/metrics/metricstest" "go.temporal.io/server/common/testing/protoutils" "go.temporal.io/server/common/testing/taskpoller" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/testing/testvars" "go.temporal.io/server/tests/testcore" "google.golang.org/protobuf/types/known/durationpb" @@ -5340,6 +5341,42 @@ func (s *UpdateWorkflowSuite) TestUpdateWithStart() { }) }) + s.Run("workflow start conflict", func() { + + s.Run("workflow id conflict policy fail: use-existing", func() { + tv := testvars.New(s.T()) + + startReq := startWorkflowReq(tv) + startReq.WorkflowIdConflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING + updateReq := s.updateWorkflowRequest(tv, + &updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}) + + // simulate a race condition + s.InjectHook(testhooks.UpdateWithStartInBetweenLockAndStart, func() { + _, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq) + s.NoError(err) + }) + + uwsCh := sendUpdateWithStart(testcore.NewContext(), startReq, updateReq) + + _, err := s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + return &workflowservice.RespondWorkflowTaskCompletedRequest{}, nil + }) + s.NoError(err) + + _, err = s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + return &workflowservice.RespondWorkflowTaskCompletedRequest{ + Messages: s.UpdateAcceptCompleteMessages(tv, task.Messages[0]), + }, nil + }) + s.NoError(err) + + <-uwsCh + }) + }) + s.Run("return update rate limit error", func() { // lower maximum total number of updates for testing purposes maxTotalUpdates := 1