Skip to content

Commit

Permalink
panic recover (#876)
Browse files Browse the repository at this point in the history
  • Loading branch information
FoseFx authored Oct 28, 2024
1 parent 3aa7035 commit baffb92
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 0 deletions.
2 changes: 2 additions & 0 deletions libs/common/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func StartNewGRPCServer(ctx context.Context, addr string, registerServerHook fun
func DefaultUnaryInterceptors(metrics *prometheusGrpcProvider.ServerMetrics) []grpc.UnaryServerInterceptor {
return []grpc.UnaryServerInterceptor{
metrics.UnaryServerInterceptor(),
hwgrpc.UnaryPanicRecoverInterceptor(),
hwgrpc.UnaryLoggingInterceptor,
hwgrpc.UnaryErrorQualityControlInterceptor,
hwgrpc.UnaryLocaleInterceptor,
Expand All @@ -106,6 +107,7 @@ func DefaultUnaryInterceptors(metrics *prometheusGrpcProvider.ServerMetrics) []g
func DefaultStreamInterceptors(metrics *prometheusGrpcProvider.ServerMetrics) []grpc.StreamServerInterceptor {
return []grpc.StreamServerInterceptor{
metrics.StreamServerInterceptor(),
hwgrpc.StreamPanicRecoverInterceptor(),
hwgrpc.StreamLoggingInterceptor,
hwgrpc.StreamErrorQualityControlInterceptor,
hwgrpc.StreamLocaleInterceptor,
Expand Down
51 changes: 51 additions & 0 deletions libs/common/hwgrpc/panic_interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package hwgrpc

import (
"context"
"runtime/debug"
"telemetry"

"common/hwerr"
"common/locale"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery"
"github.com/prometheus/client_golang/prometheus"
zlog "github.com/rs/zerolog/log"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)

var panicsRecovered = telemetry.NewLazyCounter(prometheus.CounterOpts{
Name: "services_panics_recovered_total",
Help: "Total number of panics recovered by PanicRecoverInterceptor",
})

func recoveryHandlerFn() recovery.RecoveryHandlerFuncContext {
return func(ctx context.Context, recovered any) (err error) {
zlog.Ctx(ctx).
Error().
Any("recovered", recovered).
Str("stack", string(debug.Stack())).
Msg("recovered a panic")

panicsRecovered.Counter().Inc()

return hwerr.NewStatusError(ctx, codes.Internal, "panic recovered", locale.GenericError(ctx))
}
}

func UnaryPanicRecoverInterceptor() grpc.UnaryServerInterceptor {
panicsRecovered.Ensure()

return recovery.UnaryServerInterceptor(
recovery.WithRecoveryHandlerContext(recoveryHandlerFn()),
)
}

func StreamPanicRecoverInterceptor() grpc.StreamServerInterceptor {
panicsRecovered.Ensure()

return recovery.StreamServerInterceptor(
recovery.WithRecoveryHandlerContext(recoveryHandlerFn()),
)
}
80 changes: 80 additions & 0 deletions libs/common/hwgrpc/panic_interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package hwgrpc

import (
"context"
"telemetry"
"testing"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type recoveryAssertService struct {
testpb.TestServiceServer
}

func (s *recoveryAssertService) Ping(ctx context.Context, ping *testpb.PingRequest) (*testpb.PingResponse, error) {
if ping.GetValue() == "panic" {
panic("very bad thing happened")
}
return s.TestServiceServer.Ping(ctx, ping)
}

func (s *recoveryAssertService) PingList(ping *testpb.PingListRequest, stream testpb.TestService_PingListServer) error {
if ping.Value == "panic" {
panic("very bad thing happened")
}
return s.TestServiceServer.PingList(ping, stream)
}

type RecoverySuite struct {
*testpb.InterceptorTestSuite
}

func TestPanicRecoverInterceptor(t *testing.T) {
telemetry.SetupMetrics(context.Background(), nil)
s := &RecoverySuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
TestService: &recoveryAssertService{TestServiceServer: &testpb.TestPingService{}},
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(StreamPanicRecoverInterceptor()),
grpc.UnaryInterceptor(UnaryPanicRecoverInterceptor()),
},
},
}
suite.Run(t, s)
}

func (s *RecoverySuite) TestUnary_SuccessfulRequest() {
_, err := s.Client.Ping(s.SimpleCtx(), testpb.GoodPing)
s.Require().NoError(err)
}

func (s *RecoverySuite) TestUnary_PanicRequest() {
_, err := s.Client.Ping(s.SimpleCtx(), &testpb.PingRequest{Value: "panic"})
s.Require().Error(err)
st, ok := status.FromError(err)
s.Require().True(ok, "not a status error")
s.Require().Equal(codes.Internal, st.Code())
}

func (s *RecoverySuite) TestStream_SuccessfulReceive() {
stream, err := s.Client.PingList(s.SimpleCtx(), testpb.GoodPingList)
s.Require().NoError(err, "should not fail on establishing the stream")
pong, err := stream.Recv()
s.Require().NoError(err, "no error must occur")
s.Require().NotNil(pong, "pong must not be nil")
}

func (s *RecoverySuite) TestStream_PanickingReceive() {
stream, err := s.Client.PingList(s.SimpleCtx(), &testpb.PingListRequest{Value: "panic"})
s.Require().NoError(err, "should not fail on establishing the stream")
_, err = stream.Recv()
s.Require().Error(err)
st, ok := status.FromError(err)
s.Require().True(ok, "not a status error")
s.Require().Equal(codes.Internal, st.Code())
}
27 changes: 27 additions & 0 deletions libs/telemetry/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package telemetry
import (
"context"
"errors"
"github.com/prometheus/client_golang/prometheus/promauto"
"hwutil"
"net/http"
"os"
Expand Down Expand Up @@ -106,3 +107,29 @@ func SetupMetrics(ctx context.Context, shutdown func(error)) {
func PrometheusRegistry() *prometheus.Registry {
return prometheusRegistry
}

// LazyCounter prevents access to PrometheusRegistry, before it is initialized
// by creating the counter only when it is needed
type LazyCounter struct {
opts prometheus.CounterOpts
counter *prometheus.Counter
}

func NewLazyCounter(opts prometheus.CounterOpts) LazyCounter {
return LazyCounter{
opts: opts,
counter: nil,
}
}

func (lc *LazyCounter) Counter() prometheus.Counter {
if lc.counter != nil {
return *lc.counter
}
lc.counter = hwutil.PtrTo(promauto.With(prometheusRegistry).NewCounter(lc.opts))
return *lc.counter
}

func (lc *LazyCounter) Ensure() {
lc.Counter()
}

0 comments on commit baffb92

Please sign in to comment.