Skip to content

Commit

Permalink
Generic hooks for testing (#6938)
Browse files Browse the repository at this point in the history
## What changed?
- Add generic hook interface for fine-grained control of behavior in
tests
- Use the hooks for matching varying behavior tests (force load balancer
to target partitions and disable sync match)
- Use the hooks to force a race condition in an update-with-start test
(by @stephanos)

## Why?
To write integration/functional tests that require tweaking behavior of
code under test, without affecting non-test builds.

## Potential risks
Hooks are disabled by default, so there should be zero risk to
production code, and zero overhead (assuming the Go compiler can do very
basic inlining and dead code elimination).

The downside is that functional tests now have to be run with `-tags
test_dep` everywhere.

---------

Co-authored-by: Stephan Behnke <[email protected]>
  • Loading branch information
dnr and stephanos authored Jan 18, 2025
1 parent ee654a1 commit f0e5891
Show file tree
Hide file tree
Showing 20 changed files with 354 additions and 86 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ VISIBILITY_DB ?= temporal_visibility
# Always use "protolegacy" tag to allow disabling utf-8 validation on proto messages
# during proto library transition.
ALL_BUILD_TAGS := protolegacy,$(BUILD_TAG)
ALL_TEST_TAGS := $(ALL_BUILD_TAGS),$(TEST_TAG)
ALL_TEST_TAGS := $(ALL_BUILD_TAGS),test_dep,$(TEST_TAG)
BUILD_TAG_FLAG := -tags $(ALL_BUILD_TAGS)
TEST_TAG_FLAG := -tags $(ALL_TEST_TAGS)

Expand Down Expand Up @@ -347,7 +347,7 @@ lint-actions: $(ACTIONLINT)

lint-code: $(GOLANGCI_LINT)
@printf $(COLOR) "Linting code..."
@$(GOLANGCI_LINT) run --verbose --timeout 10m --fix=$(GOLANGCI_LINT_FIX) --new-from-rev=$(GOLANGCI_LINT_BASE_REV) --config=.golangci.yml
@$(GOLANGCI_LINT) run --verbose --build-tags $(ALL_TEST_TAGS) --timeout 10m --fix=$(GOLANGCI_LINT_FIX) --new-from-rev=$(GOLANGCI_LINT_BASE_REV) --config=.golangci.yml

fmt-imports: $(GCI) # Don't get confused, there is a single linter called gci, which is a part of the mega linter we use is called golangci-lint.
@printf $(COLOR) "Formatting imports..."
Expand Down
9 changes: 5 additions & 4 deletions client/client_factory_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion client/clientfactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/primitives"
"go.temporal.io/server/common/testing/testhooks"
"google.golang.org/grpc"
)

Expand All @@ -65,6 +66,7 @@ type (
monitor membership.Monitor,
metricsHandler metrics.Handler,
dc *dynamicconfig.Collection,
testHooks testhooks.TestHooks,
numberOfHistoryShards int32,
logger log.Logger,
throttledLogger log.Logger,
Expand All @@ -79,6 +81,7 @@ type (
monitor membership.Monitor
metricsHandler metrics.Handler
dynConfig *dynamicconfig.Collection
testHooks testhooks.TestHooks
numberOfHistoryShards int32
logger log.Logger
throttledLogger log.Logger
Expand All @@ -103,6 +106,7 @@ func (p *factoryProviderImpl) NewFactory(
monitor membership.Monitor,
metricsHandler metrics.Handler,
dc *dynamicconfig.Collection,
testHooks testhooks.TestHooks,
numberOfHistoryShards int32,
logger log.Logger,
throttledLogger log.Logger,
Expand All @@ -112,6 +116,7 @@ func (p *factoryProviderImpl) NewFactory(
monitor: monitor,
metricsHandler: metricsHandler,
dynConfig: dc,
testHooks: testHooks,
numberOfHistoryShards: numberOfHistoryShards,
logger: logger,
throttledLogger: throttledLogger,
Expand Down Expand Up @@ -159,7 +164,7 @@ func (cf *rpcClientFactory) NewMatchingClientWithTimeout(
common.NewClientCache(keyResolver, clientProvider),
cf.metricsHandler,
cf.logger,
matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig),
matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig, cf.testHooks),
)

if cf.metricsHandler != nil {
Expand Down
51 changes: 31 additions & 20 deletions client/matching/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (

"go.temporal.io/server/common/dynamicconfig"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/common/tqid"
)

Expand Down Expand Up @@ -57,11 +58,10 @@ type (
}

defaultLoadBalancer struct {
namespaceIDToName func(id namespace.ID) (namespace.Name, error)
nReadPartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter
nWritePartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter
forceReadPartition dynamicconfig.IntPropertyFn
forceWritePartition dynamicconfig.IntPropertyFn
namespaceIDToName func(id namespace.ID) (namespace.Name, error)
nReadPartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter
nWritePartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter
testHooks testhooks.TestHooks

lock sync.RWMutex
taskQueueLBs map[tqid.TaskQueue]*tqLoadBalancer
Expand All @@ -85,23 +85,22 @@ type (
func NewLoadBalancer(
namespaceIDToName func(id namespace.ID) (namespace.Name, error),
dc *dynamicconfig.Collection,
testHooks testhooks.TestHooks,
) LoadBalancer {
lb := &defaultLoadBalancer{
namespaceIDToName: namespaceIDToName,
nReadPartitions: dynamicconfig.MatchingNumTaskqueueReadPartitions.Get(dc),
nWritePartitions: dynamicconfig.MatchingNumTaskqueueWritePartitions.Get(dc),
forceReadPartition: dynamicconfig.TestMatchingLBForceReadPartition.Get(dc),
forceWritePartition: dynamicconfig.TestMatchingLBForceWritePartition.Get(dc),
lock: sync.RWMutex{},
taskQueueLBs: make(map[tqid.TaskQueue]*tqLoadBalancer),
namespaceIDToName: namespaceIDToName,
nReadPartitions: dynamicconfig.MatchingNumTaskqueueReadPartitions.Get(dc),
nWritePartitions: dynamicconfig.MatchingNumTaskqueueWritePartitions.Get(dc),
testHooks: testHooks,
taskQueueLBs: make(map[tqid.TaskQueue]*tqLoadBalancer),
}
return lb
}

func (lb *defaultLoadBalancer) PickWritePartition(
taskQueue *tqid.TaskQueue,
) *tqid.NormalPartition {
if n := lb.forceWritePartition(); n >= 0 {
if n, ok := testhooks.Get[int](lb.testHooks, testhooks.MatchingLBForceWritePartition); ok {
return taskQueue.NormalPartition(n)
}

Expand Down Expand Up @@ -130,7 +129,11 @@ func (lb *defaultLoadBalancer) PickReadPartition(
partitionCount = lb.nReadPartitions(string(namespaceName), taskQueue.Name(), taskQueue.TaskType())
}

return tqlb.pickReadPartition(partitionCount, lb.forceReadPartition())
if n, ok := testhooks.Get[int](lb.testHooks, testhooks.MatchingLBForceWritePartition); ok {
return tqlb.forceReadPartition(partitionCount, n)
}

return tqlb.pickReadPartition(partitionCount)
}

func (lb *defaultLoadBalancer) getTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer {
Expand All @@ -157,18 +160,26 @@ func newTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer {
}
}

func (b *tqLoadBalancer) pickReadPartition(partitionCount int, forcedPartition int) *pollToken {
func (b *tqLoadBalancer) pickReadPartition(partitionCount int) *pollToken {
b.lock.Lock()
defer b.lock.Unlock()

// ensure we reflect dynamic config change if it ever happens
b.ensurePartitionCountLocked(max(partitionCount, forcedPartition+1))
b.ensurePartitionCountLocked(partitionCount)
partitionID := b.pickReadPartitionWithFewestPolls(partitionCount)

partitionID := forcedPartition
b.pollerCounts[partitionID]++

if partitionID < 0 {
partitionID = b.pickReadPartitionWithFewestPolls(partitionCount)
return &pollToken{
TQPartition: b.taskQueue.NormalPartition(partitionID),
balancer: b,
}
}

func (b *tqLoadBalancer) forceReadPartition(partitionCount, partitionID int) *pollToken {
b.lock.Lock()
defer b.lock.Unlock()

b.ensurePartitionCountLocked(max(partitionCount, partitionID+1))

b.pollerCounts[partitionID]++

Expand Down
54 changes: 27 additions & 27 deletions client/matching/loadbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,22 @@ func TestTQLoadBalancer(t *testing.T) {
tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY))

// pick 4 times, each partition picked would have one poller
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
p3 := tqlb.pickReadPartition(partitionCount, -1)
p3 := tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))

// release one, and pick one, the newly picked one should have one poller
p3.Release()
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))

// pick one again, this time it should have 2 pollers
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 2, maxPollerCount(tqlb))
}

Expand All @@ -89,27 +89,27 @@ func TestTQLoadBalancerForce(t *testing.T) {
tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY))

// pick 4 times, each partition picked would have one poller
p1 := tqlb.pickReadPartition(partitionCount, 1)
p1 := tqlb.forceReadPartition(partitionCount, 1)
assert.Equal(t, 1, p1.TQPartition.PartitionId())
assert.Equal(t, 1, maxPollerCount(tqlb))
tqlb.pickReadPartition(partitionCount, 1)
tqlb.forceReadPartition(partitionCount, 1)
assert.Equal(t, 2, maxPollerCount(tqlb))

// when we don't force it should balance out
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 2, maxPollerCount(tqlb))

// releasing the forced one and adding another should still be balanced
p1.Release()
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 2, maxPollerCount(tqlb))

tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 3, maxPollerCount(tqlb))
}

Expand All @@ -125,7 +125,7 @@ func TestLoadBalancerConcurrent(t *testing.T) {
for i := 0; i < concurrentCount; i++ {
go func() {
defer wg.Done()
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
}()
}
wg.Wait()
Expand All @@ -142,23 +142,23 @@ func TestLoadBalancer_ReducedPartitionCount(t *testing.T) {
f, err := tqid.NewTaskQueueFamily("fake-namespace-id", "fake-taskqueue")
assert.NoError(t, err)
tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY))
p1 := tqlb.pickReadPartition(partitionCount, -1)
p2 := tqlb.pickReadPartition(partitionCount, -1)
p1 := tqlb.pickReadPartition(partitionCount)
p2 := tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
assert.Equal(t, 1, maxPollerCount(tqlb))

partitionCount += 2 // increase partition count
p3 := tqlb.pickReadPartition(partitionCount, -1)
p4 := tqlb.pickReadPartition(partitionCount, -1)
p3 := tqlb.pickReadPartition(partitionCount)
p4 := tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
assert.Equal(t, 1, maxPollerCount(tqlb))

partitionCount -= 2 // reduce partition count
p5 := tqlb.pickReadPartition(partitionCount, -1)
p6 := tqlb.pickReadPartition(partitionCount, -1)
p5 := tqlb.pickReadPartition(partitionCount)
p6 := tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 2, maxPollerCount(tqlb))
assert.Equal(t, 2, maxPollerCount(tqlb))
p7 := tqlb.pickReadPartition(partitionCount, -1)
p7 := tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 3, maxPollerCount(tqlb))

// release all of them and it should be ok.
Expand All @@ -170,11 +170,11 @@ func TestLoadBalancer_ReducedPartitionCount(t *testing.T) {
p6.Release()
p7.Release()

tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
assert.Equal(t, 1, maxPollerCount(tqlb))
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 2, maxPollerCount(tqlb))
}

Expand Down
17 changes: 0 additions & 17 deletions common/dynamicconfig/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -1260,23 +1260,6 @@ these log lines can be noisy, we want to be able to turn on and sample selective
1000,
`MatchingMaxTaskQueuesInDeployment represents the maximum number of task-queues that can be registed in a single deployment`,
)
// for matching testing only:

TestMatchingDisableSyncMatch = NewGlobalBoolSetting(
"test.matching.disableSyncMatch",
false,
`TestMatchingDisableSyncMatch forces tasks to go through the db once`,
)
TestMatchingLBForceReadPartition = NewGlobalIntSetting(
"test.matching.lbForceReadPartition",
-1,
`TestMatchingLBForceReadPartition forces polls to go to a specific partition`,
)
TestMatchingLBForceWritePartition = NewGlobalIntSetting(
"test.matching.lbForceWritePartition",
-1,
`TestMatchingLBForceWritePartition forces adds to go to a specific partition`,
)

// keys for history

Expand Down
4 changes: 4 additions & 0 deletions common/resource/fx.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ import (
"go.temporal.io/server/common/sdk"
"go.temporal.io/server/common/searchattribute"
"go.temporal.io/server/common/telemetry"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/common/utf8validator"
"go.uber.org/fx"
"google.golang.org/grpc"
Expand Down Expand Up @@ -129,6 +130,7 @@ var Module = fx.Options(
deadlock.Module,
config.Module,
utf8validator.Module,
testhooks.Module,
fx.Invoke(func(*utf8validator.Validator) {}), // force this to be constructed even if not referenced elsewhere
)

Expand Down Expand Up @@ -227,6 +229,7 @@ func ClientFactoryProvider(
membershipMonitor membership.Monitor,
metricsHandler metrics.Handler,
dynamicCollection *dynamicconfig.Collection,
testHooks testhooks.TestHooks,
persistenceConfig *config.Persistence,
logger log.SnTaggedLogger,
throttledLogger log.ThrottledLogger,
Expand All @@ -236,6 +239,7 @@ func ClientFactoryProvider(
membershipMonitor,
metricsHandler,
dynamicCollection,
testHooks,
persistenceConfig.NumHistoryShards,
logger,
throttledLogger,
Expand Down
Loading

0 comments on commit f0e5891

Please sign in to comment.