Skip to content

Commit

Permalink
Add experiment flag
Browse files Browse the repository at this point in the history
  • Loading branch information
strideynet committed Dec 17, 2024
1 parent d02bf53 commit 5225a81
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 23 deletions.
41 changes: 41 additions & 0 deletions lib/auth/machineid/workloadidentityv1/experiment/experiment.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package experiment

import (
"os"
"sync"
)

var mu sync.Mutex

var experimentEnabled = os.Getenv("TELEPORT_WORKLOAD_IDENTITY_UX_EXPERIMENT") == "1"

// Enabled returns true if the workload identity UX experiment is
// enabled.
func Enabled() bool {
mu.Lock()
defer mu.Unlock()
return experimentEnabled
}

// SetEnabled sets the experiment enabled flag.
func SetEnabled(enabled bool) {
mu.Lock()
defer mu.Unlock()
experimentEnabled = enabled
}
46 changes: 33 additions & 13 deletions lib/auth/machineid/workloadidentityv1/issuer_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1/experiment"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/jwt"
Expand Down Expand Up @@ -125,16 +126,20 @@ func NewIssuanceService(cfg *IssuanceServiceConfig) (*IssuanceService, error) {
}, nil
}

// getFieldStringValue
// TODO(noah): This is a fairly gnarly first pass of a reflection based
// attribute extraction function. This will eventually be replaced potentially
// by the chosen expression/predicate language mechanism.
// getFieldStringValue returns a string value from the given attribute set.
// The attribute is specified as a dot-separated path to the field in the
// attribute set.
//
// The specified attribute must be a string field. If the attribute is not
// found, an error is returned.
//
// TODO(noah): This function will be replaced by the Teleport predicate language
// in a coming PR.
func getFieldStringValue(attrs *workloadidentityv1pb.Attrs, attr string) (string, error) {
// join.gitlab.username
attrParts := strings.Split(attr, ".")
message := attrs.ProtoReflect()
// TODO: Improve errors by including the fully qualified attribute (e.g
// add up the parts of the attribute path processed thus far)
// TODO(noah): Improve errors by including the fully qualified attribute
// (e.g add up the parts of the attribute path processed thus far)
for i, part := range attrParts {
fieldDesc := message.Descriptor().Fields().ByTextName(part)
if fieldDesc == nil {
Expand All @@ -158,12 +163,14 @@ func getFieldStringValue(attrs *workloadidentityv1pb.Attrs, attr string) (string
return "", nil
}

// This place is not a place of honor...
// no highly esteemed deed is commemorated here...
// nothing valued is here.
// templateString takes a given input string and replaces any values within
// {{ }} with values from the attribute set.
//
// If the specified value is not found in the attribute set, an error is
// returned.
//
// What is here was dangerous and repulsive to us.
// This message is a warning about danger.
// TODO(noah): In a coming PR, this will be replaced by evaluating the values
// within the handlebars as expressions.
func templateString(in string, attrs *workloadidentityv1pb.Attrs) (string, error) {
re := regexp.MustCompile(`\{\{(.*?)\}\}`)
matches := re.FindAllStringSubmatch(in, -1)
Expand Down Expand Up @@ -227,10 +234,16 @@ func (s *IssuanceService) deriveAttrs(
return attrs, nil
}

var defaultMaxTTL = 24 * time.Hour

func (s *IssuanceService) IssueWorkloadIdentity(
ctx context.Context,
req *workloadidentityv1pb.IssueWorkloadIdentityRequest,
) (*workloadidentityv1pb.IssueWorkloadIdentityResponse, error) {
if !experiment.Enabled() {
return nil, trace.AccessDenied("workload identity issuance experiment is disabled")
}

authCtx, err := s.authorizer.Authorize(ctx)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -284,8 +297,15 @@ func (s *IssuanceService) IssueWorkloadIdentity(
return nil, trace.Wrap(err, "templating spec.spiffe.hint")
}

// TODO: Actually like calculate the TTL.
// TODO(noah): Add more sophisticated control of the TTL.
ttl := time.Hour
if req.RequestedTtl != nil && req.RequestedTtl.AsDuration() != 0 {
ttl := req.RequestedTtl.AsDuration()
if ttl > defaultMaxTTL {
ttl = defaultMaxTTL
}
}

now := s.clock.Now()
notBefore := now.Add(-1 * time.Minute)
notAfter := now.Add(ttl)
Expand Down
26 changes: 16 additions & 10 deletions lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1/experiment"
"github.com/gravitational/teleport/lib/cryptosuites"
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
Expand Down Expand Up @@ -81,9 +82,13 @@ func newTestTLSServer(t testing.TB) (*auth.TestTLSServer, *eventstest.MockRecord
}

func TestIssueWorkloadIdentity(t *testing.T) {
t.Parallel()
experimentStatus := experiment.Enabled()
defer experiment.SetEnabled(experimentStatus)
experiment.SetEnabled(true)

srv, eventRecorder := newTestTLSServer(t)
ctx := context.Background()
clock := srv.Auth().GetClock()

// Upsert a fake proxy to ensure we have a public address to use for the
// issuer.
Expand Down Expand Up @@ -234,15 +239,15 @@ func TestIssueWorkloadIdentity(t *testing.T) {
protocmp.Transform(),
protocmp.IgnoreFields(
&workloadidentityv1pb.Credential{},
"expiry",
"expires_at",
),
protocmp.IgnoreOneofs(
&workloadidentityv1pb.Credential{},
"credential",
),
))
// Check expiry makes sense
require.WithinDuration(t, time.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second)
require.WithinDuration(t, clock.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second)

// Check the JWT
parsed, err := jwt.ParseSigned(cred.GetJwtSvid().GetJwt())
Expand All @@ -259,8 +264,8 @@ func TestIssueWorkloadIdentity(t *testing.T) {
require.NotEmpty(t, claims.ID)
require.Equal(t, jwt.Audience{"example.com", "test.example.com"}, claims.Audience)
require.Equal(t, wantIssuer, claims.Issuer)
require.WithinDuration(t, time.Now().Add(wantTTL), claims.Expiry.Time(), 5*time.Second)
require.WithinDuration(t, time.Now(), claims.IssuedAt.Time(), 5*time.Second)
require.WithinDuration(t, clock.Now().Add(wantTTL), claims.Expiry.Time(), 5*time.Second)
require.WithinDuration(t, clock.Now(), claims.IssuedAt.Time(), 5*time.Second)

// Check audit log event
evt, ok := eventRecorder.LastEvent().(*events.SPIFFESVIDIssued)
Expand Down Expand Up @@ -324,25 +329,25 @@ func TestIssueWorkloadIdentity(t *testing.T) {
protocmp.Transform(),
protocmp.IgnoreFields(
&workloadidentityv1pb.Credential{},
"expiry",
"expires_at",
),
protocmp.IgnoreOneofs(
&workloadidentityv1pb.Credential{},
"credential",
),
))
// Check expiry makes sense
require.WithinDuration(t, time.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second)
require.WithinDuration(t, clock.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second)

// Check the X509
cert, err := x509.ParseCertificate(cred.GetX509Svid().GetCert())
require.NoError(t, err)
// Check included public key matches
require.Equal(t, workloadKey.Public(), cert.PublicKey)
// Check cert expiry
require.WithinDuration(t, time.Now().Add(wantTTL), cert.NotAfter, time.Second)
require.WithinDuration(t, clock.Now().Add(wantTTL), cert.NotAfter, time.Second)
// Check cert nbf
require.WithinDuration(t, time.Now().Add(-1*time.Minute), cert.NotBefore, time.Second)
require.WithinDuration(t, clock.Now().Add(-1*time.Minute), cert.NotBefore, time.Second)
// Check cert TTL
require.Equal(t, cert.NotAfter.Sub(cert.NotBefore), wantTTL+time.Minute)

Expand All @@ -366,7 +371,8 @@ func TestIssueWorkloadIdentity(t *testing.T) {

// Check cert signature is valid
_, err = cert.Verify(x509.VerifyOptions{
Roots: spiffeX509CAPool,
Roots: spiffeX509CAPool,
CurrentTime: srv.Auth().GetClock().Now(),
})
require.NoError(t, err)

Expand Down

0 comments on commit 5225a81

Please sign in to comment.