From 2791f59232954071d8d3722c27708addc78509dc Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 13 Dec 2024 15:22:23 +0000 Subject: [PATCH] Add experiment flag --- .../experiment/experiment.go | 41 +++++++++++++++++ .../workloadidentityv1/issuer_service.go | 46 +++++++++++++------ .../workloadidentityv1_test.go | 26 +++++++---- 3 files changed, 90 insertions(+), 23 deletions(-) create mode 100644 lib/auth/machineid/workloadidentityv1/experiment/experiment.go diff --git a/lib/auth/machineid/workloadidentityv1/experiment/experiment.go b/lib/auth/machineid/workloadidentityv1/experiment/experiment.go new file mode 100644 index 0000000000000..fafe51ea83d1f --- /dev/null +++ b/lib/auth/machineid/workloadidentityv1/experiment/experiment.go @@ -0,0 +1,41 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package experiment + +import ( + "os" + "sync" +) + +var mu sync.Mutex + +var experimentEnabled = os.Getenv("TELEPORT_WORKLOAD_IDENTITY_UX_EXPERIMENT") == "1" + +// Enabled returns true if the workload identity UX experiment is +// enabled. +func Enabled() bool { + mu.Lock() + defer mu.Unlock() + return experimentEnabled +} + +// SetEnabled sets the experiment enabled flag. +func SetEnabled(enabled bool) { + mu.Lock() + defer mu.Unlock() + experimentEnabled = enabled +} diff --git a/lib/auth/machineid/workloadidentityv1/issuer_service.go b/lib/auth/machineid/workloadidentityv1/issuer_service.go index c9381ce15b415..399c4574ee220 100644 --- a/lib/auth/machineid/workloadidentityv1/issuer_service.go +++ b/lib/auth/machineid/workloadidentityv1/issuer_service.go @@ -42,6 +42,7 @@ import ( "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1/experiment" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/jwt" @@ -125,16 +126,20 @@ func NewIssuanceService(cfg *IssuanceServiceConfig) (*IssuanceService, error) { }, nil } -// getFieldStringValue -// TODO(noah): This is a fairly gnarly first pass of a reflection based -// attribute extraction function. This will eventually be replaced potentially -// by the chosen expression/predicate language mechanism. +// getFieldStringValue returns a string value from the given attribute set. +// The attribute is specified as a dot-separated path to the field in the +// attribute set. +// +// The specified attribute must be a string field. If the attribute is not +// found, an error is returned. +// +// TODO(noah): This function will be replaced by the Teleport predicate language +// in a coming PR. func getFieldStringValue(attrs *workloadidentityv1pb.Attrs, attr string) (string, error) { - // join.gitlab.username attrParts := strings.Split(attr, ".") message := attrs.ProtoReflect() - // TODO: Improve errors by including the fully qualified attribute (e.g - // add up the parts of the attribute path processed thus far) + // TODO(noah): Improve errors by including the fully qualified attribute + // (e.g add up the parts of the attribute path processed thus far) for i, part := range attrParts { fieldDesc := message.Descriptor().Fields().ByTextName(part) if fieldDesc == nil { @@ -158,12 +163,14 @@ func getFieldStringValue(attrs *workloadidentityv1pb.Attrs, attr string) (string return "", nil } -// This place is not a place of honor... -// no highly esteemed deed is commemorated here... -// nothing valued is here. +// templateString takes a given input string and replaces any values within +// {{ }} with values from the attribute set. +// +// If the specified value is not found in the attribute set, an error is +// returned. // -// What is here was dangerous and repulsive to us. -// This message is a warning about danger. +// TODO(noah): In a coming PR, this will be replaced by evaluating the values +// within the handlebars as expressions. func templateString(in string, attrs *workloadidentityv1pb.Attrs) (string, error) { re := regexp.MustCompile(`\{\{(.*?)\}\}`) matches := re.FindAllStringSubmatch(in, -1) @@ -227,10 +234,16 @@ func (s *IssuanceService) deriveAttrs( return attrs, nil } +var defaultMaxTTL = 24 * time.Hour + func (s *IssuanceService) IssueWorkloadIdentity( ctx context.Context, req *workloadidentityv1pb.IssueWorkloadIdentityRequest, ) (*workloadidentityv1pb.IssueWorkloadIdentityResponse, error) { + if !experiment.Enabled() { + return nil, trace.AccessDenied("workload identity issuance experiment is disabled") + } + authCtx, err := s.authorizer.Authorize(ctx) if err != nil { return nil, trace.Wrap(err) @@ -284,8 +297,15 @@ func (s *IssuanceService) IssueWorkloadIdentity( return nil, trace.Wrap(err, "templating spec.spiffe.hint") } - // TODO: Actually like calculate the TTL. + // TODO(noah): Add more sophisticated control of the TTL. ttl := time.Hour + if req.RequestedTtl != nil && req.RequestedTtl.AsDuration() != 0 { + ttl := req.RequestedTtl.AsDuration() + if ttl > defaultMaxTTL { + ttl = defaultMaxTTL + } + } + now := s.clock.Now() notBefore := now.Add(-1 * time.Minute) notAfter := now.Add(ttl) diff --git a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go index 58cbb8e478040..3b7a7b1d85759 100644 --- a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go +++ b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go @@ -43,6 +43,7 @@ import ( "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1/experiment" "github.com/gravitational/teleport/lib/cryptosuites" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" @@ -81,9 +82,13 @@ func newTestTLSServer(t testing.TB) (*auth.TestTLSServer, *eventstest.MockRecord } func TestIssueWorkloadIdentity(t *testing.T) { - t.Parallel() + experimentStatus := experiment.Enabled() + defer experiment.SetEnabled(experimentStatus) + experiment.SetEnabled(true) + srv, eventRecorder := newTestTLSServer(t) ctx := context.Background() + clock := srv.Auth().GetClock() // Upsert a fake proxy to ensure we have a public address to use for the // issuer. @@ -234,7 +239,7 @@ func TestIssueWorkloadIdentity(t *testing.T) { protocmp.Transform(), protocmp.IgnoreFields( &workloadidentityv1pb.Credential{}, - "expiry", + "expires_at", ), protocmp.IgnoreOneofs( &workloadidentityv1pb.Credential{}, @@ -242,7 +247,7 @@ func TestIssueWorkloadIdentity(t *testing.T) { ), )) // Check expiry makes sense - require.WithinDuration(t, time.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second) + require.WithinDuration(t, clock.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second) // Check the JWT parsed, err := jwt.ParseSigned(cred.GetJwtSvid().GetJwt()) @@ -259,8 +264,8 @@ func TestIssueWorkloadIdentity(t *testing.T) { require.NotEmpty(t, claims.ID) require.Equal(t, jwt.Audience{"example.com", "test.example.com"}, claims.Audience) require.Equal(t, wantIssuer, claims.Issuer) - require.WithinDuration(t, time.Now().Add(wantTTL), claims.Expiry.Time(), 5*time.Second) - require.WithinDuration(t, time.Now(), claims.IssuedAt.Time(), 5*time.Second) + require.WithinDuration(t, clock.Now().Add(wantTTL), claims.Expiry.Time(), 5*time.Second) + require.WithinDuration(t, clock.Now(), claims.IssuedAt.Time(), 5*time.Second) // Check audit log event evt, ok := eventRecorder.LastEvent().(*events.SPIFFESVIDIssued) @@ -324,7 +329,7 @@ func TestIssueWorkloadIdentity(t *testing.T) { protocmp.Transform(), protocmp.IgnoreFields( &workloadidentityv1pb.Credential{}, - "expiry", + "expires_at", ), protocmp.IgnoreOneofs( &workloadidentityv1pb.Credential{}, @@ -332,7 +337,7 @@ func TestIssueWorkloadIdentity(t *testing.T) { ), )) // Check expiry makes sense - require.WithinDuration(t, time.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second) + require.WithinDuration(t, clock.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second) // Check the X509 cert, err := x509.ParseCertificate(cred.GetX509Svid().GetCert()) @@ -340,9 +345,9 @@ func TestIssueWorkloadIdentity(t *testing.T) { // Check included public key matches require.Equal(t, workloadKey.Public(), cert.PublicKey) // Check cert expiry - require.WithinDuration(t, time.Now().Add(wantTTL), cert.NotAfter, time.Second) + require.WithinDuration(t, clock.Now().Add(wantTTL), cert.NotAfter, time.Second) // Check cert nbf - require.WithinDuration(t, time.Now().Add(-1*time.Minute), cert.NotBefore, time.Second) + require.WithinDuration(t, clock.Now().Add(-1*time.Minute), cert.NotBefore, time.Second) // Check cert TTL require.Equal(t, cert.NotAfter.Sub(cert.NotBefore), wantTTL+time.Minute) @@ -366,7 +371,8 @@ func TestIssueWorkloadIdentity(t *testing.T) { // Check cert signature is valid _, err = cert.Verify(x509.VerifyOptions{ - Roots: spiffeX509CAPool, + Roots: spiffeX509CAPool, + CurrentTime: srv.Auth().GetClock().Now(), }) require.NoError(t, err)