Skip to content

Commit

Permalink
Merge pull request #86 from carlbraganza/client-error-to-driver
Browse files Browse the repository at this point in the history
Client error propagation
  • Loading branch information
k8s-ci-robot authored Dec 19, 2024
2 parents 3048fed + 29b0c54 commit c81d89e
Show file tree
Hide file tree
Showing 9 changed files with 466 additions and 41 deletions.
218 changes: 217 additions & 1 deletion pkg/internal/server/grpc/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ import (
"net"
"strconv"
"strings"
"sync"
"testing"
"time"

"github.com/container-storage-interface/spec/lib/go/csi"
snapshotv1 "github.com/kubernetes-csi/external-snapshotter/client/v8/apis/volumesnapshot/v1"
fakesnapshot "github.com/kubernetes-csi/external-snapshotter/client/v8/clientset/versioned/fake"
snapshotutils "github.com/kubernetes-csi/external-snapshotter/v8/pkg/utils"
Expand Down Expand Up @@ -58,6 +61,8 @@ type testHarness struct {
SecurityToken string
VolumeSnapshotClassName string

MaxStreamDur time.Duration

FakeKubeClient *fake.Clientset
FakeSnapshotClient *fakesnapshot.Clientset
FakeCBTClient *fakecbt.Clientset
Expand All @@ -80,6 +85,7 @@ func newTestHarness() *testHarness {
SecretNs: "secret-ns",
SecurityToken: "securityToken",
VolumeSnapshotClassName: "csi-snapshot-class",
MaxStreamDur: HandlerDefaultMaxStreamDuration,
}
}

Expand Down Expand Up @@ -111,7 +117,10 @@ func (th *testHarness) Runtime() *runtime.Runtime {

func (th *testHarness) ServerWithRuntime(t *testing.T, rt *runtime.Runtime) *Server {
return &Server{
config: ServerConfig{Runtime: rt},
config: ServerConfig{
Runtime: rt,
MaxStreamDur: th.MaxStreamDur,
},
healthServer: newHealthServer(),
}
}
Expand All @@ -123,6 +132,8 @@ func (th *testHarness) StartGRPCServer(t *testing.T, rt *runtime.Runtime) *Serve
th.grpcServer = grpc.NewServer()

s := th.ServerWithRuntime(t, rt)
assert.NotZero(t, s.config.MaxStreamDur)

s.grpcServer = th.grpcServer
api.RegisterSnapshotMetadataServer(s.grpcServer, s)
healthpb.RegisterHealthServer(s.grpcServer, s.healthServer)
Expand Down Expand Up @@ -433,3 +444,208 @@ func (th *testHarness) SetKlogVerbosity(verboseLevel int, uniquePrefix string) K
klog.InitFlags(fs)
}
}

// test data structure for context propagation testing.
type testSnapshotMetadataServerCtxPropagator struct {
*csi.UnimplementedSnapshotMetadataServer

chanToCloseBeforeReturn chan struct{}
chanToCloseOnEntry chan struct{}
chanToWaitOnBeforeFirstResponse chan struct{}
chanToWaitOnBeforeSecondResponse chan struct{}

handlerWaitsForCtxError bool

// the mux protects this block of variables
mux sync.Mutex
handlerCalled bool
send1Err error
send2Err error
streamCtxErr error

// test harness support
th *testHarness
rth *runtime.TestHarness
grpcServer *Server
}

// newSMSHarnessForCtxPropagation sets up a test harness with the sidecar service
// connected to a fake CSI driver serving the testSnapshotMetadataServerCtxPropagator.
// The setup provides a mechanism to _deterministically_ sense a canceled context
// in the fake CSI driver, an intrinsically racy operation. The canceled context could
// have arisen because the client canceled its context, or the sidecar timedout in
// transmitting data to the client.
//
// The mechanism works as follows:
//
// - The application client calls one of the sidecar's GetAllocatedAllocated() or
// GetMetadataDelta() operations, which return a stream from which data can be received.
//
// - The sidecar handler gets invoked with a stream on which to send responses to the
// application client. It wraps the stream context with a deadline and defers the
// cancellation function invocation.
// It then makes the same call on the fake CSI driver and receives a stream from
// which to read responses from the driver. It blocks in a loop reading from
// the fake CSI driver stream and sending the response to the client.
//
// - The handler in the fake CSI driver gets invoked with a stream on which to send the
// responses to the sidecar. It will block waiting on the application client
// to call synchronizeBeforeCancel().
//
// - The application client calls synchronizeBeforeCancel(), which wakes up the fake
// CSI driver handler and it sends the first response. The fake CSI driver handler
// blocks again waiting for the invoker to call synchronizeAfterCancel().
//
// - The response is routed through the sidecar back to the application client.
//
// - The application client receives the first response without error.
//
// At this point we inject an error. Either one of:
//
// - The application client cancels its context. At some point after this the
// canceled context is detected in the sidecar; we cannot actually detect
// this but expect the error to be logged by the sidecar.
//
// - The application client does not read a response within the sidecar's timeout
// which will trigger a cancellation of the sidecar context. We could expect to see this
// logged by the side car, either when the handler fails or the client stream send fails.
//
// Post error synchronization:
//
// - The application client calls synchronizeAfterCancel() which blocks it until the
// fake CSI driver handler returns.
//
// - The fake CSI driver handler wakes up and then loops waiting to detect
// the failed context. When it breaks out of this loop it attempts to send a
// second response, which must fail because the client has canceled its context.
//
// - When the fake CSI driver handler returns the invoker gets unblocked and returns
// from its call to synchronizeAfterCancel().
//
// - After this the invoker should examine the SnapshotMetadataServer properties
// to check for correctness.
func newSMSHarnessForCtxPropagation(t *testing.T, maxStreamDur time.Duration) (*testSnapshotMetadataServerCtxPropagator, *testHarness) {
s := &testSnapshotMetadataServerCtxPropagator{}
s.chanToCloseOnEntry = make(chan struct{})
s.chanToCloseBeforeReturn = make(chan struct{})
s.chanToWaitOnBeforeFirstResponse = make(chan struct{})
s.chanToWaitOnBeforeSecondResponse = make(chan struct{})

// set up a fake csi driver with the runtime test harness
s.rth = runtime.NewTestHarness().WithFakeKubeConfig(t).WithFakeCSIDriver(t, s)
rrt := s.rth.RuntimeForFakeCSIDriver(t)

// configure a local test harness to connect to the fake csi driver
s.th = newTestHarness()
// 2 modes: client context canceled or sidecar context canceled
if maxStreamDur > 0 {
s.th.MaxStreamDur = maxStreamDur
} else {
s.handlerWaitsForCtxError = true
}
s.th.DriverName = rrt.DriverName
s.th.WithFakeClientAPIs()
rt := s.th.Runtime()
rt.CSIConn = rrt.CSIConn
s.grpcServer = s.th.StartGRPCServer(t, rt)
s.grpcServer.CSIDriverIsReady()

return s, s.th
}

func (s *testSnapshotMetadataServerCtxPropagator) cleanup(t *testing.T) {
if s.th != nil {
s.th.StopGRPCServer(t)
s.rth.RemoveFakeKubeConfig(t)
s.rth.TerminateFakeCSIDriver(t)
}
}

func (s *testSnapshotMetadataServerCtxPropagator) sync(ctx context.Context, sendResp func() error) error {
s.mux.Lock()
defer s.mux.Unlock()

s.handlerCalled = true

// synchronizeBeforeCancel() is needed to proceed
close(s.chanToCloseOnEntry)
<-s.chanToWaitOnBeforeFirstResponse

// send the first response
s.send1Err = sendResp()

// synchronizeAfterCancel() is needed to proceed
<-s.chanToWaitOnBeforeSecondResponse

if s.handlerWaitsForCtxError {
// wait for the client's canceled context to be detected
for ctx.Err() == nil {
time.Sleep(time.Millisecond)
}

s.streamCtxErr = ctx.Err()
}

// send additional responses until an error is encountered
for s.send2Err == nil {
s.send2Err = sendResp()
time.Sleep(time.Millisecond * 10)
}

// allow the client blocked in synchronizeAfterCancel() to proceed
close(s.chanToCloseBeforeReturn)

return nil
}

func (s *testSnapshotMetadataServerCtxPropagator) synchronizeBeforeCancel() {
// synchronize with the fake CSI driver
<-s.chanToCloseOnEntry

// the fake driver may now send the first response
close(s.chanToWaitOnBeforeFirstResponse)
}

func (s *testSnapshotMetadataServerCtxPropagator) synchronizeAfterCancel() {
// the fake driver can now send the second response
close(s.chanToWaitOnBeforeSecondResponse)

// wait for the fake driver method to complete
<-s.chanToCloseBeforeReturn
}

func (s *testSnapshotMetadataServerCtxPropagator) GetMetadataAllocated(req *csi.GetMetadataAllocatedRequest, stream csi.SnapshotMetadata_GetMetadataAllocatedServer) error {
var byteOffset int64
return s.sync(stream.Context(),
func() error {
byteOffset += 1024
return stream.Send(&csi.GetMetadataAllocatedResponse{
BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH,
VolumeCapacityBytes: 1024 * 1024 * 1024,
BlockMetadata: []*csi.BlockMetadata{
{
ByteOffset: byteOffset,
SizeBytes: 1024,
},
},
})
})
}

func (s *testSnapshotMetadataServerCtxPropagator) GetMetadataDelta(req *csi.GetMetadataDeltaRequest, stream csi.SnapshotMetadata_GetMetadataDeltaServer) error {
var byteOffset int64
return s.sync(stream.Context(),
func() error {
byteOffset += 1024
return stream.Send(&csi.GetMetadataDeltaResponse{
BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH,
VolumeCapacityBytes: 1024 * 1024 * 1024,
BlockMetadata: []*csi.BlockMetadata{
{
ByteOffset: byteOffset,
SizeBytes: 1024,
},
},
})
})
}
7 changes: 6 additions & 1 deletion pkg/internal/server/grpc/get_metadata_allocated.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ import (
)

func (s *Server) GetMetadataAllocated(req *api.GetMetadataAllocatedRequest, stream api.SnapshotMetadata_GetMetadataAllocatedServer) error {
ctx := s.getMetadataAllocatedContextWithLogger(req, stream)
// Create a timeout context so that failure in either sending to the client or
// receiving from the CSI driver will ultimately abort the handler session.
// The context could also get canceled by the client.
ctx, cancelFn := context.WithTimeout(s.getMetadataAllocatedContextWithLogger(req, stream), s.config.MaxStreamDur)
defer cancelFn()

if err := s.validateGetMetadataAllocatedRequest(req); err != nil {
klog.FromContext(ctx).Error(err, "validation failed")
Expand All @@ -55,6 +59,7 @@ func (s *Server) GetMetadataAllocated(req *api.GetMetadataAllocatedRequest, stre
klog.FromContext(ctx).V(HandlerTraceLogLevel).Info("calling CSI driver", "snapshotId", csiReq.SnapshotId)
csiStream, err := csi.NewSnapshotMetadataClient(s.csiConnection()).GetMetadataAllocated(ctx, csiReq)
if err != nil {
klog.FromContext(ctx).Error(err, "csi.GetMetadataAllocated")
return err
}

Expand Down
78 changes: 78 additions & 0 deletions pkg/internal/server/grpc/get_metadata_allocated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"io"
"testing"
"time"

"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/golang/mock/gomock"
Expand Down Expand Up @@ -751,3 +752,80 @@ func (f *fakeStreamServerSnapshotAllocated) Send(m *api.GetMetadataAllocatedResp
func (f *fakeStreamServerSnapshotAllocated) verifyResponse(expectedResponse *api.GetMetadataAllocatedResponse) bool {
return f.response.String() == expectedResponse.String()
}

func TestGetMetadataAllocatedClientErrorHandling(t *testing.T) {
t.Run("client-cancels-context", func(t *testing.T) {
sms, th := newSMSHarnessForCtxPropagation(t, 0)
defer sms.cleanup(t)

// create the cancelable application client context
ctx, cancelFn := context.WithCancel(context.Background())

// make the RPC call
client := th.GRPCSnapshotMetadataClient(t)
clientStream, err := client.GetMetadataAllocated(ctx, &api.GetMetadataAllocatedRequest{
SecurityToken: th.SecurityToken,
Namespace: th.Namespace,
SnapshotName: "snap-1",
})
assert.NoError(t, err)
assert.NotNil(t, clientStream)

sms.synchronizeBeforeCancel()

r1, e1 := clientStream.Recv() // get the first response
assert.NoError(t, e1)
assert.NotNil(t, r1)

// the client cancels the context
cancelFn()

r2, e2 := clientStream.Recv() // fail because ctx is canceled
assert.Error(t, e2)
assert.ErrorContains(t, e2, context.Canceled.Error())
assert.Nil(t, r2)

sms.synchronizeAfterCancel()

// Check the fake driver handler status
sms.mux.Lock()
defer sms.mux.Unlock()
assert.True(t, sms.handlerCalled)
assert.NoError(t, sms.send1Err)
assert.ErrorIs(t, sms.streamCtxErr, context.Canceled)
assert.Error(t, sms.send2Err)
assert.ErrorContains(t, sms.send2Err, context.Canceled.Error())
})

t.Run("sidecar-deadline-exceeded", func(t *testing.T) {
// arrange for the sidecar to timeout quickly
sms, th := newSMSHarnessForCtxPropagation(t, time.Millisecond*10)
defer sms.cleanup(t)

// make the RPC call
client := th.GRPCSnapshotMetadataClient(t)
clientStream, err := client.GetMetadataAllocated(context.Background(), &api.GetMetadataAllocatedRequest{
SecurityToken: th.SecurityToken,
Namespace: th.Namespace,
SnapshotName: "snap-1",
})
assert.NoError(t, err)
assert.NotNil(t, clientStream)

sms.synchronizeBeforeCancel()

// do not attempt to receive anything

sms.synchronizeAfterCancel()

// Check the fake driver handler status
sms.mux.Lock()
defer sms.mux.Unlock()
assert.True(t, sms.handlerCalled)
assert.NoError(t, sms.send1Err)
assert.Error(t, sms.send2Err)
// its a bit uncertain as to which context error we get in the handler
re := context.DeadlineExceeded.Error() + "|" + context.Canceled.Error()
assert.Regexp(t, re, sms.send2Err.Error())
})
}
7 changes: 6 additions & 1 deletion pkg/internal/server/grpc/get_metadata_delta.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ import (
)

func (s *Server) GetMetadataDelta(req *api.GetMetadataDeltaRequest, stream api.SnapshotMetadata_GetMetadataDeltaServer) error {
ctx := s.getMetadataDeltaContextWithLogger(req, stream)
// Create a timeout context so that failure in either sending to the client or
// receiving from the CSI driver will ultimately abort the handler session.
// The context could also get canceled by the client.
ctx, cancelFn := context.WithTimeout(s.getMetadataDeltaContextWithLogger(req, stream), s.config.MaxStreamDur)
defer cancelFn()

if err := s.validateGetMetadataDeltaRequest(req); err != nil {
klog.FromContext(ctx).Error(err, "validation failed")
Expand All @@ -55,6 +59,7 @@ func (s *Server) GetMetadataDelta(req *api.GetMetadataDeltaRequest, stream api.S
klog.FromContext(ctx).V(HandlerTraceLogLevel).Info("calling CSI driver", "baseSnapshotId", csiReq.BaseSnapshotId, "targetSnapshotId", csiReq.TargetSnapshotId)
csiStream, err := csi.NewSnapshotMetadataClient(s.csiConnection()).GetMetadataDelta(ctx, csiReq)
if err != nil {
klog.FromContext(ctx).Error(err, "csi.GetMetadataDelta")
return err
}

Expand Down
Loading

0 comments on commit c81d89e

Please sign in to comment.