From 1954211850e9fa94ead8f78da1d7ecbd442ced3f Mon Sep 17 00:00:00 2001 From: Carl Braganza Date: Tue, 10 Dec 2024 18:12:35 -0800 Subject: [PATCH] Ensure that a client send error in the sidecar will result in the CSI driver stream being canceled. --- pkg/internal/server/grpc/common_test.go | 178 ++++++++++++++++++ .../server/grpc/get_metadata_allocated.go | 9 +- .../grpc/get_metadata_allocated_test.go | 41 ++++ .../server/grpc/get_metadata_delta.go | 9 +- .../server/grpc/get_metadata_delta_test.go | 42 +++++ 5 files changed, 277 insertions(+), 2 deletions(-) diff --git a/pkg/internal/server/grpc/common_test.go b/pkg/internal/server/grpc/common_test.go index 682c1f64..89e70494 100644 --- a/pkg/internal/server/grpc/common_test.go +++ b/pkg/internal/server/grpc/common_test.go @@ -23,7 +23,9 @@ import ( "strconv" "strings" "testing" + "time" + "github.com/container-storage-interface/spec/lib/go/csi" snapshotv1 "github.com/kubernetes-csi/external-snapshotter/client/v8/apis/volumesnapshot/v1" fakesnapshot "github.com/kubernetes-csi/external-snapshotter/client/v8/clientset/versioned/fake" snapshotutils "github.com/kubernetes-csi/external-snapshotter/v8/pkg/utils" @@ -433,3 +435,179 @@ func (th *testHarness) SetKlogVerbosity(verboseLevel int, uniquePrefix string) K klog.InitFlags(fs) } } + +// test data structure for context propagation testing. +type testSnapshotMetadataServerCtxPropagator struct { + *csi.UnimplementedSnapshotMetadataServer + + chanToCloseBeforeReturn chan struct{} + chanToCloseOnEntry chan struct{} + chanToWaitOnBeforeFirstResponse chan struct{} + chanToWaitOnBeforeSecondResponse chan struct{} + handlerCalled bool + send1Err error + send2Err error + streamCtx context.Context + streamGetMetadataAllocatedErr error + + // test harness support + th *testHarness + rth *runtime.TestHarness + grpcServer *Server +} + +// newSMSHarnessForCtxPropagation sets up a test harness with the sidecar service +// connected to a fake CSI driver serving the testSnapshotMetadataServerCtxPropagator. +// The setup provides a mechanism to deterministically sense a canceled context +// in the fake CSI driver. +// +// The invoker should make the following call sequence: +// - Invoke one of the GetAllocatedAllocated or GetMetadataDelta operations. +// The handler in the fake CSI driver blocks waiting for the invoker to synchronize. +// - Call synchronizeBeforeCancel, which enables the fake CSI driver handler to send +// the first response after which it blocks again. +// - Receive the first response (without an error) +// - Cancel its context +// - Receive the second response (which must fail as it hasn't been sent yet) +// - Call synchronizeAfterCancel. This wakes the fake CSI driver handler which then +// spins waiting to detect the failed context, after which it attempts to send +// a second response. Only then does control returns back to the invoker. +// +// After this the invoker should examine the SnapshotMetadataServer properties to +// check for correctness. +func newSMSHarnessForCtxPropagation(t *testing.T) (*testSnapshotMetadataServerCtxPropagator, *testHarness) { + s := &testSnapshotMetadataServerCtxPropagator{} + s.chanToCloseOnEntry = make(chan struct{}) + s.chanToCloseBeforeReturn = make(chan struct{}) + s.chanToWaitOnBeforeFirstResponse = make(chan struct{}) + s.chanToWaitOnBeforeSecondResponse = make(chan struct{}) + + // set up a fake csi driver with the runtime test harness + s.rth = runtime.NewTestHarness().WithFakeKubeConfig(t).WithFakeCSIDriver(t, s) + rrt := s.rth.RuntimeForFakeCSIDriver(t) + + // configure a local test harness to connect to the fake csi driver + s.th = newTestHarness() + s.th.DriverName = rrt.DriverName + s.th.WithFakeClientAPIs() + rt := s.th.Runtime() + rt.CSIConn = rrt.CSIConn + s.grpcServer = s.th.StartGRPCServer(t, rt) + s.grpcServer.CSIDriverIsReady() + + return s, s.th +} + +func (s *testSnapshotMetadataServerCtxPropagator) cleanup(t *testing.T) { + if s.th != nil { + s.th.StopGRPCServer(t) + s.rth.RemoveFakeKubeConfig(t) + s.rth.TerminateFakeCSIDriver(t) + } +} + +func (s *testSnapshotMetadataServerCtxPropagator) sync(ctx context.Context, sendResp1, sendResp2 func() error) error { + if s.chanToCloseOnEntry != nil { + close(s.chanToCloseOnEntry) + } + + if s.chanToWaitOnBeforeFirstResponse != nil { + <-s.chanToWaitOnBeforeFirstResponse + } + + // send the first response + s.send1Err = sendResp1() + + if s.chanToWaitOnBeforeSecondResponse != nil { + <-s.chanToWaitOnBeforeSecondResponse + } + + // now wait for the client's canceled context to reach this context + for ctx.Err() == nil { + time.Sleep(time.Millisecond) + } + + // send the next response + s.send2Err = sendResp2() + + if s.chanToCloseBeforeReturn != nil { + close(s.chanToCloseBeforeReturn) + } + + s.handlerCalled = true + s.streamCtx = ctx + + return nil +} + +func (s *testSnapshotMetadataServerCtxPropagator) synchronizeBeforeCancel() { + // synchronize with the fake CSI driver + <-s.chanToCloseOnEntry + + // the fake driver may now send the first response + close(s.chanToWaitOnBeforeFirstResponse) +} + +func (s *testSnapshotMetadataServerCtxPropagator) synchronizeAfterCancel() { + // the fake driver can now send the second response + close(s.chanToWaitOnBeforeSecondResponse) + + // wait for the fake driver method to complete + <-s.chanToCloseBeforeReturn +} + +func (s *testSnapshotMetadataServerCtxPropagator) GetMetadataAllocated(req *csi.GetMetadataAllocatedRequest, stream csi.SnapshotMetadata_GetMetadataAllocatedServer) error { + return s.sync(stream.Context(), + func() error { + return stream.Send(&csi.GetMetadataAllocatedResponse{ + BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH, + VolumeCapacityBytes: 1024 * 1024 * 1024, + BlockMetadata: []*csi.BlockMetadata{ + { + ByteOffset: 0, + SizeBytes: 1024, + }, + }, + }) + }, + func() error { + return stream.Send(&csi.GetMetadataAllocatedResponse{ + BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH, + VolumeCapacityBytes: 1024 * 1024 * 1024, + BlockMetadata: []*csi.BlockMetadata{ + { + ByteOffset: 1024, + SizeBytes: 1024, + }, + }, + }) + }) +} + +func (s *testSnapshotMetadataServerCtxPropagator) GetMetadataDelta(req *csi.GetMetadataDeltaRequest, stream csi.SnapshotMetadata_GetMetadataDeltaServer) error { + return s.sync(stream.Context(), + func() error { + return stream.Send(&csi.GetMetadataDeltaResponse{ + BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH, + VolumeCapacityBytes: 1024 * 1024 * 1024, + BlockMetadata: []*csi.BlockMetadata{ + { + ByteOffset: 0, + SizeBytes: 1024, + }, + }, + }) + }, + func() error { + return stream.Send(&csi.GetMetadataDeltaResponse{ + BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH, + VolumeCapacityBytes: 1024 * 1024 * 1024, + BlockMetadata: []*csi.BlockMetadata{ + { + ByteOffset: 1024, + SizeBytes: 1024, + }, + }, + }) + }) +} diff --git a/pkg/internal/server/grpc/get_metadata_allocated.go b/pkg/internal/server/grpc/get_metadata_allocated.go index ad141cbe..f0a30ef1 100644 --- a/pkg/internal/server/grpc/get_metadata_allocated.go +++ b/pkg/internal/server/grpc/get_metadata_allocated.go @@ -31,7 +31,13 @@ import ( ) func (s *Server) GetMetadataAllocated(req *api.GetMetadataAllocatedRequest, stream api.SnapshotMetadata_GetMetadataAllocatedServer) error { - ctx := s.getMetadataAllocatedContextWithLogger(req, stream) + // Create a cancelable context so that failure in sending to the client would + // cancel the context used to communicate with the CSI driver when the stack unwinds. + // Note: this may be unnecessary if failure on the client stream Send() is already + // propagated to the the returned context in the gRPC runtime, but it is not documented + // as such, so this extra step ensures that this will indeed take place. + ctx, cancelFn := context.WithCancel(s.getMetadataAllocatedContextWithLogger(req, stream)) + defer cancelFn() if err := s.validateGetMetadataAllocatedRequest(req); err != nil { klog.FromContext(ctx).Error(err, "validation failed") @@ -55,6 +61,7 @@ func (s *Server) GetMetadataAllocated(req *api.GetMetadataAllocatedRequest, stre klog.FromContext(ctx).V(HandlerTraceLogLevel).Info("calling CSI driver", "snapshotId", csiReq.SnapshotId) csiStream, err := csi.NewSnapshotMetadataClient(s.csiConnection()).GetMetadataAllocated(ctx, csiReq) if err != nil { + klog.FromContext(ctx).Error(err, "csi.GetMetadataAllocated") return err } diff --git a/pkg/internal/server/grpc/get_metadata_allocated_test.go b/pkg/internal/server/grpc/get_metadata_allocated_test.go index 98634357..17c433ac 100644 --- a/pkg/internal/server/grpc/get_metadata_allocated_test.go +++ b/pkg/internal/server/grpc/get_metadata_allocated_test.go @@ -751,3 +751,44 @@ func (f *fakeStreamServerSnapshotAllocated) Send(m *api.GetMetadataAllocatedResp func (f *fakeStreamServerSnapshotAllocated) verifyResponse(expectedResponse *api.GetMetadataAllocatedResponse) bool { return f.response.String() == expectedResponse.String() } + +func TestGetMetadataAllocatedClientErrorCancelsDriverStream(t *testing.T) { + sms, th := newSMSHarnessForCtxPropagation(t) + defer sms.cleanup(t) + + // create the cancelable client context + ctx, cancelFn := context.WithCancel(context.Background()) + + // make the RPC call + client := th.GRPCSnapshotMetadataClient(t) + clientStream, err := client.GetMetadataAllocated(ctx, &api.GetMetadataAllocatedRequest{ + SecurityToken: th.SecurityToken, + Namespace: th.Namespace, + SnapshotName: "snap-1", + }) + assert.NoError(t, err) + assert.NotNil(t, clientStream) + + sms.synchronizeBeforeCancel() + + r1, e1 := clientStream.Recv() // get the first response + assert.NoError(t, e1) + assert.NotNil(t, r1) + + // the client cancels the context + cancelFn() + + r2, e2 := clientStream.Recv() // fail to get the second response (not yet sent) + assert.Error(t, e2) + assert.ErrorContains(t, e2, context.Canceled.Error()) + assert.Nil(t, r2) + + sms.synchronizeAfterCancel() + + // Check the fake driver handler status + assert.True(t, sms.handlerCalled) + assert.NoError(t, sms.send1Err) + assert.ErrorIs(t, sms.streamCtx.Err(), context.Canceled) + assert.Error(t, sms.send2Err) + assert.ErrorContains(t, sms.send2Err, context.Canceled.Error()) +} diff --git a/pkg/internal/server/grpc/get_metadata_delta.go b/pkg/internal/server/grpc/get_metadata_delta.go index 79724bb6..066d8d65 100644 --- a/pkg/internal/server/grpc/get_metadata_delta.go +++ b/pkg/internal/server/grpc/get_metadata_delta.go @@ -31,7 +31,13 @@ import ( ) func (s *Server) GetMetadataDelta(req *api.GetMetadataDeltaRequest, stream api.SnapshotMetadata_GetMetadataDeltaServer) error { - ctx := s.getMetadataDeltaContextWithLogger(req, stream) + // Create a cancelable context so that failure in sending to the client would + // cancel the context used to communicate with the CSI driver when the stack unwinds. + // Note: this may be unnecessary if failure on the client stream Send() is already + // propagated to the the returned context in the gRPC runtime, but it is not documented + // as such, so this extra step ensures that this will indeed take place. + ctx, cancelFn := context.WithCancel(s.getMetadataDeltaContextWithLogger(req, stream)) + defer cancelFn() if err := s.validateGetMetadataDeltaRequest(req); err != nil { klog.FromContext(ctx).Error(err, "validation failed") @@ -55,6 +61,7 @@ func (s *Server) GetMetadataDelta(req *api.GetMetadataDeltaRequest, stream api.S klog.FromContext(ctx).V(HandlerTraceLogLevel).Info("calling CSI driver", "baseSnapshotId", csiReq.BaseSnapshotId, "targetSnapshotId", csiReq.TargetSnapshotId) csiStream, err := csi.NewSnapshotMetadataClient(s.csiConnection()).GetMetadataDelta(ctx, csiReq) if err != nil { + klog.FromContext(ctx).Error(err, "csi.GetMetadataDelta") return err } diff --git a/pkg/internal/server/grpc/get_metadata_delta_test.go b/pkg/internal/server/grpc/get_metadata_delta_test.go index e6ff33fb..776f243d 100644 --- a/pkg/internal/server/grpc/get_metadata_delta_test.go +++ b/pkg/internal/server/grpc/get_metadata_delta_test.go @@ -985,3 +985,45 @@ func (f *fakeStreamServerSnapshotDelta) Send(m *api.GetMetadataDeltaResponse) er func (f *fakeStreamServerSnapshotDelta) verifyResponse(expectedResponse *api.GetMetadataDeltaResponse) bool { return f.response.String() == expectedResponse.String() } + +func TestGetMetadataDeltaClientErrorCancelsDriverStream(t *testing.T) { + sms, th := newSMSHarnessForCtxPropagation(t) + defer sms.cleanup(t) + + // create the cancelable client context + ctx, cancelFn := context.WithCancel(context.Background()) + + // make the RPC call + client := th.GRPCSnapshotMetadataClient(t) + clientStream, err := client.GetMetadataDelta(ctx, &api.GetMetadataDeltaRequest{ + SecurityToken: th.SecurityToken, + Namespace: th.Namespace, + BaseSnapshotName: "snap-1", + TargetSnapshotName: "snap-2", + }) + assert.NoError(t, err) + assert.NotNil(t, clientStream) + + sms.synchronizeBeforeCancel() + + r1, e1 := clientStream.Recv() // get the first response + assert.NoError(t, e1) + assert.NotNil(t, r1) + + // the client cancels the context + cancelFn() + + r2, e2 := clientStream.Recv() // fail to get the second response (not yet sent) + assert.Error(t, e2) + assert.ErrorContains(t, e2, context.Canceled.Error()) + assert.Nil(t, r2) + + sms.synchronizeAfterCancel() + + // Check the fake driver handler status + assert.True(t, sms.handlerCalled) + assert.NoError(t, sms.send1Err) + assert.ErrorIs(t, sms.streamCtx.Err(), context.Canceled) + assert.Error(t, sms.send2Err) + assert.ErrorContains(t, sms.send2Err, context.Canceled.Error()) +}