From 29b0c5491a8aeb87762e1c2e28266a3b1ff5f1ad Mon Sep 17 00:00:00 2001 From: Carl Braganza Date: Tue, 10 Dec 2024 18:12:35 -0800 Subject: [PATCH] Ensure that a client errors will result in stream cancellation. --- pkg/internal/server/grpc/common_test.go | 218 +++++++++++++++++- .../server/grpc/get_metadata_allocated.go | 7 +- .../grpc/get_metadata_allocated_test.go | 78 +++++++ .../server/grpc/get_metadata_delta.go | 7 +- .../server/grpc/get_metadata_delta_test.go | 80 +++++++ pkg/internal/server/grpc/server.go | 13 ++ pkg/internal/server/grpc/server_test.go | 9 +- pkg/sidecar/sidecar.go | 82 ++++--- pkg/sidecar/sidecar_test.go | 13 +- 9 files changed, 466 insertions(+), 41 deletions(-) diff --git a/pkg/internal/server/grpc/common_test.go b/pkg/internal/server/grpc/common_test.go index 682c1f64..be4877b2 100644 --- a/pkg/internal/server/grpc/common_test.go +++ b/pkg/internal/server/grpc/common_test.go @@ -22,8 +22,11 @@ import ( "net" "strconv" "strings" + "sync" "testing" + "time" + "github.com/container-storage-interface/spec/lib/go/csi" snapshotv1 "github.com/kubernetes-csi/external-snapshotter/client/v8/apis/volumesnapshot/v1" fakesnapshot "github.com/kubernetes-csi/external-snapshotter/client/v8/clientset/versioned/fake" snapshotutils "github.com/kubernetes-csi/external-snapshotter/v8/pkg/utils" @@ -58,6 +61,8 @@ type testHarness struct { SecurityToken string VolumeSnapshotClassName string + MaxStreamDur time.Duration + FakeKubeClient *fake.Clientset FakeSnapshotClient *fakesnapshot.Clientset FakeCBTClient *fakecbt.Clientset @@ -80,6 +85,7 @@ func newTestHarness() *testHarness { SecretNs: "secret-ns", SecurityToken: "securityToken", VolumeSnapshotClassName: "csi-snapshot-class", + MaxStreamDur: HandlerDefaultMaxStreamDuration, } } @@ -111,7 +117,10 @@ func (th *testHarness) Runtime() *runtime.Runtime { func (th *testHarness) ServerWithRuntime(t *testing.T, rt *runtime.Runtime) *Server { return &Server{ - config: ServerConfig{Runtime: rt}, + config: ServerConfig{ + Runtime: rt, + MaxStreamDur: th.MaxStreamDur, + }, healthServer: newHealthServer(), } } @@ -123,6 +132,8 @@ func (th *testHarness) StartGRPCServer(t *testing.T, rt *runtime.Runtime) *Serve th.grpcServer = grpc.NewServer() s := th.ServerWithRuntime(t, rt) + assert.NotZero(t, s.config.MaxStreamDur) + s.grpcServer = th.grpcServer api.RegisterSnapshotMetadataServer(s.grpcServer, s) healthpb.RegisterHealthServer(s.grpcServer, s.healthServer) @@ -433,3 +444,208 @@ func (th *testHarness) SetKlogVerbosity(verboseLevel int, uniquePrefix string) K klog.InitFlags(fs) } } + +// test data structure for context propagation testing. +type testSnapshotMetadataServerCtxPropagator struct { + *csi.UnimplementedSnapshotMetadataServer + + chanToCloseBeforeReturn chan struct{} + chanToCloseOnEntry chan struct{} + chanToWaitOnBeforeFirstResponse chan struct{} + chanToWaitOnBeforeSecondResponse chan struct{} + + handlerWaitsForCtxError bool + + // the mux protects this block of variables + mux sync.Mutex + handlerCalled bool + send1Err error + send2Err error + streamCtxErr error + + // test harness support + th *testHarness + rth *runtime.TestHarness + grpcServer *Server +} + +// newSMSHarnessForCtxPropagation sets up a test harness with the sidecar service +// connected to a fake CSI driver serving the testSnapshotMetadataServerCtxPropagator. +// The setup provides a mechanism to _deterministically_ sense a canceled context +// in the fake CSI driver, an intrinsically racy operation. The canceled context could +// have arisen because the client canceled its context, or the sidecar timedout in +// transmitting data to the client. +// +// The mechanism works as follows: +// +// - The application client calls one of the sidecar's GetAllocatedAllocated() or +// GetMetadataDelta() operations, which return a stream from which data can be received. +// +// - The sidecar handler gets invoked with a stream on which to send responses to the +// application client. It wraps the stream context with a deadline and defers the +// cancellation function invocation. +// It then makes the same call on the fake CSI driver and receives a stream from +// which to read responses from the driver. It blocks in a loop reading from +// the fake CSI driver stream and sending the response to the client. +// +// - The handler in the fake CSI driver gets invoked with a stream on which to send the +// responses to the sidecar. It will block waiting on the application client +// to call synchronizeBeforeCancel(). +// +// - The application client calls synchronizeBeforeCancel(), which wakes up the fake +// CSI driver handler and it sends the first response. The fake CSI driver handler +// blocks again waiting for the invoker to call synchronizeAfterCancel(). +// +// - The response is routed through the sidecar back to the application client. +// +// - The application client receives the first response without error. +// +// At this point we inject an error. Either one of: +// +// - The application client cancels its context. At some point after this the +// canceled context is detected in the sidecar; we cannot actually detect +// this but expect the error to be logged by the sidecar. +// +// - The application client does not read a response within the sidecar's timeout +// which will trigger a cancellation of the sidecar context. We could expect to see this +// logged by the side car, either when the handler fails or the client stream send fails. +// +// Post error synchronization: +// +// - The application client calls synchronizeAfterCancel() which blocks it until the +// fake CSI driver handler returns. +// +// - The fake CSI driver handler wakes up and then loops waiting to detect +// the failed context. When it breaks out of this loop it attempts to send a +// second response, which must fail because the client has canceled its context. +// +// - When the fake CSI driver handler returns the invoker gets unblocked and returns +// from its call to synchronizeAfterCancel(). +// +// - After this the invoker should examine the SnapshotMetadataServer properties +// to check for correctness. +func newSMSHarnessForCtxPropagation(t *testing.T, maxStreamDur time.Duration) (*testSnapshotMetadataServerCtxPropagator, *testHarness) { + s := &testSnapshotMetadataServerCtxPropagator{} + s.chanToCloseOnEntry = make(chan struct{}) + s.chanToCloseBeforeReturn = make(chan struct{}) + s.chanToWaitOnBeforeFirstResponse = make(chan struct{}) + s.chanToWaitOnBeforeSecondResponse = make(chan struct{}) + + // set up a fake csi driver with the runtime test harness + s.rth = runtime.NewTestHarness().WithFakeKubeConfig(t).WithFakeCSIDriver(t, s) + rrt := s.rth.RuntimeForFakeCSIDriver(t) + + // configure a local test harness to connect to the fake csi driver + s.th = newTestHarness() + // 2 modes: client context canceled or sidecar context canceled + if maxStreamDur > 0 { + s.th.MaxStreamDur = maxStreamDur + } else { + s.handlerWaitsForCtxError = true + } + s.th.DriverName = rrt.DriverName + s.th.WithFakeClientAPIs() + rt := s.th.Runtime() + rt.CSIConn = rrt.CSIConn + s.grpcServer = s.th.StartGRPCServer(t, rt) + s.grpcServer.CSIDriverIsReady() + + return s, s.th +} + +func (s *testSnapshotMetadataServerCtxPropagator) cleanup(t *testing.T) { + if s.th != nil { + s.th.StopGRPCServer(t) + s.rth.RemoveFakeKubeConfig(t) + s.rth.TerminateFakeCSIDriver(t) + } +} + +func (s *testSnapshotMetadataServerCtxPropagator) sync(ctx context.Context, sendResp func() error) error { + s.mux.Lock() + defer s.mux.Unlock() + + s.handlerCalled = true + + // synchronizeBeforeCancel() is needed to proceed + close(s.chanToCloseOnEntry) + <-s.chanToWaitOnBeforeFirstResponse + + // send the first response + s.send1Err = sendResp() + + // synchronizeAfterCancel() is needed to proceed + <-s.chanToWaitOnBeforeSecondResponse + + if s.handlerWaitsForCtxError { + // wait for the client's canceled context to be detected + for ctx.Err() == nil { + time.Sleep(time.Millisecond) + } + + s.streamCtxErr = ctx.Err() + } + + // send additional responses until an error is encountered + for s.send2Err == nil { + s.send2Err = sendResp() + time.Sleep(time.Millisecond * 10) + } + + // allow the client blocked in synchronizeAfterCancel() to proceed + close(s.chanToCloseBeforeReturn) + + return nil +} + +func (s *testSnapshotMetadataServerCtxPropagator) synchronizeBeforeCancel() { + // synchronize with the fake CSI driver + <-s.chanToCloseOnEntry + + // the fake driver may now send the first response + close(s.chanToWaitOnBeforeFirstResponse) +} + +func (s *testSnapshotMetadataServerCtxPropagator) synchronizeAfterCancel() { + // the fake driver can now send the second response + close(s.chanToWaitOnBeforeSecondResponse) + + // wait for the fake driver method to complete + <-s.chanToCloseBeforeReturn +} + +func (s *testSnapshotMetadataServerCtxPropagator) GetMetadataAllocated(req *csi.GetMetadataAllocatedRequest, stream csi.SnapshotMetadata_GetMetadataAllocatedServer) error { + var byteOffset int64 + return s.sync(stream.Context(), + func() error { + byteOffset += 1024 + return stream.Send(&csi.GetMetadataAllocatedResponse{ + BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH, + VolumeCapacityBytes: 1024 * 1024 * 1024, + BlockMetadata: []*csi.BlockMetadata{ + { + ByteOffset: byteOffset, + SizeBytes: 1024, + }, + }, + }) + }) +} + +func (s *testSnapshotMetadataServerCtxPropagator) GetMetadataDelta(req *csi.GetMetadataDeltaRequest, stream csi.SnapshotMetadata_GetMetadataDeltaServer) error { + var byteOffset int64 + return s.sync(stream.Context(), + func() error { + byteOffset += 1024 + return stream.Send(&csi.GetMetadataDeltaResponse{ + BlockMetadataType: csi.BlockMetadataType_FIXED_LENGTH, + VolumeCapacityBytes: 1024 * 1024 * 1024, + BlockMetadata: []*csi.BlockMetadata{ + { + ByteOffset: byteOffset, + SizeBytes: 1024, + }, + }, + }) + }) +} diff --git a/pkg/internal/server/grpc/get_metadata_allocated.go b/pkg/internal/server/grpc/get_metadata_allocated.go index ad141cbe..35f84e7a 100644 --- a/pkg/internal/server/grpc/get_metadata_allocated.go +++ b/pkg/internal/server/grpc/get_metadata_allocated.go @@ -31,7 +31,11 @@ import ( ) func (s *Server) GetMetadataAllocated(req *api.GetMetadataAllocatedRequest, stream api.SnapshotMetadata_GetMetadataAllocatedServer) error { - ctx := s.getMetadataAllocatedContextWithLogger(req, stream) + // Create a timeout context so that failure in either sending to the client or + // receiving from the CSI driver will ultimately abort the handler session. + // The context could also get canceled by the client. + ctx, cancelFn := context.WithTimeout(s.getMetadataAllocatedContextWithLogger(req, stream), s.config.MaxStreamDur) + defer cancelFn() if err := s.validateGetMetadataAllocatedRequest(req); err != nil { klog.FromContext(ctx).Error(err, "validation failed") @@ -55,6 +59,7 @@ func (s *Server) GetMetadataAllocated(req *api.GetMetadataAllocatedRequest, stre klog.FromContext(ctx).V(HandlerTraceLogLevel).Info("calling CSI driver", "snapshotId", csiReq.SnapshotId) csiStream, err := csi.NewSnapshotMetadataClient(s.csiConnection()).GetMetadataAllocated(ctx, csiReq) if err != nil { + klog.FromContext(ctx).Error(err, "csi.GetMetadataAllocated") return err } diff --git a/pkg/internal/server/grpc/get_metadata_allocated_test.go b/pkg/internal/server/grpc/get_metadata_allocated_test.go index 98634357..c1a708d8 100644 --- a/pkg/internal/server/grpc/get_metadata_allocated_test.go +++ b/pkg/internal/server/grpc/get_metadata_allocated_test.go @@ -22,6 +22,7 @@ import ( "fmt" "io" "testing" + "time" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/mock/gomock" @@ -751,3 +752,80 @@ func (f *fakeStreamServerSnapshotAllocated) Send(m *api.GetMetadataAllocatedResp func (f *fakeStreamServerSnapshotAllocated) verifyResponse(expectedResponse *api.GetMetadataAllocatedResponse) bool { return f.response.String() == expectedResponse.String() } + +func TestGetMetadataAllocatedClientErrorHandling(t *testing.T) { + t.Run("client-cancels-context", func(t *testing.T) { + sms, th := newSMSHarnessForCtxPropagation(t, 0) + defer sms.cleanup(t) + + // create the cancelable application client context + ctx, cancelFn := context.WithCancel(context.Background()) + + // make the RPC call + client := th.GRPCSnapshotMetadataClient(t) + clientStream, err := client.GetMetadataAllocated(ctx, &api.GetMetadataAllocatedRequest{ + SecurityToken: th.SecurityToken, + Namespace: th.Namespace, + SnapshotName: "snap-1", + }) + assert.NoError(t, err) + assert.NotNil(t, clientStream) + + sms.synchronizeBeforeCancel() + + r1, e1 := clientStream.Recv() // get the first response + assert.NoError(t, e1) + assert.NotNil(t, r1) + + // the client cancels the context + cancelFn() + + r2, e2 := clientStream.Recv() // fail because ctx is canceled + assert.Error(t, e2) + assert.ErrorContains(t, e2, context.Canceled.Error()) + assert.Nil(t, r2) + + sms.synchronizeAfterCancel() + + // Check the fake driver handler status + sms.mux.Lock() + defer sms.mux.Unlock() + assert.True(t, sms.handlerCalled) + assert.NoError(t, sms.send1Err) + assert.ErrorIs(t, sms.streamCtxErr, context.Canceled) + assert.Error(t, sms.send2Err) + assert.ErrorContains(t, sms.send2Err, context.Canceled.Error()) + }) + + t.Run("sidecar-deadline-exceeded", func(t *testing.T) { + // arrange for the sidecar to timeout quickly + sms, th := newSMSHarnessForCtxPropagation(t, time.Millisecond*10) + defer sms.cleanup(t) + + // make the RPC call + client := th.GRPCSnapshotMetadataClient(t) + clientStream, err := client.GetMetadataAllocated(context.Background(), &api.GetMetadataAllocatedRequest{ + SecurityToken: th.SecurityToken, + Namespace: th.Namespace, + SnapshotName: "snap-1", + }) + assert.NoError(t, err) + assert.NotNil(t, clientStream) + + sms.synchronizeBeforeCancel() + + // do not attempt to receive anything + + sms.synchronizeAfterCancel() + + // Check the fake driver handler status + sms.mux.Lock() + defer sms.mux.Unlock() + assert.True(t, sms.handlerCalled) + assert.NoError(t, sms.send1Err) + assert.Error(t, sms.send2Err) + // its a bit uncertain as to which context error we get in the handler + re := context.DeadlineExceeded.Error() + "|" + context.Canceled.Error() + assert.Regexp(t, re, sms.send2Err.Error()) + }) +} diff --git a/pkg/internal/server/grpc/get_metadata_delta.go b/pkg/internal/server/grpc/get_metadata_delta.go index 79724bb6..42b1422a 100644 --- a/pkg/internal/server/grpc/get_metadata_delta.go +++ b/pkg/internal/server/grpc/get_metadata_delta.go @@ -31,7 +31,11 @@ import ( ) func (s *Server) GetMetadataDelta(req *api.GetMetadataDeltaRequest, stream api.SnapshotMetadata_GetMetadataDeltaServer) error { - ctx := s.getMetadataDeltaContextWithLogger(req, stream) + // Create a timeout context so that failure in either sending to the client or + // receiving from the CSI driver will ultimately abort the handler session. + // The context could also get canceled by the client. + ctx, cancelFn := context.WithTimeout(s.getMetadataDeltaContextWithLogger(req, stream), s.config.MaxStreamDur) + defer cancelFn() if err := s.validateGetMetadataDeltaRequest(req); err != nil { klog.FromContext(ctx).Error(err, "validation failed") @@ -55,6 +59,7 @@ func (s *Server) GetMetadataDelta(req *api.GetMetadataDeltaRequest, stream api.S klog.FromContext(ctx).V(HandlerTraceLogLevel).Info("calling CSI driver", "baseSnapshotId", csiReq.BaseSnapshotId, "targetSnapshotId", csiReq.TargetSnapshotId) csiStream, err := csi.NewSnapshotMetadataClient(s.csiConnection()).GetMetadataDelta(ctx, csiReq) if err != nil { + klog.FromContext(ctx).Error(err, "csi.GetMetadataDelta") return err } diff --git a/pkg/internal/server/grpc/get_metadata_delta_test.go b/pkg/internal/server/grpc/get_metadata_delta_test.go index e6ff33fb..3c0848ea 100644 --- a/pkg/internal/server/grpc/get_metadata_delta_test.go +++ b/pkg/internal/server/grpc/get_metadata_delta_test.go @@ -22,6 +22,7 @@ import ( "fmt" "io" "testing" + "time" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/mock/gomock" @@ -985,3 +986,82 @@ func (f *fakeStreamServerSnapshotDelta) Send(m *api.GetMetadataDeltaResponse) er func (f *fakeStreamServerSnapshotDelta) verifyResponse(expectedResponse *api.GetMetadataDeltaResponse) bool { return f.response.String() == expectedResponse.String() } + +func TestGetMetadataDeltaClientErrorHandling(t *testing.T) { + t.Run("client-cancels-context", func(t *testing.T) { + sms, th := newSMSHarnessForCtxPropagation(t, 0) + defer sms.cleanup(t) + + // create the cancelable application client context + ctx, cancelFn := context.WithCancel(context.Background()) + + // make the RPC call + client := th.GRPCSnapshotMetadataClient(t) + clientStream, err := client.GetMetadataDelta(ctx, &api.GetMetadataDeltaRequest{ + SecurityToken: th.SecurityToken, + Namespace: th.Namespace, + BaseSnapshotName: "snap-1", + TargetSnapshotName: "snap-2", + }) + assert.NoError(t, err) + assert.NotNil(t, clientStream) + + sms.synchronizeBeforeCancel() + + r1, e1 := clientStream.Recv() // get the first response + assert.NoError(t, e1) + assert.NotNil(t, r1) + + // the client cancels the context + cancelFn() + + r2, e2 := clientStream.Recv() // fail because ctx is canceled + assert.Error(t, e2) + assert.ErrorContains(t, e2, context.Canceled.Error()) + assert.Nil(t, r2) + + sms.synchronizeAfterCancel() + + // Check the fake driver handler status + sms.mux.Lock() + defer sms.mux.Unlock() + assert.True(t, sms.handlerCalled) + assert.NoError(t, sms.send1Err) + assert.ErrorIs(t, sms.streamCtxErr, context.Canceled) + assert.Error(t, sms.send2Err) + assert.ErrorContains(t, sms.send2Err, context.Canceled.Error()) + }) + + t.Run("sidecar-deadline-exceeded", func(t *testing.T) { + // arrange for the sidecar to timeout quickly + sms, th := newSMSHarnessForCtxPropagation(t, time.Millisecond*10) + defer sms.cleanup(t) + + // make the RPC call + client := th.GRPCSnapshotMetadataClient(t) + clientStream, err := client.GetMetadataDelta(context.Background(), &api.GetMetadataDeltaRequest{ + SecurityToken: th.SecurityToken, + Namespace: th.Namespace, + BaseSnapshotName: "snap-1", + TargetSnapshotName: "snap-2", + }) + assert.NoError(t, err) + assert.NotNil(t, clientStream) + + sms.synchronizeBeforeCancel() + + // do not attempt to receive anything + + sms.synchronizeAfterCancel() + + // Check the fake driver handler status + sms.mux.Lock() + defer sms.mux.Unlock() + assert.True(t, sms.handlerCalled) + assert.NoError(t, sms.send1Err) + assert.Error(t, sms.send2Err) + // its a bit uncertain as to which context error we get in the handler + re := context.DeadlineExceeded.Error() + "|" + context.Canceled.Error() + assert.Regexp(t, re, sms.send2Err.Error()) + }) +} diff --git a/pkg/internal/server/grpc/server.go b/pkg/internal/server/grpc/server.go index c5611d3f..815f84f9 100644 --- a/pkg/internal/server/grpc/server.go +++ b/pkg/internal/server/grpc/server.go @@ -23,6 +23,7 @@ import ( "strconv" "sync" "sync/atomic" + "time" snapshot "github.com/kubernetes-csi/external-snapshotter/client/v8/clientset/versioned" "google.golang.org/grpc" @@ -40,10 +41,18 @@ import ( const ( HandlerTraceLogLevel = 4 HandlerDetailedTraceLogLevel = 5 + + HandlerDefaultMaxStreamDuration = time.Minute * 10 ) type ServerConfig struct { Runtime *runtime.Runtime + + // The maximum duration of a streaming session. + // The handler will abort if either the CSI driver or + // the client do not complete in this time. + // If not set then HandlerDefaultMaxStreamDuration is used. + MaxStreamDur time.Duration } type Server struct { @@ -61,6 +70,10 @@ func NewServer(config ServerConfig) (*Server, error) { return nil, err } + if config.MaxStreamDur <= 0 { + config.MaxStreamDur = HandlerDefaultMaxStreamDuration + } + return &Server{ config: config, grpcServer: grpc.NewServer(options...), diff --git a/pkg/internal/server/grpc/server_test.go b/pkg/internal/server/grpc/server_test.go index b77952ac..2b3c6d05 100644 --- a/pkg/internal/server/grpc/server_test.go +++ b/pkg/internal/server/grpc/server_test.go @@ -20,6 +20,7 @@ import ( "context" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" @@ -75,6 +76,7 @@ func TestNewServer(t *testing.T) { assert.NotNil(t, s) assert.NotNil(t, s.grpcServer) assert.Equal(t, s.config.Runtime, &rt) + assert.Equal(t, HandlerDefaultMaxStreamDuration, s.config.MaxStreamDur) err = s.Start() assert.Error(t, err) @@ -89,11 +91,16 @@ func TestNewServer(t *testing.T) { rt.TLSCertFile = rta.TLSCertFile rt.TLSKeyFile = rta.TLSKeyFile - s, err := NewServer(ServerConfig{Runtime: &rt}) + expMaxStreamDur := HandlerDefaultMaxStreamDuration + time.Minute + s, err := NewServer(ServerConfig{ + Runtime: &rt, + MaxStreamDur: expMaxStreamDur, + }) assert.NoError(t, err) assert.NotNil(t, s) assert.NotNil(t, s.grpcServer) assert.Equal(t, s.config.Runtime, &rt) + assert.Equal(t, expMaxStreamDur, s.config.MaxStreamDur) assert.False(t, s.isReady()) // initialized to not ready err = s.Start() diff --git a/pkg/sidecar/sidecar.go b/pkg/sidecar/sidecar.go index 667bf453..c1ef82f3 100644 --- a/pkg/sidecar/sidecar.go +++ b/pkg/sidecar/sidecar.go @@ -32,26 +32,28 @@ import ( ) const ( - defaultCSISocket = "/run/csi/socket" - defaultCSITimeout = time.Minute // Default timeout of short CSI calls like GetPluginInfo. - defaultGRPCPort = 50051 - defaultHTTPEndpoint = "" - defaultKubeAPIQPS = 5.0 - defaultKubeAPIBurst = 10 - defaultKubeconfig = "" - defaultMetricsPath = "/metrics" - - flagCSIAddress = "csi-address" - flagCSITimeout = "timeout" - flagGRPCPort = "port" - flagHTTPEndpoint = "http-endpoint" - flagKubeAPIBurst = "kube-api-burst" - flagKubeAPIQPS = "kube-api-qps" - flagKubeconfig = "kubeconfig" - flagMetricsPath = "metrics-path" - flagTLSCert = "tls-cert" - flagTLSKey = "tls-key" - flagVersion = "version" + defaultCSISocket = "/run/csi/socket" + defaultCSITimeout = time.Minute // Default timeout of short CSI calls like GetPluginInfo. + defaultGRPCPort = 50051 + defaultHTTPEndpoint = "" + defaultKubeAPIBurst = 10 + defaultKubeAPIQPS = 5.0 + defaultKubeconfig = "" + defaultMaxStreamingDurationMin = 10 + defaultMetricsPath = "/metrics" + + flagCSIAddress = "csi-address" + flagCSITimeout = "timeout" + flagGRPCPort = "port" + flagHTTPEndpoint = "http-endpoint" + flagKubeAPIBurst = "kube-api-burst" + flagKubeAPIQPS = "kube-api-qps" + flagKubeconfig = "kubeconfig" + flagMaxStreamingDurationMin = "max-streaming-duration-min" + flagMetricsPath = "metrics-path" + flagTLSCert = "tls-cert" + flagTLSKey = "tls-key" + flagVersion = "version" // tlsCertEnvVar is an environment variable that specifies the path to tls certificate file. tlsCertEnvVar = "TLS_CERT_PATH" @@ -87,7 +89,7 @@ func Run(argv []string, version string) int { // TBD May need to exposed metric HTTP end point // here because the wait for the CSI driver is open ended. - grpcServer, err := startGRPCServerAndValidateCSIDriver(rt) + grpcServer, err := startGRPCServerAndValidateCSIDriver(s.createServerConfig(rt)) if err != nil { klog.Error(err) return 1 @@ -106,17 +108,18 @@ type sidecarFlagSet struct { version string // flag variables - csiAddress *string - csiTimeout *time.Duration - grpcPort *int - httpEndpoint *string - kubeAPIBurst *int - kubeAPIQPS *float64 - kubeconfig *string - metricsPath *string - showVersion *bool - tlsCert *string - tlsKey *string + csiAddress *string + csiTimeout *time.Duration + grpcPort *int + httpEndpoint *string + kubeAPIBurst *int + kubeAPIQPS *float64 + kubeconfig *string + maxStreamingDurMin *int + metricsPath *string + showVersion *bool + tlsCert *string + tlsKey *string } var sidecarFlagSetErrorHandling flag.ErrorHandling = flag.ExitOnError // UT interception point. @@ -137,6 +140,8 @@ func newSidecarFlagSet(name, version string) *sidecarFlagSet { s.tlsCert = s.String(flagTLSCert, os.Getenv(tlsCertEnvVar), "Path to the TLS certificate file. Can also be set with the environment variable "+tlsCertEnvVar+".") s.tlsKey = s.String(flagTLSKey, os.Getenv(tlsKeyEnvVar), "Path to the TLS private key file. Can also be set with the environment variable "+tlsKeyEnvVar+".") + s.maxStreamingDurMin = s.Int(flagMaxStreamingDurationMin, defaultMaxStreamingDurationMin, "The maximum duration in minutes for any individual streaming session") + s.kubeAPIQPS = s.Float64(flagKubeAPIQPS, defaultKubeAPIQPS, "QPS to use while communicating with the kubernetes apiserver. Defaults to 5.0.") s.kubeAPIBurst = s.Int(flagKubeAPIBurst, defaultKubeAPIBurst, "Burst to use while communicating with the kubernetes apiserver. Defaults to 10.") @@ -220,11 +225,20 @@ func (s *sidecarFlagSet) runtimeArgsToArgv(progName string, rta runtime.Args) [] return argv } +func (s *sidecarFlagSet) createServerConfig(rt *runtime.Runtime) grpc.ServerConfig { + return grpc.ServerConfig{ + Runtime: rt, + MaxStreamDur: time.Duration(*s.maxStreamingDurMin * 60), + } +} + // startGRPCServerAndValidateCSIDriver starts the GRPC server and waits // for it to validate the CSI driver capabilities. -func startGRPCServerAndValidateCSIDriver(rt *runtime.Runtime) (*grpc.Server, error) { +func startGRPCServerAndValidateCSIDriver(config grpc.ServerConfig) (*grpc.Server, error) { + rt := config.Runtime + // create the GRPC server. - grpcServer, err := grpc.NewServer(grpc.ServerConfig{Runtime: rt}) + grpcServer, err := grpc.NewServer(config) if err != nil { klog.Errorf("Failed to start GRPC server: %v", err) return nil, err diff --git a/pkg/sidecar/sidecar_test.go b/pkg/sidecar/sidecar_test.go index cba7de43..59e61948 100644 --- a/pkg/sidecar/sidecar_test.go +++ b/pkg/sidecar/sidecar_test.go @@ -35,6 +35,7 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" "github.com/kubernetes-csi/external-snapshot-metadata/pkg/internal/runtime" + "github.com/kubernetes-csi/external-snapshot-metadata/pkg/internal/server/grpc" ) func TestSidecarFlagSet(t *testing.T) { @@ -83,7 +84,7 @@ func TestSidecarFlagSet(t *testing.T) { assert.Equal(t, fmt.Sprintf("%s %s\n", progName, version), string(output)) }) - t.Run("default-runtime-args", func(t *testing.T) { + t.Run("default-args", func(t *testing.T) { defer saveAndResetGlobalState()() expTLSCertFile := "/tls/certFile" @@ -112,6 +113,11 @@ func TestSidecarFlagSet(t *testing.T) { } assert.Equal(t, expRTA, rta) + + rt := &runtime.Runtime{} + config := sfs.createServerConfig(rt) + assert.Equal(t, rt, config.Runtime) + assert.Equal(t, time.Duration(defaultMaxStreamingDurationMin*60), config.MaxStreamDur) }) } @@ -142,7 +148,7 @@ func TestStartGRPCServerAndValidateCSIDriver(t *testing.T) { rt.GRPCPort = -1 // invalid port - s, err := startGRPCServerAndValidateCSIDriver(rt) + s, err := startGRPCServerAndValidateCSIDriver(grpc.ServerConfig{Runtime: rt}) assert.Error(t, err) assert.Contains(t, err.Error(), "invalid port") assert.Nil(t, s) @@ -156,7 +162,7 @@ func TestStartGRPCServerAndValidateCSIDriver(t *testing.T) { rt := rth.RuntimeForFakeCSIDriver(t) - s, err := startGRPCServerAndValidateCSIDriver(rt) + s, err := startGRPCServerAndValidateCSIDriver(grpc.ServerConfig{Runtime: rt}) assert.Error(t, err) assert.Contains(t, err.Error(), "error waiting for CSI driver to become ready") // probe unimplemented. assert.Nil(t, s) @@ -194,6 +200,7 @@ func TestRun(t *testing.T) { sfs := &sidecarFlagSet{} argv := sfs.runtimeArgsToArgv("progName", rt.Args) + argv = append(argv, flagMaxStreamingDurationMin, fmt.Sprintf("%d", defaultMaxStreamingDurationMin+1)) // invoke Run() in a goroutine so as not to block. wg := sync.WaitGroup{}