diff --git a/cmd/baton-aws/config.go b/cmd/baton-aws/config.go index f5d1ee1a..aeb67041 100644 --- a/cmd/baton-aws/config.go +++ b/cmd/baton-aws/config.go @@ -27,6 +27,9 @@ type config struct { GlobalAwsSsoEnabled bool `mapstructure:"global-aws-sso-enabled"` GlobalAwsOrgsEnabled bool `mapstructure:"global-aws-orgs-enabled"` + SCIMEndpoint string `mapstructure:"scim-endpoint"` + SCIMToken string `mapstructure:"scim-token"` + UseAssumeRole bool `mapstructure:"use-assume-role"` } @@ -63,4 +66,6 @@ func cmdFlags(cmd *cobra.Command) { cmd.PersistentFlags().String("global-secret-access-key", "", "The global-secret-access-key for the aws account. ($BATON_GLOBAL_SECRET_ACCESS_KEY)") cmd.PersistentFlags().String("global-access-key-id", "", "The global-access-key-id for the aws account. ($BATON_GLOBAL_ACCESS_KEY_ID)") cmd.PersistentFlags().Bool("use-assume-role", false, "Enable support for assume role. ($BATON_GLOBAL_USE_ASSUME_ROLE)") + cmd.PersistentFlags().String("scim-endpoint", "", "The SCIMv2 endpoint for aws identity center. ($BATON_SCIM_ENDPOINT)") + cmd.PersistentFlags().String("scim-token", "", "The SCIMv2 token for aws identity center. ($BATON_SCIM_TOKEN)") } diff --git a/cmd/baton-aws/main.go b/cmd/baton-aws/main.go index d42ced75..cc17597b 100644 --- a/cmd/baton-aws/main.go +++ b/cmd/baton-aws/main.go @@ -51,6 +51,8 @@ func getConnector(ctx context.Context, cfg *config) (types.ConnectorServer, erro GlobalAwsSsoEnabled: cfg.GlobalAwsSsoEnabled, ExternalID: cfg.ExternalID, RoleARN: cfg.RoleARN, + SCIMEndpoint: cfg.SCIMEndpoint, + SCIMToken: cfg.SCIMToken, } cb, err := connector.New(ctx, config) diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index 35eb6d0f..b7f52c2d 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "net/url" "sync" awsSdk "github.com/aws/aws-sdk-go-v2/aws" @@ -85,14 +86,19 @@ type Config struct { GlobalAwsSsoEnabled bool ExternalID string RoleARN string + SCIMToken string + SCIMEndpoint string } type AWS struct { useAssumeRole bool orgsEnabled bool ssoEnabled bool - globalRegion string ssoRegion string + scimEnabled bool + scimToken string + scimEndpoint string + globalRegion string roleARN string externalID string globalBindingExternalID string @@ -134,6 +140,30 @@ func (o *AWS) ssoAdminClient(ctx context.Context) (*awsSsoAdmin.Client, error) { return awsSsoAdmin.NewFromConfig(callingConfig), nil } +func (o *AWS) ssoSCIMClient(ctx context.Context) (*awsIdentityCenterSCIMClient, error) { + if !o.scimEnabled { + return nil, nil + } + + normalizedEndpoint, err := NormalizeAWSIdentityCenterSCIMUrl(o.scimEndpoint) + if err != nil { + return nil, fmt.Errorf("aws-connector-scim: invalid endpoint: %w", err) + } + ep, err := url.Parse(normalizedEndpoint) + if err != nil { + return nil, fmt.Errorf("aws-connector-scim: invalid endpoint: %w", err) + } + if len(o.scimToken) == 0 { + return nil, fmt.Errorf("aws-connector-scim: token is required") + } + return &awsIdentityCenterSCIMClient{ + Client: o.baseClient, + Endpoint: ep, + Token: o.scimToken, + scimEnabled: o.scimEnabled, + }, nil +} + func (o *AWS) stsClient(ctx context.Context) (*sts.Client, error) { callingConfig, err := o.getCallingConfig(ctx, o.globalRegion) if err != nil { @@ -241,6 +271,9 @@ func New(ctx context.Context, config Config) (*AWS, error) { globalAccessKeyID: config.GlobalAccessKeyID, globalSecretAccessKey: config.GlobalSecretAccessKey, ssoRegion: config.GlobalAwsSsoRegion, + scimEndpoint: config.SCIMEndpoint, + scimToken: config.SCIMToken, + scimEnabled: config.SCIMEndpoint != "" && config.SCIMToken != "", baseClient: httpClient, baseConfig: baseConfig.Copy(), _onceCallingConfig: map[string]*sync.Once{}, @@ -332,8 +365,12 @@ func (c *AWS) ResourceSyncers(ctx context.Context) []connectorbuilder.ResourceSy if err != nil { return rs } + scimClient, err := c.ssoSCIMClient(ctx) + if err != nil { + return rs + } if c.ssoEnabled { - rs = append(rs, ssoUserBuilder(c.ssoRegion, ssoAdminClient, identityStoreClient, ix)) + rs = append(rs, ssoUserBuilder(c.ssoRegion, ssoAdminClient, identityStoreClient, ix, scimClient)) rs = append(rs, ssoGroupBuilder(c.ssoRegion, ssoAdminClient, identityStoreClient, ix)) } if c.orgsEnabled { diff --git a/pkg/connector/iam_group.go b/pkg/connector/iam_group.go index 10256fd0..bb349226 100644 --- a/pkg/connector/iam_group.go +++ b/pkg/connector/iam_group.go @@ -142,17 +142,11 @@ func (o *iamGroupResourceType) Grants(ctx context.Context, resource *v2.Resource return rv, "", nil, nil } - // TODO(lauren) update connector-sdk version and simplify this by just calling bag.NextToken - err = bag.Next(awsSdk.ToString(resp.Marker)) + nextPage, err := bag.NextToken(awsSdk.ToString(resp.Marker)) if err != nil { return nil, "", nil, err } - nextPage, err := bag.Marshal() - if err != nil { - return nil, "", nil, fmt.Errorf("aws-connector: failed to marshal pagination bag: %w", err) - } - return rv, nextPage, nil, nil } diff --git a/pkg/connector/iam_user.go b/pkg/connector/iam_user.go index 181e7a4c..0a7b91c4 100644 --- a/pkg/connector/iam_user.go +++ b/pkg/connector/iam_user.go @@ -67,17 +67,11 @@ func (o *iamUserResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *pa return rv, "", nil, nil } - // TODO(lauren) update connector-sdk version and simplify this by just calling bag.NextToken - err = bag.Next(awsSdk.ToString(resp.Marker)) + nextPage, err := bag.NextToken(awsSdk.ToString(resp.Marker)) if err != nil { return nil, "", nil, err } - nextPage, err := bag.Marshal() - if err != nil { - return nil, "", nil, fmt.Errorf("aws-connector: failed to marshal pagination bag: %w", err) - } - return rv, nextPage, nil, nil } diff --git a/pkg/connector/role.go b/pkg/connector/role.go index b5003b22..bfa4322c 100644 --- a/pkg/connector/role.go +++ b/pkg/connector/role.go @@ -74,17 +74,11 @@ func (o *roleResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *pagin return rv, "", nil, nil } - // TODO(lauren) update connector-sdk version and simplify this by just calling bag.NextToken - err = bag.Next(awsSdk.ToString(resp.Marker)) + nextPage, err := bag.NextToken(awsSdk.ToString(resp.Marker)) if err != nil { return nil, "", nil, err } - nextPage, err := bag.Marshal() - if err != nil { - return nil, "", nil, fmt.Errorf("aws-connector: failed to marshal pagination bag: %w", err) - } - return rv, nextPage, nil, nil } diff --git a/pkg/connector/sso_group.go b/pkg/connector/sso_group.go index 779eb971..a8c2fca5 100644 --- a/pkg/connector/sso_group.go +++ b/pkg/connector/sso_group.go @@ -76,17 +76,11 @@ func (o *ssoGroupResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *p rv = append(rv, groupResource) } - // TODO(lauren) update connector-sdk version and simplify this by just calling bag.NextToken - err = bag.Next(awsSdk.ToString(resp.NextToken)) + nextPage, err := bag.NextToken(awsSdk.ToString(resp.NextToken)) if err != nil { return nil, "", nil, err } - nextPage, err := bag.Marshal() - if err != nil { - return nil, "", nil, fmt.Errorf("aws-connector: failed to marshal pagination bag: %w", err) - } - return rv, nextPage, nil, nil } diff --git a/pkg/connector/sso_user.go b/pkg/connector/sso_user.go index 98bdb8b9..c7914576 100644 --- a/pkg/connector/sso_user.go +++ b/pkg/connector/sso_user.go @@ -21,6 +21,7 @@ type ssoUserResourceType struct { ssoClient *awsSsoAdmin.Client identityStoreClient *awsIdentityStore.Client identityInstance *awsSsoAdminTypes.InstanceMetadata + scimClient *awsIdentityCenterSCIMClient region string } @@ -56,6 +57,10 @@ func (o *ssoUserResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *pa rv := make([]*v2.Resource, 0, len(resp.Users)) for _, user := range resp.Users { + status, err := o.scimClient.getUserStatus(ctx, awsSdk.ToString(user.UserId)) + if err != nil { + return nil, "", nil, fmt.Errorf("aws-connector: failed to get user status from scim: %w", err) + } userARN := ssoUserToARN(o.region, awsSdk.ToString(o.identityInstance.IdentityStoreId), awsSdk.ToString(user.UserId)) annos := &v2.V1Identifier{ Id: userARN, @@ -68,6 +73,7 @@ func (o *ssoUserResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *pa []resourceSdk.UserTraitOption{ resourceSdk.WithEmail(getSsoUserEmail(user), true), resourceSdk.WithUserProfile(profile), + resourceSdk.WithStatus(status), }, resourceSdk.WithAnnotation(annos), ) @@ -77,17 +83,11 @@ func (o *ssoUserResourceType) List(ctx context.Context, _ *v2.ResourceId, pt *pa rv = append(rv, userResource) } - // TODO(lauren) update connector-sdk version and simplify this by just calling bag.NextToken - err = bag.Next(awsSdk.ToString(resp.NextToken)) + nextPage, err := bag.NextToken(awsSdk.ToString(resp.NextToken)) if err != nil { return nil, "", nil, err } - nextPage, err := bag.Marshal() - if err != nil { - return nil, "", nil, fmt.Errorf("aws-connector: failed to marshal pagination bag: %w", err) - } - return rv, nextPage, nil, nil } @@ -99,13 +99,20 @@ func (o *ssoUserResourceType) Grants(_ context.Context, _ *v2.Resource, _ *pagin return nil, "", nil, nil } -func ssoUserBuilder(region string, ssoClient *awsSsoAdmin.Client, identityStoreClient *awsIdentityStore.Client, identityInstance *awsSsoAdminTypes.InstanceMetadata) *ssoUserResourceType { +func ssoUserBuilder( + region string, + ssoClient *awsSsoAdmin.Client, + identityStoreClient *awsIdentityStore.Client, + identityInstance *awsSsoAdminTypes.InstanceMetadata, + scimClient *awsIdentityCenterSCIMClient, +) *ssoUserResourceType { return &ssoUserResourceType{ resourceType: resourceTypeSSOUser, region: region, identityInstance: identityInstance, identityStoreClient: identityStoreClient, ssoClient: ssoClient, + scimClient: scimClient, } } diff --git a/pkg/connector/sso_user_scim.go b/pkg/connector/sso_user_scim.go new file mode 100644 index 00000000..6370af90 --- /dev/null +++ b/pkg/connector/sso_user_scim.go @@ -0,0 +1,248 @@ +package connector + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "path" + "regexp" + "strings" + "time" + + v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" +) + +var awsSSOSCIMUIDPattern = regexp.MustCompile("^[0-9a-zA-Z-]{1,64}$") + +type SCIMUserEmail struct { + Value string `json:"value"` + Type string `json:"type"` + Primary bool `json:"primary"` +} + +type SCIMUserAddress struct { + Type string `json:"type"` +} + +// SCIMUser is an AWS Identity Center SCIM User. +type SCIMUser struct { + ID string `json:"id,omitempty"` + Schemas []string `json:"schemas"` + Username string `json:"userName"` + Name struct { + FamilyName string `json:"familyName"` + GivenName string `json:"givenName"` + } `json:"name"` + DisplayName string `json:"displayName"` + Active bool `json:"active"` + Emails []SCIMUserEmail `json:"emails"` + Addresses []SCIMUserAddress `json:"addresses"` +} + +type awsIdentityCenterSCIMClient struct { + scimEnabled bool + + Client *http.Client + Endpoint *url.URL + Token string +} + +type ssoSCIMRetrier struct { + attempts int64 +} + +func newSSOSCIMRetrier() *ssoSCIMRetrier { + return &ssoSCIMRetrier{} +} + +// Will try 3 times over 120 ms. +func (r *ssoSCIMRetrier) wait(ctx context.Context) bool { + if r.attempts >= 3 { + return false + } + r.attempts++ + + select { + case <-ctx.Done(): + return false + case <-time.After(20 * time.Millisecond * time.Duration(r.attempts)): // 20ms, 40ms, 60ms, ... + return true + } +} + +func (sc *awsIdentityCenterSCIMClient) get(ctx context.Context, path string, target interface{}) error { + endpoint := strings.TrimRight(sc.Endpoint.String(), "/") + path = strings.TrimLeft(path, "/") + path = endpoint + "/" + path + var retry *ssoSCIMRetrier + for { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, path, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+sc.Token) + req.Header.Set("Accept", "application/scim+json") + + resp, err := sc.Client.Do(req) + if err != nil { + return err + } + b, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + _ = resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + err := json.Unmarshal(b, target) + if err != nil { + return fmt.Errorf("get: failed to decode response body for '%s': %w", path, err) + } + return nil + case http.StatusTooManyRequests: + // NOTE(morgabra) We don't get any headers back from AWS about this, but the docs say it's possible, so we'll aggressively retry + // a little bit in here and then give up. This has to be pretty aggressive because we call this function for each user in a page. + if retry == nil { + retry = newSSOSCIMRetrier() + } + + ok := retry.wait(ctx) + if !ok { + return fmt.Errorf("get: too many requests (%d)", resp.StatusCode) + } + continue + default: + return fmt.Errorf("get: request status !2xx (%d): %s", resp.StatusCode, b) + } + } +} + +func (sc *awsIdentityCenterSCIMClient) getUser(ctx context.Context, userID string) (*SCIMUser, error) { + scimPath := path.Join("Users", userID) + + user := &SCIMUser{} + err := sc.get(ctx, scimPath, user) + if err != nil { + return nil, fmt.Errorf("aws-connector-scim: failed to get user '%s': %w", userID, err) + } + + if user.ID != userID { + return nil, fmt.Errorf("aws-connector-scim: user id mismatch: got:%s want:%s", user.ID, userID) + } + + return user, nil +} + +func (sc *awsIdentityCenterSCIMClient) getUserStatus(ctx context.Context, userID string) (v2.UserTrait_Status_Status, error) { + status := v2.UserTrait_Status_STATUS_ENABLED + + // If SCIM is enabled, we can fetch the user status from the SCIM API because it's not available in the SSO API. + // This is tragic because the identitystore API is missing the active attribute on the user datatype. + // This extra tragic because pagination doesn't work for SCIM endpoints either, so we can't just use it as the source of truth, + // so we're doomed to making requests. + // https://repost.aws/questions/QUTLAhQGa4ReatoAnQSkx11w/iam-identity-center-identitystore-api-is-missing-the-active-attribute-on-user-datatype + if sc.scimEnabled { + scimUser, err := sc.getUser(ctx, userID) + if err != nil { + return v2.UserTrait_Status_STATUS_UNSPECIFIED, fmt.Errorf("aws-connector: scim.GetUser failed: %w", err) + } + + if scimUser.Active { + status = v2.UserTrait_Status_STATUS_ENABLED + } else { + status = v2.UserTrait_Status_STATUS_DISABLED + } + } + + return status, nil +} + +// NormalizeAWSIdentityCenterSCIMUrl normalizes the AWS Identity Center SCIM URL. +// e.x. https://scim..amazonaws.com/aAaAaAaAaAa-bBbB-cCcC-dDdD-eEeEeEeEeEeE/scim/v2 +func NormalizeAWSIdentityCenterSCIMUrl(u string) (string, error) { + if !strings.Contains(u, "//") { + u = "https://" + u + } + + p, err := url.Parse(u) + if err != nil { + return "", err + } + + if p.Scheme != "https" { + return "", fmt.Errorf("aws-connector-scim: invalid scheme: expected 'https'") + } + + // Host is exactly 'scim..amazonaws.com' + host := strings.ToLower(p.Host) + parts := strings.SplitN(host, ".", 4) + if len(parts) != 4 { + return "", fmt.Errorf("aws-connector-scim: invalid host: expected 'scim..amazonaws.com") + } + if parts[0] != "scim" || parts[2] != "amazonaws" || parts[3] != "com" { + return "", fmt.Errorf("aws-connector-scim: invalid host: expected 'scim..amazonaws.com") + } + if !isRegion(parts[1]) { + return "", fmt.Errorf("aws-connector-scim: invalid host: expected 'scim..amazonaws.com") + } + + // Path is exactly '//scim/v2' + path := p.Path + path = strings.TrimRight(path, "/") + parts = strings.SplitN(path, "/", 4) + if len(parts) != 4 { + return "", fmt.Errorf("aws-connector-scim: invalid path: expected '//scim/v2'") + } + + if !awsSSOSCIMUIDPattern.Match([]byte(parts[1])) { + return "", fmt.Errorf("aws-connector-scim: invalid path: expected '//scim/v2'") + } + parts[2] = strings.ToLower(parts[2]) + parts[3] = strings.ToLower(parts[3]) + if parts[0] != "" || parts[2] != "scim" || parts[3] != "v2" { + return "", fmt.Errorf("aws-connector-scim: invalid path: expected '//scim/v2'") + } + path = strings.Join(parts, "/") + + p = &url.URL{ + Scheme: "https", + Host: host, + Path: path, + } + + return p.String(), nil +} +func isRegion(region string) bool { + _, ok := regions[region] + return ok +} + +var regions = map[string]struct{}{ + "us-east-2": {}, + "us-east-1": {}, + "us-west-1": {}, + "us-west-2": {}, + "af-south-1": {}, + "ap-east-1": {}, + "ap-southeast-3": {}, + "ap-south-1": {}, + "ap-northeast-3": {}, + "ap-northeast-2": {}, + "ap-southeast-1": {}, + "ap-southeast-2": {}, + "ap-northeast-1": {}, + "ca-central-1": {}, + "eu-central-1": {}, + "eu-west-1": {}, + "eu-west-2": {}, + "eu-south-1": {}, + "eu-west-3": {}, + "eu-north-1": {}, + "me-south-1": {}, + "me-central-1": {}, + "sa-east-1": {}, +}