From 5d7d85f5a0388bb0faa0d9250f96b35814cff1f9 Mon Sep 17 00:00:00 2001 From: Tonis Tiigi Date: Sun, 17 Dec 2023 23:39:51 -0800 Subject: [PATCH] pb: add extra validation to protobuf types Signed-off-by: Tonis Tiigi (cherry picked from commit 838635998dcae34bbde59e3eab129ab85bd37bef) --- client/validation_test.go | 9 ++++++--- control/control.go | 3 +++ frontend/gateway/client/attestation.go | 6 ++++++ frontend/gateway/gateway.go | 15 +++++++++++++++ util/tracing/transform/attribute.go | 21 ++++++++++++++++----- util/tracing/transform/span.go | 23 +++++++++++++++++++---- 6 files changed, 65 insertions(+), 12 deletions(-) diff --git a/client/validation_test.go b/client/validation_test.go index 62fade4a07d9..672054a6a301 100644 --- a/client/validation_test.go +++ b/client/validation_test.go @@ -11,6 +11,7 @@ import ( "github.com/moby/buildkit/frontend/gateway/client" sppb "github.com/moby/buildkit/sourcepolicy/pb" "github.com/moby/buildkit/util/testutil/integration" + "github.com/moby/buildkit/util/testutil/workers" ocispecs "github.com/opencontainers/image-spec/specs-go/v1" "github.com/stretchr/testify/require" ) @@ -25,6 +26,7 @@ var validationTests = []func(t *testing.T, sb integration.Sandbox){ func testValidateNullConfig(t *testing.T, sb integration.Sandbox) { requiresLinux(t) + workers.CheckFeatureCompat(t, sb, workers.FeatureOCIExporter) ctx := sb.Context() @@ -63,6 +65,7 @@ func testValidateNullConfig(t *testing.T, sb integration.Sandbox) { func testValidateInvalidConfig(t *testing.T, sb integration.Sandbox) { requiresLinux(t) + workers.CheckFeatureCompat(t, sb, workers.FeatureOCIExporter) ctx := sb.Context() @@ -104,11 +107,12 @@ func testValidateInvalidConfig(t *testing.T, sb integration.Sandbox) { }, }, "", b, nil) require.Error(t, err) - require.Contains(t, err.Error(), "invalid image config for export: missing os") + require.Contains(t, err.Error(), "invalid image config: os and architecture must be specified together") } func testValidatePlatformsEmpty(t *testing.T, sb integration.Sandbox) { requiresLinux(t) + workers.CheckFeatureCompat(t, sb, workers.FeatureOCIExporter) ctx := sb.Context() @@ -147,6 +151,7 @@ func testValidatePlatformsEmpty(t *testing.T, sb integration.Sandbox) { func testValidatePlatformsInvalid(t *testing.T, sb integration.Sandbox) { requiresLinux(t) + workers.CheckFeatureCompat(t, sb, workers.FeatureOCIExporter) ctx := sb.Context() @@ -279,7 +284,6 @@ func testValidateSourcePolicy(t *testing.T, sb integration.Sandbox) { for _, tc := range tcases { t.Run(tc.name, func(t *testing.T) { - var viaFrontend bool b := func(ctx context.Context, c client.Client) (*client.Result, error) { @@ -310,7 +314,6 @@ func testValidateSourcePolicy(t *testing.T, sb integration.Sandbox) { _, err = c.Build(ctx, SolveOpt{}, "", b, nil) require.Error(t, err) require.Contains(t, err.Error(), tc.exp) - }) } } diff --git a/control/control.go b/control/control.go index 276003604db5..40058f8fe1f1 100644 --- a/control/control.go +++ b/control/control.go @@ -420,6 +420,9 @@ func (c *Controller) Solve(ctx context.Context, req *controlapi.SolveRequest) (* var cacheImports []frontend.CacheOptionsEntry for _, im := range req.Cache.Imports { + if im == nil { + continue + } cacheImports = append(cacheImports, frontend.CacheOptionsEntry{ Type: im.Type, Attrs: im.Attrs, diff --git a/frontend/gateway/client/attestation.go b/frontend/gateway/client/attestation.go index 5ffe67233c50..c5112db9db64 100644 --- a/frontend/gateway/client/attestation.go +++ b/frontend/gateway/client/attestation.go @@ -30,8 +30,14 @@ func AttestationToPB[T any](a *result.Attestation[T]) (*pb.Attestation, error) { } func AttestationFromPB[T any](a *pb.Attestation) (*result.Attestation[T], error) { + if a == nil { + return nil, errors.Errorf("invalid nil attestation") + } subjects := make([]result.InTotoSubject, len(a.InTotoSubjects)) for i, subject := range a.InTotoSubjects { + if subject == nil { + return nil, errors.Errorf("invalid nil attestation subject") + } subjects[i] = result.InTotoSubject{ Kind: subject.Kind, Name: subject.Name, diff --git a/frontend/gateway/gateway.go b/frontend/gateway/gateway.go index bc31d3e0b5b8..92a6ffd85c5e 100644 --- a/frontend/gateway/gateway.go +++ b/frontend/gateway/gateway.go @@ -646,12 +646,21 @@ func (lbf *llbBridgeForwarder) registerResultIDs(results ...solver.Result) (ids func (lbf *llbBridgeForwarder) Solve(ctx context.Context, req *pb.SolveRequest) (*pb.SolveResponse, error) { var cacheImports []frontend.CacheOptionsEntry for _, e := range req.CacheImports { + if e == nil { + return nil, errors.Errorf("invalid nil cache import") + } cacheImports = append(cacheImports, frontend.CacheOptionsEntry{ Type: e.Type, Attrs: e.Attrs, }) } + for _, p := range req.SourcePolicies { + if p == nil { + return nil, errors.Errorf("invalid nil source policy") + } + } + ctx = tracing.ContextWithSpanFromContext(ctx, lbf.callCtx) res, err := lbf.llbBridge.Solve(ctx, frontend.SolveRequest{ Evaluate: req.Evaluate, @@ -1076,6 +1085,12 @@ func (lbf *llbBridgeForwarder) ReleaseContainer(ctx context.Context, in *pb.Rele } func (lbf *llbBridgeForwarder) Warn(ctx context.Context, in *pb.WarnRequest) (*pb.WarnResponse, error) { + // validate ranges are valid + for _, r := range in.Ranges { + if r == nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid source range") + } + } err := lbf.llbBridge.Warn(ctx, in.Digest, string(in.Short), frontend.WarnOpts{ Level: int(in.Level), SourceInfo: in.Info, diff --git a/util/tracing/transform/attribute.go b/util/tracing/transform/attribute.go index 2debe8835924..bc0df048d0a2 100644 --- a/util/tracing/transform/attribute.go +++ b/util/tracing/transform/attribute.go @@ -13,6 +13,9 @@ func Attributes(attrs []*commonpb.KeyValue) []attribute.KeyValue { out := make([]attribute.KeyValue, 0, len(attrs)) for _, a := range attrs { + if a == nil { + continue + } kv := attribute.KeyValue{ Key: attribute.Key(a.Key), Value: toValue(a.Value), @@ -42,7 +45,9 @@ func toValue(v *commonpb.AnyValue) attribute.Value { func boolArray(kv []*commonpb.AnyValue) attribute.Value { arr := make([]bool, len(kv)) for i, v := range kv { - arr[i] = v.GetBoolValue() + if v != nil { + arr[i] = v.GetBoolValue() + } } return attribute.BoolSliceValue(arr) } @@ -50,7 +55,9 @@ func boolArray(kv []*commonpb.AnyValue) attribute.Value { func intArray(kv []*commonpb.AnyValue) attribute.Value { arr := make([]int64, len(kv)) for i, v := range kv { - arr[i] = v.GetIntValue() + if v != nil { + arr[i] = v.GetIntValue() + } } return attribute.Int64SliceValue(arr) } @@ -58,7 +65,9 @@ func intArray(kv []*commonpb.AnyValue) attribute.Value { func doubleArray(kv []*commonpb.AnyValue) attribute.Value { arr := make([]float64, len(kv)) for i, v := range kv { - arr[i] = v.GetDoubleValue() + if v != nil { + arr[i] = v.GetDoubleValue() + } } return attribute.Float64SliceValue(arr) } @@ -66,13 +75,15 @@ func doubleArray(kv []*commonpb.AnyValue) attribute.Value { func stringArray(kv []*commonpb.AnyValue) attribute.Value { arr := make([]string, len(kv)) for i, v := range kv { - arr[i] = v.GetStringValue() + if v != nil { + arr[i] = v.GetStringValue() + } } return attribute.StringSliceValue(arr) } func arrayValues(kv []*commonpb.AnyValue) attribute.Value { - if len(kv) == 0 { + if len(kv) == 0 || kv[0] == nil { return attribute.StringSliceValue([]string{}) } diff --git a/util/tracing/transform/span.go b/util/tracing/transform/span.go index 9f7924c4a7e1..2273e3635d9d 100644 --- a/util/tracing/transform/span.go +++ b/util/tracing/transform/span.go @@ -32,14 +32,20 @@ func Spans(sdl []*tracepb.ResourceSpans) []tracesdk.ReadOnlySpan { } for _, sdi := range sd.ScopeSpans { - sda := make([]tracesdk.ReadOnlySpan, len(sdi.Spans)) - for i, s := range sdi.Spans { - sda[i] = &readOnlySpan{ + if sdi == nil { + continue + } + sda := make([]tracesdk.ReadOnlySpan, 0, len(sdi.Spans)) + for _, s := range sdi.Spans { + if s == nil { + continue + } + sda = append(sda, &readOnlySpan{ pb: s, is: sdi.Scope, resource: sd.Resource, schemaURL: sd.SchemaUrl, - } + }) } out = append(out, sda...) } @@ -170,6 +176,9 @@ var _ tracesdk.ReadOnlySpan = &readOnlySpan{} // status transform a OTLP span status into span code. func statusCode(st *tracepb.Status) codes.Code { + if st == nil { + return codes.Unset + } switch st.Code { case tracepb.Status_STATUS_CODE_ERROR: return codes.Error @@ -186,6 +195,9 @@ func links(links []*tracepb.Span_Link) []tracesdk.Link { sl := make([]tracesdk.Link, 0, len(links)) for _, otLink := range links { + if otLink == nil { + continue + } // This redefinition is necessary to prevent otLink.*ID[:] copies // being reused -- in short we need a new otLink per iteration. otLink := otLink @@ -226,6 +238,9 @@ func spanEvents(es []*tracepb.Span_Event) []tracesdk.Event { if messageEvents >= maxMessageEventsPerSpan { break } + if e == nil { + continue + } messageEvents++ events = append(events, tracesdk.Event{