Skip to content

Commit

Permalink
pb: add extra validation to protobuf types
Browse files Browse the repository at this point in the history
Signed-off-by: Tonis Tiigi <[email protected]>
(cherry picked from commit 838635998dcae34bbde59e3eab129ab85bd37bef)
  • Loading branch information
tonistiigi committed Jan 31, 2024
1 parent e11862c commit 5d7d85f
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 12 deletions.
9 changes: 6 additions & 3 deletions client/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/moby/buildkit/frontend/gateway/client"
sppb "github.com/moby/buildkit/sourcepolicy/pb"
"github.com/moby/buildkit/util/testutil/integration"
"github.com/moby/buildkit/util/testutil/workers"
ocispecs "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/stretchr/testify/require"
)
Expand All @@ -25,6 +26,7 @@ var validationTests = []func(t *testing.T, sb integration.Sandbox){

func testValidateNullConfig(t *testing.T, sb integration.Sandbox) {
requiresLinux(t)
workers.CheckFeatureCompat(t, sb, workers.FeatureOCIExporter)

ctx := sb.Context()

Expand Down Expand Up @@ -63,6 +65,7 @@ func testValidateNullConfig(t *testing.T, sb integration.Sandbox) {

func testValidateInvalidConfig(t *testing.T, sb integration.Sandbox) {
requiresLinux(t)
workers.CheckFeatureCompat(t, sb, workers.FeatureOCIExporter)

ctx := sb.Context()

Expand Down Expand Up @@ -104,11 +107,12 @@ func testValidateInvalidConfig(t *testing.T, sb integration.Sandbox) {
},
}, "", b, nil)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid image config for export: missing os")
require.Contains(t, err.Error(), "invalid image config: os and architecture must be specified together")
}

func testValidatePlatformsEmpty(t *testing.T, sb integration.Sandbox) {
requiresLinux(t)
workers.CheckFeatureCompat(t, sb, workers.FeatureOCIExporter)

ctx := sb.Context()

Expand Down Expand Up @@ -147,6 +151,7 @@ func testValidatePlatformsEmpty(t *testing.T, sb integration.Sandbox) {

func testValidatePlatformsInvalid(t *testing.T, sb integration.Sandbox) {
requiresLinux(t)
workers.CheckFeatureCompat(t, sb, workers.FeatureOCIExporter)

ctx := sb.Context()

Expand Down Expand Up @@ -279,7 +284,6 @@ func testValidateSourcePolicy(t *testing.T, sb integration.Sandbox) {

for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {

var viaFrontend bool

b := func(ctx context.Context, c client.Client) (*client.Result, error) {
Expand Down Expand Up @@ -310,7 +314,6 @@ func testValidateSourcePolicy(t *testing.T, sb integration.Sandbox) {
_, err = c.Build(ctx, SolveOpt{}, "", b, nil)
require.Error(t, err)
require.Contains(t, err.Error(), tc.exp)

})
}
}
3 changes: 3 additions & 0 deletions control/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ func (c *Controller) Solve(ctx context.Context, req *controlapi.SolveRequest) (*

var cacheImports []frontend.CacheOptionsEntry
for _, im := range req.Cache.Imports {
if im == nil {
continue
}
cacheImports = append(cacheImports, frontend.CacheOptionsEntry{
Type: im.Type,
Attrs: im.Attrs,
Expand Down
6 changes: 6 additions & 0 deletions frontend/gateway/client/attestation.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@ func AttestationToPB[T any](a *result.Attestation[T]) (*pb.Attestation, error) {
}

func AttestationFromPB[T any](a *pb.Attestation) (*result.Attestation[T], error) {
if a == nil {
return nil, errors.Errorf("invalid nil attestation")
}
subjects := make([]result.InTotoSubject, len(a.InTotoSubjects))
for i, subject := range a.InTotoSubjects {
if subject == nil {
return nil, errors.Errorf("invalid nil attestation subject")
}
subjects[i] = result.InTotoSubject{
Kind: subject.Kind,
Name: subject.Name,
Expand Down
15 changes: 15 additions & 0 deletions frontend/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,12 +646,21 @@ func (lbf *llbBridgeForwarder) registerResultIDs(results ...solver.Result) (ids
func (lbf *llbBridgeForwarder) Solve(ctx context.Context, req *pb.SolveRequest) (*pb.SolveResponse, error) {
var cacheImports []frontend.CacheOptionsEntry
for _, e := range req.CacheImports {
if e == nil {
return nil, errors.Errorf("invalid nil cache import")
}
cacheImports = append(cacheImports, frontend.CacheOptionsEntry{
Type: e.Type,
Attrs: e.Attrs,
})
}

for _, p := range req.SourcePolicies {
if p == nil {
return nil, errors.Errorf("invalid nil source policy")
}
}

ctx = tracing.ContextWithSpanFromContext(ctx, lbf.callCtx)
res, err := lbf.llbBridge.Solve(ctx, frontend.SolveRequest{
Evaluate: req.Evaluate,
Expand Down Expand Up @@ -1076,6 +1085,12 @@ func (lbf *llbBridgeForwarder) ReleaseContainer(ctx context.Context, in *pb.Rele
}

func (lbf *llbBridgeForwarder) Warn(ctx context.Context, in *pb.WarnRequest) (*pb.WarnResponse, error) {
// validate ranges are valid
for _, r := range in.Ranges {
if r == nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid source range")
}
}
err := lbf.llbBridge.Warn(ctx, in.Digest, string(in.Short), frontend.WarnOpts{
Level: int(in.Level),
SourceInfo: in.Info,
Expand Down
21 changes: 16 additions & 5 deletions util/tracing/transform/attribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ func Attributes(attrs []*commonpb.KeyValue) []attribute.KeyValue {

out := make([]attribute.KeyValue, 0, len(attrs))
for _, a := range attrs {
if a == nil {
continue
}
kv := attribute.KeyValue{
Key: attribute.Key(a.Key),
Value: toValue(a.Value),
Expand Down Expand Up @@ -42,37 +45,45 @@ func toValue(v *commonpb.AnyValue) attribute.Value {
func boolArray(kv []*commonpb.AnyValue) attribute.Value {
arr := make([]bool, len(kv))
for i, v := range kv {
arr[i] = v.GetBoolValue()
if v != nil {
arr[i] = v.GetBoolValue()
}
}
return attribute.BoolSliceValue(arr)
}

func intArray(kv []*commonpb.AnyValue) attribute.Value {
arr := make([]int64, len(kv))
for i, v := range kv {
arr[i] = v.GetIntValue()
if v != nil {
arr[i] = v.GetIntValue()
}
}
return attribute.Int64SliceValue(arr)
}

func doubleArray(kv []*commonpb.AnyValue) attribute.Value {
arr := make([]float64, len(kv))
for i, v := range kv {
arr[i] = v.GetDoubleValue()
if v != nil {
arr[i] = v.GetDoubleValue()
}
}
return attribute.Float64SliceValue(arr)
}

func stringArray(kv []*commonpb.AnyValue) attribute.Value {
arr := make([]string, len(kv))
for i, v := range kv {
arr[i] = v.GetStringValue()
if v != nil {
arr[i] = v.GetStringValue()
}
}
return attribute.StringSliceValue(arr)
}

func arrayValues(kv []*commonpb.AnyValue) attribute.Value {
if len(kv) == 0 {
if len(kv) == 0 || kv[0] == nil {
return attribute.StringSliceValue([]string{})
}

Expand Down
23 changes: 19 additions & 4 deletions util/tracing/transform/span.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,20 @@ func Spans(sdl []*tracepb.ResourceSpans) []tracesdk.ReadOnlySpan {
}

for _, sdi := range sd.ScopeSpans {
sda := make([]tracesdk.ReadOnlySpan, len(sdi.Spans))
for i, s := range sdi.Spans {
sda[i] = &readOnlySpan{
if sdi == nil {
continue
}
sda := make([]tracesdk.ReadOnlySpan, 0, len(sdi.Spans))
for _, s := range sdi.Spans {
if s == nil {
continue
}
sda = append(sda, &readOnlySpan{
pb: s,
is: sdi.Scope,
resource: sd.Resource,
schemaURL: sd.SchemaUrl,
}
})
}
out = append(out, sda...)
}
Expand Down Expand Up @@ -170,6 +176,9 @@ var _ tracesdk.ReadOnlySpan = &readOnlySpan{}

// status transform a OTLP span status into span code.
func statusCode(st *tracepb.Status) codes.Code {
if st == nil {
return codes.Unset
}
switch st.Code {
case tracepb.Status_STATUS_CODE_ERROR:
return codes.Error
Expand All @@ -186,6 +195,9 @@ func links(links []*tracepb.Span_Link) []tracesdk.Link {

sl := make([]tracesdk.Link, 0, len(links))
for _, otLink := range links {
if otLink == nil {
continue
}
// This redefinition is necessary to prevent otLink.*ID[:] copies
// being reused -- in short we need a new otLink per iteration.
otLink := otLink
Expand Down Expand Up @@ -226,6 +238,9 @@ func spanEvents(es []*tracepb.Span_Event) []tracesdk.Event {
if messageEvents >= maxMessageEventsPerSpan {
break
}
if e == nil {
continue
}
messageEvents++
events = append(events,
tracesdk.Event{
Expand Down

0 comments on commit 5d7d85f

Please sign in to comment.