diff --git a/libs/common/grpc.go b/libs/common/grpc.go index 5ca26b9e6..6ae93d02d 100644 --- a/libs/common/grpc.go +++ b/libs/common/grpc.go @@ -90,6 +90,7 @@ func StartNewGRPCServer(ctx context.Context, addr string, registerServerHook fun func DefaultUnaryInterceptors(metrics *prometheusGrpcProvider.ServerMetrics) []grpc.UnaryServerInterceptor { return []grpc.UnaryServerInterceptor{ metrics.UnaryServerInterceptor(), + hwgrpc.UnaryPanicRecoverInterceptor(), hwgrpc.UnaryLoggingInterceptor, hwgrpc.UnaryErrorQualityControlInterceptor, hwgrpc.UnaryLocaleInterceptor, @@ -106,6 +107,7 @@ func DefaultUnaryInterceptors(metrics *prometheusGrpcProvider.ServerMetrics) []g func DefaultStreamInterceptors(metrics *prometheusGrpcProvider.ServerMetrics) []grpc.StreamServerInterceptor { return []grpc.StreamServerInterceptor{ metrics.StreamServerInterceptor(), + hwgrpc.StreamPanicRecoverInterceptor(), hwgrpc.StreamLoggingInterceptor, hwgrpc.StreamErrorQualityControlInterceptor, hwgrpc.StreamLocaleInterceptor, diff --git a/libs/common/hwgrpc/panic_interceptor.go b/libs/common/hwgrpc/panic_interceptor.go new file mode 100644 index 000000000..e7ad94913 --- /dev/null +++ b/libs/common/hwgrpc/panic_interceptor.go @@ -0,0 +1,51 @@ +package hwgrpc + +import ( + "context" + "runtime/debug" + "telemetry" + + "common/hwerr" + "common/locale" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery" + "github.com/prometheus/client_golang/prometheus" + zlog "github.com/rs/zerolog/log" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +var panicsRecovered = telemetry.NewLazyCounter(prometheus.CounterOpts{ + Name: "services_panics_recovered_total", + Help: "Total number of panics recovered by PanicRecoverInterceptor", +}) + +func recoveryHandlerFn() recovery.RecoveryHandlerFuncContext { + return func(ctx context.Context, recovered any) (err error) { + zlog.Ctx(ctx). + Error(). + Any("recovered", recovered). + Str("stack", string(debug.Stack())). + Msg("recovered a panic") + + panicsRecovered.Counter().Inc() + + return hwerr.NewStatusError(ctx, codes.Internal, "panic recovered", locale.GenericError(ctx)) + } +} + +func UnaryPanicRecoverInterceptor() grpc.UnaryServerInterceptor { + panicsRecovered.Ensure() + + return recovery.UnaryServerInterceptor( + recovery.WithRecoveryHandlerContext(recoveryHandlerFn()), + ) +} + +func StreamPanicRecoverInterceptor() grpc.StreamServerInterceptor { + panicsRecovered.Ensure() + + return recovery.StreamServerInterceptor( + recovery.WithRecoveryHandlerContext(recoveryHandlerFn()), + ) +} diff --git a/libs/common/hwgrpc/panic_interceptor_test.go b/libs/common/hwgrpc/panic_interceptor_test.go new file mode 100644 index 000000000..7eaf8ce6b --- /dev/null +++ b/libs/common/hwgrpc/panic_interceptor_test.go @@ -0,0 +1,80 @@ +package hwgrpc + +import ( + "context" + "telemetry" + "testing" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type recoveryAssertService struct { + testpb.TestServiceServer +} + +func (s *recoveryAssertService) Ping(ctx context.Context, ping *testpb.PingRequest) (*testpb.PingResponse, error) { + if ping.GetValue() == "panic" { + panic("very bad thing happened") + } + return s.TestServiceServer.Ping(ctx, ping) +} + +func (s *recoveryAssertService) PingList(ping *testpb.PingListRequest, stream testpb.TestService_PingListServer) error { + if ping.Value == "panic" { + panic("very bad thing happened") + } + return s.TestServiceServer.PingList(ping, stream) +} + +type RecoverySuite struct { + *testpb.InterceptorTestSuite +} + +func TestPanicRecoverInterceptor(t *testing.T) { + telemetry.SetupMetrics(context.Background(), nil) + s := &RecoverySuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + TestService: &recoveryAssertService{TestServiceServer: &testpb.TestPingService{}}, + ServerOpts: []grpc.ServerOption{ + grpc.StreamInterceptor(StreamPanicRecoverInterceptor()), + grpc.UnaryInterceptor(UnaryPanicRecoverInterceptor()), + }, + }, + } + suite.Run(t, s) +} + +func (s *RecoverySuite) TestUnary_SuccessfulRequest() { + _, err := s.Client.Ping(s.SimpleCtx(), testpb.GoodPing) + s.Require().NoError(err) +} + +func (s *RecoverySuite) TestUnary_PanicRequest() { + _, err := s.Client.Ping(s.SimpleCtx(), &testpb.PingRequest{Value: "panic"}) + s.Require().Error(err) + st, ok := status.FromError(err) + s.Require().True(ok, "not a status error") + s.Require().Equal(codes.Internal, st.Code()) +} + +func (s *RecoverySuite) TestStream_SuccessfulReceive() { + stream, err := s.Client.PingList(s.SimpleCtx(), testpb.GoodPingList) + s.Require().NoError(err, "should not fail on establishing the stream") + pong, err := stream.Recv() + s.Require().NoError(err, "no error must occur") + s.Require().NotNil(pong, "pong must not be nil") +} + +func (s *RecoverySuite) TestStream_PanickingReceive() { + stream, err := s.Client.PingList(s.SimpleCtx(), &testpb.PingListRequest{Value: "panic"}) + s.Require().NoError(err, "should not fail on establishing the stream") + _, err = stream.Recv() + s.Require().Error(err) + st, ok := status.FromError(err) + s.Require().True(ok, "not a status error") + s.Require().Equal(codes.Internal, st.Code()) +} diff --git a/libs/telemetry/setup.go b/libs/telemetry/setup.go index 03fb95d52..00a2f78d7 100644 --- a/libs/telemetry/setup.go +++ b/libs/telemetry/setup.go @@ -3,6 +3,7 @@ package telemetry import ( "context" "errors" + "github.com/prometheus/client_golang/prometheus/promauto" "hwutil" "net/http" "os" @@ -106,3 +107,29 @@ func SetupMetrics(ctx context.Context, shutdown func(error)) { func PrometheusRegistry() *prometheus.Registry { return prometheusRegistry } + +// LazyCounter prevents access to PrometheusRegistry, before it is initialized +// by creating the counter only when it is needed +type LazyCounter struct { + opts prometheus.CounterOpts + counter *prometheus.Counter +} + +func NewLazyCounter(opts prometheus.CounterOpts) LazyCounter { + return LazyCounter{ + opts: opts, + counter: nil, + } +} + +func (lc *LazyCounter) Counter() prometheus.Counter { + if lc.counter != nil { + return *lc.counter + } + lc.counter = hwutil.PtrTo(promauto.With(prometheusRegistry).NewCounter(lc.opts)) + return *lc.counter +} + +func (lc *LazyCounter) Ensure() { + lc.Counter() +}