Skip to content

Commit

Permalink
Ensure that a client send error in the sidecar will result
Browse files Browse the repository at this point in the history
in the CSI driver stream being canceled.
  • Loading branch information
carlbraganza committed Dec 13, 2024
1 parent b943ffd commit 1954211
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 2 deletions.
178 changes: 178 additions & 0 deletions pkg/internal/server/grpc/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/container-storage-interface/spec/lib/go/csi"
snapshotv1 "github.com/kubernetes-csi/external-snapshotter/client/v8/apis/volumesnapshot/v1"
fakesnapshot "github.com/kubernetes-csi/external-snapshotter/client/v8/clientset/versioned/fake"
snapshotutils "github.com/kubernetes-csi/external-snapshotter/v8/pkg/utils"
Expand Down Expand Up @@ -433,3 +435,179 @@ func (th *testHarness) SetKlogVerbosity(verboseLevel int, uniquePrefix string) K
klog.InitFlags(fs)
}
}

// test data structure for context propagation testing.
type testSnapshotMetadataServerCtxPropagator struct {
*csi.UnimplementedSnapshotMetadataServer

chanToCloseBeforeReturn chan struct{}
chanToCloseOnEntry chan struct{}
chanToWaitOnBeforeFirstResponse chan struct{}
chanToWaitOnBeforeSecondResponse chan struct{}
handlerCalled bool
send1Err error
send2Err error
streamCtx context.Context
streamGetMetadataAllocatedErr error

// test harness support
th *testHarness
rth *runtime.TestHarness
grpcServer *Server
}

// newSMSHarnessForCtxPropagation sets up a test harness with the sidecar service
// connected to a fake CSI driver serving the testSnapshotMetadataServerCtxPropagator.
// The setup provides a mechanism to deterministically sense a canceled context
// in the fake CSI driver.
//
// The invoker should make the following call sequence:
// - Invoke one of the GetAllocatedAllocated or GetMetadataDelta operations.
// The handler in the fake CSI driver blocks waiting for the invoker to synchronize.
// - Call synchronizeBeforeCancel, which enables the fake CSI driver handler to send
// the first response after which it blocks again.
// - Receive the first response (without an error)
// - Cancel its context
// - Receive the second response (which must fail as it hasn't been sent yet)
// - Call synchronizeAfterCancel. This wakes the fake CSI driver handler which then
// spins waiting to detect the failed context, after which it attempts to send
// a second response. Only then does control returns back to the invoker.
//
// After this the invoker should examine the SnapshotMetadataServer properties to
// check for correctness.
func newSMSHarnessForCtxPropagation(t *testing.T) (*testSnapshotMetadataServerCtxPropagator, *testHarness) {
s := &testSnapshotMetadataServerCtxPropagator{}
s.chanToCloseOnEntry = make(chan struct{})
s.chanToCloseBeforeReturn = make(chan struct{})
s.chanToWaitOnBeforeFirstResponse = make(chan struct{})
s.chanToWaitOnBeforeSecondResponse = make(chan struct{})

// set up a fake csi driver with the runtime test harness
s.rth = runtime.NewTestHarness().WithFakeKubeConfig(t).WithFakeCSIDriver(t, s)
rrt := s.rth.RuntimeForFakeCSIDriver(t)

// configure a local test harness to connect to the fake csi driver
s.th = newTestHarness()
s.th.DriverName = rrt.DriverName
s.th.WithFakeClientAPIs()
rt := s.th.Runtime()
rt.CSIConn = rrt.CSIConn
s.grpcServer = s.th.StartGRPCServer(t, rt)
s.grpcServer.CSIDriverIsReady()

return s, s.th
}

func (s *testSnapshotMetadataServerCtxPropagator) cleanup(t *testing.T) {
if s.th != nil {
s.th.StopGRPCServer(t)
s.rth.RemoveFakeKubeConfig(t)
s.rth.TerminateFakeCSIDriver(t)
}
}

func (s *testSnapshotMetadataServerCtxPropagator) sync(ctx context.Context, sendResp1, sendResp2 func() error) error {
if s.chanToCloseOnEntry != nil {
close(s.chanToCloseOnEntry)
}

if s.chanToWaitOnBeforeFirstResponse != nil {
<-s.chanToWaitOnBeforeFirstResponse
}

// send the first response
s.send1Err = sendResp1()

if s.chanToWaitOnBeforeSecondResponse != nil {
<-s.chanToWaitOnBeforeSecondResponse
}

// now wait for the client's canceled context to reach this context
for ctx.Err() == nil {
time.Sleep(time.Millisecond)
}

// send the next response
s.send2Err = sendResp2()

if s.chanToCloseBeforeReturn != nil {
close(s.chanToCloseBeforeReturn)
}

s.handlerCalled = true
s.streamCtx = ctx

return nil
}

func (s *testSnapshotMetadataServerCtxPropagator) synchronizeBeforeCancel() {
// synchronize with the fake CSI driver
<-s.chanToCloseOnEntry

// the fake driver may now send the first response
close(s.chanToWaitOnBeforeFirstResponse)
}

func (s *testSnapshotMetadataServerCtxPropagator) synchronizeAfterCancel() {
// the fake driver can now send the second response
close(s.chanToWaitOnBeforeSecondResponse)

// wait for the fake driver method to complete
<-s.chanToCloseBeforeReturn
}

func (s *testSnapshotMetadataServerCtxPropagator) GetMetadataAllocated(req *csi.GetMetadataAllocatedRequest, stream csi.SnapshotMetadata_GetMetadataAllocatedServer) error {
return s.sync(stream.Context(),
func() error {
return stream.Send(&csi.GetMetadataAllocatedResponse{
BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH,
VolumeCapacityBytes: 1024 * 1024 * 1024,
BlockMetadata: []*csi.BlockMetadata{
{
ByteOffset: 0,
SizeBytes: 1024,
},
},
})
},
func() error {
return stream.Send(&csi.GetMetadataAllocatedResponse{
BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH,
VolumeCapacityBytes: 1024 * 1024 * 1024,
BlockMetadata: []*csi.BlockMetadata{
{
ByteOffset: 1024,
SizeBytes: 1024,
},
},
})
})
}

func (s *testSnapshotMetadataServerCtxPropagator) GetMetadataDelta(req *csi.GetMetadataDeltaRequest, stream csi.SnapshotMetadata_GetMetadataDeltaServer) error {
return s.sync(stream.Context(),
func() error {
return stream.Send(&csi.GetMetadataDeltaResponse{
BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH,
VolumeCapacityBytes: 1024 * 1024 * 1024,
BlockMetadata: []*csi.BlockMetadata{
{
ByteOffset: 0,
SizeBytes: 1024,
},
},
})
},
func() error {
return stream.Send(&csi.GetMetadataDeltaResponse{
BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH,
VolumeCapacityBytes: 1024 * 1024 * 1024,
BlockMetadata: []*csi.BlockMetadata{
{
ByteOffset: 1024,
SizeBytes: 1024,
},
},
})
})
}
9 changes: 8 additions & 1 deletion pkg/internal/server/grpc/get_metadata_allocated.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ import (
)

func (s *Server) GetMetadataAllocated(req *api.GetMetadataAllocatedRequest, stream api.SnapshotMetadata_GetMetadataAllocatedServer) error {
ctx := s.getMetadataAllocatedContextWithLogger(req, stream)
// Create a cancelable context so that failure in sending to the client would
// cancel the context used to communicate with the CSI driver when the stack unwinds.
// Note: this may be unnecessary if failure on the client stream Send() is already
// propagated to the the returned context in the gRPC runtime, but it is not documented
// as such, so this extra step ensures that this will indeed take place.
ctx, cancelFn := context.WithCancel(s.getMetadataAllocatedContextWithLogger(req, stream))
defer cancelFn()

if err := s.validateGetMetadataAllocatedRequest(req); err != nil {
klog.FromContext(ctx).Error(err, "validation failed")
Expand All @@ -55,6 +61,7 @@ func (s *Server) GetMetadataAllocated(req *api.GetMetadataAllocatedRequest, stre
klog.FromContext(ctx).V(HandlerTraceLogLevel).Info("calling CSI driver", "snapshotId", csiReq.SnapshotId)
csiStream, err := csi.NewSnapshotMetadataClient(s.csiConnection()).GetMetadataAllocated(ctx, csiReq)
if err != nil {
klog.FromContext(ctx).Error(err, "csi.GetMetadataAllocated")
return err
}

Expand Down
41 changes: 41 additions & 0 deletions pkg/internal/server/grpc/get_metadata_allocated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -751,3 +751,44 @@ func (f *fakeStreamServerSnapshotAllocated) Send(m *api.GetMetadataAllocatedResp
func (f *fakeStreamServerSnapshotAllocated) verifyResponse(expectedResponse *api.GetMetadataAllocatedResponse) bool {
return f.response.String() == expectedResponse.String()
}

func TestGetMetadataAllocatedClientErrorCancelsDriverStream(t *testing.T) {
sms, th := newSMSHarnessForCtxPropagation(t)
defer sms.cleanup(t)

// create the cancelable client context
ctx, cancelFn := context.WithCancel(context.Background())

// make the RPC call
client := th.GRPCSnapshotMetadataClient(t)
clientStream, err := client.GetMetadataAllocated(ctx, &api.GetMetadataAllocatedRequest{
SecurityToken: th.SecurityToken,
Namespace: th.Namespace,
SnapshotName: "snap-1",
})
assert.NoError(t, err)
assert.NotNil(t, clientStream)

sms.synchronizeBeforeCancel()

r1, e1 := clientStream.Recv() // get the first response
assert.NoError(t, e1)
assert.NotNil(t, r1)

// the client cancels the context
cancelFn()

r2, e2 := clientStream.Recv() // fail to get the second response (not yet sent)
assert.Error(t, e2)
assert.ErrorContains(t, e2, context.Canceled.Error())
assert.Nil(t, r2)

sms.synchronizeAfterCancel()

// Check the fake driver handler status
assert.True(t, sms.handlerCalled)
assert.NoError(t, sms.send1Err)
assert.ErrorIs(t, sms.streamCtx.Err(), context.Canceled)
assert.Error(t, sms.send2Err)
assert.ErrorContains(t, sms.send2Err, context.Canceled.Error())
}
9 changes: 8 additions & 1 deletion pkg/internal/server/grpc/get_metadata_delta.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ import (
)

func (s *Server) GetMetadataDelta(req *api.GetMetadataDeltaRequest, stream api.SnapshotMetadata_GetMetadataDeltaServer) error {
ctx := s.getMetadataDeltaContextWithLogger(req, stream)
// Create a cancelable context so that failure in sending to the client would
// cancel the context used to communicate with the CSI driver when the stack unwinds.
// Note: this may be unnecessary if failure on the client stream Send() is already
// propagated to the the returned context in the gRPC runtime, but it is not documented
// as such, so this extra step ensures that this will indeed take place.
ctx, cancelFn := context.WithCancel(s.getMetadataDeltaContextWithLogger(req, stream))
defer cancelFn()

if err := s.validateGetMetadataDeltaRequest(req); err != nil {
klog.FromContext(ctx).Error(err, "validation failed")
Expand All @@ -55,6 +61,7 @@ func (s *Server) GetMetadataDelta(req *api.GetMetadataDeltaRequest, stream api.S
klog.FromContext(ctx).V(HandlerTraceLogLevel).Info("calling CSI driver", "baseSnapshotId", csiReq.BaseSnapshotId, "targetSnapshotId", csiReq.TargetSnapshotId)
csiStream, err := csi.NewSnapshotMetadataClient(s.csiConnection()).GetMetadataDelta(ctx, csiReq)
if err != nil {
klog.FromContext(ctx).Error(err, "csi.GetMetadataDelta")
return err
}

Expand Down
42 changes: 42 additions & 0 deletions pkg/internal/server/grpc/get_metadata_delta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -985,3 +985,45 @@ func (f *fakeStreamServerSnapshotDelta) Send(m *api.GetMetadataDeltaResponse) er
func (f *fakeStreamServerSnapshotDelta) verifyResponse(expectedResponse *api.GetMetadataDeltaResponse) bool {
return f.response.String() == expectedResponse.String()
}

func TestGetMetadataDeltaClientErrorCancelsDriverStream(t *testing.T) {
sms, th := newSMSHarnessForCtxPropagation(t)
defer sms.cleanup(t)

// create the cancelable client context
ctx, cancelFn := context.WithCancel(context.Background())

// make the RPC call
client := th.GRPCSnapshotMetadataClient(t)
clientStream, err := client.GetMetadataDelta(ctx, &api.GetMetadataDeltaRequest{
SecurityToken: th.SecurityToken,
Namespace: th.Namespace,
BaseSnapshotName: "snap-1",
TargetSnapshotName: "snap-2",
})
assert.NoError(t, err)
assert.NotNil(t, clientStream)

sms.synchronizeBeforeCancel()

r1, e1 := clientStream.Recv() // get the first response
assert.NoError(t, e1)
assert.NotNil(t, r1)

// the client cancels the context
cancelFn()

r2, e2 := clientStream.Recv() // fail to get the second response (not yet sent)
assert.Error(t, e2)
assert.ErrorContains(t, e2, context.Canceled.Error())
assert.Nil(t, r2)

sms.synchronizeAfterCancel()

// Check the fake driver handler status
assert.True(t, sms.handlerCalled)
assert.NoError(t, sms.send1Err)
assert.ErrorIs(t, sms.streamCtx.Err(), context.Canceled)
assert.Error(t, sms.send2Err)
assert.ErrorContains(t, sms.send2Err, context.Canceled.Error())
}

0 comments on commit 1954211

Please sign in to comment.