Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SCIM endpoints to get user status #21

Merged
merged 3 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmd/baton-aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ type config struct {
GlobalAwsSsoEnabled bool `mapstructure:"global-aws-sso-enabled"`
GlobalAwsOrgsEnabled bool `mapstructure:"global-aws-orgs-enabled"`

SCIMEndpoint string `mapstructure:"scim-endpoint"`
SCIMToken string `mapstructure:"scim-token"`

UseAssumeRole bool `mapstructure:"use-assume-role"`
}

Expand Down Expand Up @@ -63,4 +66,6 @@ func cmdFlags(cmd *cobra.Command) {
cmd.PersistentFlags().String("global-secret-access-key", "", "The global-secret-access-key for the aws account. ($BATON_GLOBAL_SECRET_ACCESS_KEY)")
cmd.PersistentFlags().String("global-access-key-id", "", "The global-access-key-id for the aws account. ($BATON_GLOBAL_ACCESS_KEY_ID)")
cmd.PersistentFlags().Bool("use-assume-role", false, "Enable support for assume role. ($BATON_GLOBAL_USE_ASSUME_ROLE)")
cmd.PersistentFlags().String("scim-endpoint", "", "The SCIMv2 endpoint for aws identity center. ($BATON_SCIM_ENDPOINT)")
cmd.PersistentFlags().String("scim-token", "", "The SCIMv2 token for aws identity center. ($BATON_SCIM_TOKEN)")
}
2 changes: 2 additions & 0 deletions cmd/baton-aws/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ func getConnector(ctx context.Context, cfg *config) (types.ConnectorServer, erro
GlobalAwsSsoEnabled: cfg.GlobalAwsSsoEnabled,
ExternalID: cfg.ExternalID,
RoleARN: cfg.RoleARN,
SCIMEndpoint: cfg.SCIMEndpoint,
SCIMToken: cfg.SCIMToken,
}

cb, err := connector.New(ctx, config)
Expand Down
41 changes: 39 additions & 2 deletions pkg/connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"sync"

awsSdk "github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -85,14 +86,19 @@ type Config struct {
GlobalAwsSsoEnabled bool
ExternalID string
RoleARN string
SCIMToken string
SCIMEndpoint string
}

type AWS struct {
useAssumeRole bool
orgsEnabled bool
ssoEnabled bool
globalRegion string
ssoRegion string
scimEnabled bool
scimToken string
scimEndpoint string
globalRegion string
roleARN string
externalID string
globalBindingExternalID string
Expand Down Expand Up @@ -134,6 +140,30 @@ func (o *AWS) ssoAdminClient(ctx context.Context) (*awsSsoAdmin.Client, error) {
return awsSsoAdmin.NewFromConfig(callingConfig), nil
}

func (o *AWS) ssoSCIMClient(ctx context.Context) (*awsIdentityCenterSCIMClient, error) {
if !o.scimEnabled {
return nil, nil
}

normalizedEndpoint, err := NormalizeAWSIdentityCenterSCIMUrl(o.scimEndpoint)
if err != nil {
return nil, fmt.Errorf("aws-connector-scim: invalid endpoint: %w", err)
}
ep, err := url.Parse(normalizedEndpoint)
if err != nil {
return nil, fmt.Errorf("aws-connector-scim: invalid endpoint: %w", err)
}
if len(o.scimToken) == 0 {
return nil, fmt.Errorf("aws-connector-scim: token is required")
}
return &awsIdentityCenterSCIMClient{
Client: o.baseClient,
Endpoint: ep,
Token: o.scimToken,
scimEnabled: o.scimEnabled,
}, nil
}

func (o *AWS) stsClient(ctx context.Context) (*sts.Client, error) {
callingConfig, err := o.getCallingConfig(ctx, o.globalRegion)
if err != nil {
Expand Down Expand Up @@ -241,6 +271,9 @@ func New(ctx context.Context, config Config) (*AWS, error) {
globalAccessKeyID: config.GlobalAccessKeyID,
globalSecretAccessKey: config.GlobalSecretAccessKey,
ssoRegion: config.GlobalAwsSsoRegion,
scimEndpoint: config.SCIMEndpoint,
scimToken: config.SCIMToken,
scimEnabled: config.SCIMEndpoint != "" && config.SCIMToken != "",
baseClient: httpClient,
baseConfig: baseConfig.Copy(),
_onceCallingConfig: map[string]*sync.Once{},
Expand Down Expand Up @@ -332,8 +365,12 @@ func (c *AWS) ResourceSyncers(ctx context.Context) []connectorbuilder.ResourceSy
if err != nil {
return rs
}
scimClient, err := c.ssoSCIMClient(ctx)
if err != nil {
return rs
}
if c.ssoEnabled {
rs = append(rs, ssoUserBuilder(c.ssoRegion, ssoAdminClient, identityStoreClient, ix))
rs = append(rs, ssoUserBuilder(c.ssoRegion, ssoAdminClient, identityStoreClient, ix, scimClient))
rs = append(rs, ssoGroupBuilder(c.ssoRegion, ssoAdminClient, identityStoreClient, ix))
}
if c.orgsEnabled {
Expand Down
8 changes: 1 addition & 7 deletions pkg/connector/iam_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,11 @@ func (o *iamGroupResourceType) Grants(ctx context.Context, resource *v2.Resource
return rv, "", nil, nil
}

// TODO(lauren) update connector-sdk version and simplify this by just calling bag.NextToken
err = bag.Next(awsSdk.ToString(resp.Marker))
nextPage, err := bag.NextToken(awsSdk.ToString(resp.Marker))
if err != nil {
return nil, "", nil, err
}

nextPage, err := bag.Marshal()
if err != nil {
return nil, "", nil, fmt.Errorf("aws-connector: failed to marshal pagination bag: %w", err)
}

return rv, nextPage, nil, nil
}

Expand Down
8 changes: 1 addition & 7 deletions pkg/connector/iam_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,11 @@ func (o *iamUserResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *pa
return rv, "", nil, nil
}

// TODO(lauren) update connector-sdk version and simplify this by just calling bag.NextToken
err = bag.Next(awsSdk.ToString(resp.Marker))
nextPage, err := bag.NextToken(awsSdk.ToString(resp.Marker))
if err != nil {
return nil, "", nil, err
}

nextPage, err := bag.Marshal()
if err != nil {
return nil, "", nil, fmt.Errorf("aws-connector: failed to marshal pagination bag: %w", err)
}

return rv, nextPage, nil, nil
}

Expand Down
8 changes: 1 addition & 7 deletions pkg/connector/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,11 @@ func (o *roleResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *pagin
return rv, "", nil, nil
}

// TODO(lauren) update connector-sdk version and simplify this by just calling bag.NextToken
err = bag.Next(awsSdk.ToString(resp.Marker))
nextPage, err := bag.NextToken(awsSdk.ToString(resp.Marker))
if err != nil {
return nil, "", nil, err
}

nextPage, err := bag.Marshal()
if err != nil {
return nil, "", nil, fmt.Errorf("aws-connector: failed to marshal pagination bag: %w", err)
}

return rv, nextPage, nil, nil
}

Expand Down
8 changes: 1 addition & 7 deletions pkg/connector/sso_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,11 @@ func (o *ssoGroupResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *p
rv = append(rv, groupResource)
}

// TODO(lauren) update connector-sdk version and simplify this by just calling bag.NextToken
err = bag.Next(awsSdk.ToString(resp.NextToken))
nextPage, err := bag.NextToken(awsSdk.ToString(resp.NextToken))
if err != nil {
return nil, "", nil, err
}

nextPage, err := bag.Marshal()
if err != nil {
return nil, "", nil, fmt.Errorf("aws-connector: failed to marshal pagination bag: %w", err)
}

return rv, nextPage, nil, nil
}

Expand Down
23 changes: 15 additions & 8 deletions pkg/connector/sso_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type ssoUserResourceType struct {
ssoClient *awsSsoAdmin.Client
identityStoreClient *awsIdentityStore.Client
identityInstance *awsSsoAdminTypes.InstanceMetadata
scimClient *awsIdentityCenterSCIMClient
region string
}

Expand Down Expand Up @@ -56,6 +57,10 @@ func (o *ssoUserResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *pa

rv := make([]*v2.Resource, 0, len(resp.Users))
for _, user := range resp.Users {
status, err := o.scimClient.getUserStatus(ctx, awsSdk.ToString(user.UserId))
if err != nil {
return nil, "", nil, fmt.Errorf("aws-connector: failed to get user status from scim: %w", err)
}
userARN := ssoUserToARN(o.region, awsSdk.ToString(o.identityInstance.IdentityStoreId), awsSdk.ToString(user.UserId))
annos := &v2.V1Identifier{
Id: userARN,
Expand All @@ -68,6 +73,7 @@ func (o *ssoUserResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *pa
[]resourceSdk.UserTraitOption{
resourceSdk.WithEmail(getSsoUserEmail(user), true),
resourceSdk.WithUserProfile(profile),
resourceSdk.WithStatus(status),
},
resourceSdk.WithAnnotation(annos),
)
Expand All @@ -77,17 +83,11 @@ func (o *ssoUserResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *pa
rv = append(rv, userResource)
}

// TODO(lauren) update connector-sdk version and simplify this by just calling bag.NextToken
err = bag.Next(awsSdk.ToString(resp.NextToken))
nextPage, err := bag.NextToken(awsSdk.ToString(resp.NextToken))
if err != nil {
return nil, "", nil, err
}

nextPage, err := bag.Marshal()
if err != nil {
return nil, "", nil, fmt.Errorf("aws-connector: failed to marshal pagination bag: %w", err)
}

return rv, nextPage, nil, nil
}

Expand All @@ -99,13 +99,20 @@ func (o *ssoUserResourceType) Grants(_ context.Context, _ *v2.Resource, _ *pagin
return nil, "", nil, nil
}

func ssoUserBuilder(region string, ssoClient *awsSsoAdmin.Client, identityStoreClient *awsIdentityStore.Client, identityInstance *awsSsoAdminTypes.InstanceMetadata) *ssoUserResourceType {
func ssoUserBuilder(
region string,
ssoClient *awsSsoAdmin.Client,
identityStoreClient *awsIdentityStore.Client,
identityInstance *awsSsoAdminTypes.InstanceMetadata,
scimClient *awsIdentityCenterSCIMClient,
) *ssoUserResourceType {
return &ssoUserResourceType{
resourceType: resourceTypeSSOUser,
region: region,
identityInstance: identityInstance,
identityStoreClient: identityStoreClient,
ssoClient: ssoClient,
scimClient: scimClient,
}
}

Expand Down
Loading
Loading