diff --git a/cmd/baton-aws/config.go b/cmd/baton-aws/config.go index aeb67041..474b90f0 100644 --- a/cmd/baton-aws/config.go +++ b/cmd/baton-aws/config.go @@ -29,6 +29,7 @@ type config struct { SCIMEndpoint string `mapstructure:"scim-endpoint"` SCIMToken string `mapstructure:"scim-token"` + SCIMEnabled bool `mapstructure:"scim-enabled"` UseAssumeRole bool `mapstructure:"use-assume-role"` } @@ -66,6 +67,7 @@ func cmdFlags(cmd *cobra.Command) { cmd.PersistentFlags().String("global-secret-access-key", "", "The global-secret-access-key for the aws account. ($BATON_GLOBAL_SECRET_ACCESS_KEY)") cmd.PersistentFlags().String("global-access-key-id", "", "The global-access-key-id for the aws account. ($BATON_GLOBAL_ACCESS_KEY_ID)") cmd.PersistentFlags().Bool("use-assume-role", false, "Enable support for assume role. ($BATON_GLOBAL_USE_ASSUME_ROLE)") + cmd.PersistentFlags().Bool("scim-enabled", false, "Enable support for pulling SSO User status from the AWS SCIM API. ($BATON_SCIM_ENABLED)") cmd.PersistentFlags().String("scim-endpoint", "", "The SCIMv2 endpoint for aws identity center. ($BATON_SCIM_ENDPOINT)") cmd.PersistentFlags().String("scim-token", "", "The SCIMv2 token for aws identity center. ($BATON_SCIM_TOKEN)") } diff --git a/cmd/baton-aws/main.go b/cmd/baton-aws/main.go index cc17597b..efca1ffd 100644 --- a/cmd/baton-aws/main.go +++ b/cmd/baton-aws/main.go @@ -53,6 +53,7 @@ func getConnector(ctx context.Context, cfg *config) (types.ConnectorServer, erro RoleARN: cfg.RoleARN, SCIMEndpoint: cfg.SCIMEndpoint, SCIMToken: cfg.SCIMToken, + SCIMEnabled: cfg.SCIMEnabled, } cb, err := connector.New(ctx, config) diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index b7f52c2d..2e5659b2 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -88,6 +88,7 @@ type Config struct { RoleARN string SCIMToken string SCIMEndpoint string + SCIMEnabled bool } type AWS struct { @@ -273,7 +274,7 @@ func New(ctx context.Context, config Config) (*AWS, error) { ssoRegion: config.GlobalAwsSsoRegion, scimEndpoint: config.SCIMEndpoint, scimToken: config.SCIMToken, - scimEnabled: config.SCIMEndpoint != "" && config.SCIMToken != "", + scimEnabled: config.SCIMEnabled, baseClient: httpClient, baseConfig: baseConfig.Copy(), _onceCallingConfig: map[string]*sync.Once{},