Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 124 additions & 9 deletions tool/tctl/common/plugin/awsic.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"log/slog"
"net/http"
"net/url"
"slices"

Expand All @@ -41,9 +42,10 @@ const (
awsicPluginNameHelp = "Name of the AWS Identity Center integration instance to update. Defaults to " + apicommon.OriginAWSIdentityCenter + "."
awsicRolesSyncModeFlag = "roles-sync-mode"
awsicRolesSyncModeHelp = "Control account-assignment role creation. ALL creates roles for all possible account assignments. NONE creates no roles, and also implies a totally-exclusive group import filter."
notAWSICPluginMsg = "%q is not an AWS Identity Center integration"
)

type awsICArgs struct {
type awsICInstallArgs struct {
cmd *kingpin.CmdClause
defaultOwners []string
scimToken string
Expand All @@ -64,7 +66,7 @@ type awsICArgs struct {
excludeAccountIDFilters []string
}

func (a *awsICArgs) validate(ctx context.Context, log *slog.Logger) error {
func (a *awsICInstallArgs) validate(ctx context.Context, log *slog.Logger) error {
if !awsutils.IsKnownRegion(a.region) {
return trace.BadParameter("unknown AWS region: %s", a.region)
}
Expand All @@ -84,7 +86,7 @@ func (a *awsICArgs) validate(ctx context.Context, log *slog.Logger) error {
return nil
}

func (a *awsICArgs) validateSystemCredentialInput() error {
func (a *awsICInstallArgs) validateSystemCredentialInput() error {
if !a.useSystemCredentials {
return trace.BadParameter("--use-system-credentials must be set. The tctl-based AWS IAM Identity Center plugin installation only supports AWS local system credentials")
}
Expand All @@ -100,7 +102,7 @@ func (a *awsICArgs) validateSystemCredentialInput() error {
return nil
}

func (a *awsICArgs) validateSCIMBaseURL(ctx context.Context, log *slog.Logger) error {
func (a *awsICInstallArgs) validateSCIMBaseURL(ctx context.Context, log *slog.Logger) error {
validatedBaseUrl, err := icutils.EnsureSCIMEndpointURL(a.scimURL)
if err == nil {
a.scimURL = validatedBaseUrl
Expand All @@ -116,7 +118,7 @@ func (a *awsICArgs) validateSCIMBaseURL(ctx context.Context, log *slog.Logger) e
return trace.Wrap(err)
}

func (a *awsICArgs) parseGroupFilters() (icfilters.Filters, error) {
func (a *awsICInstallArgs) parseGroupFilters() (icfilters.Filters, error) {
filters := make([]*types.AWSICResourceFilter, 0, len(a.groupNameFilters)+len(a.excludeGroupNameFilters))
for _, n := range a.groupNameFilters {
filters = append(filters, &types.AWSICResourceFilter{
Expand All @@ -131,7 +133,7 @@ func (a *awsICArgs) parseGroupFilters() (icfilters.Filters, error) {
return icfilters.New(filters)
}

func (a *awsICArgs) parseAccountFilters() (icfilters.Filters, error) {
func (a *awsICInstallArgs) parseAccountFilters() (icfilters.Filters, error) {
filtersCap := len(a.accountNameFilters) + len(a.excludeAccountNameFilters) + len(a.accountIDFilters) + len(a.excludeAccountIDFilters)
filters := make([]*types.AWSICResourceFilter, 0, filtersCap)
for _, n := range a.accountNameFilters {
Expand Down Expand Up @@ -161,7 +163,7 @@ func (a *awsICArgs) parseAccountFilters() (icfilters.Filters, error) {
return icfilters.New(filters)
}

func (a *awsICArgs) parseUserFilters() ([]*types.AWSICUserSyncFilter, error) {
func (a *awsICInstallArgs) parseUserFilters() ([]*types.AWSICUserSyncFilter, error) {
result := make([]*types.AWSICUserSyncFilter, 0, len(a.userOrigins)+len(a.userLabels))

if len(a.userOrigins) > 0 {
Expand Down Expand Up @@ -241,7 +243,7 @@ func (p *PluginsCommand) initInstallAWSIC(parent *kingpin.CmdClause) {
}

// InstallAWSIC installs AWS Identity Center plugin.
func (p *PluginsCommand) InstallAWSIC(ctx context.Context, args installPluginArgs) error {
func (p *PluginsCommand) InstallAWSIC(ctx context.Context, args pluginServices) error {
awsICArgs := p.install.awsIC
if err := awsICArgs.validate(ctx, p.config.Logger); err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -356,7 +358,7 @@ func (p *PluginsCommand) initEditAWSIC(parent *kingpin.CmdClause) {
EnumVar(&p.edit.awsIC.rolesSyncMode, types.AWSICRolesSyncModeAll, types.AWSICRolesSyncModeNone)
}

func (p *PluginsCommand) EditAWSIC(ctx context.Context, args installPluginArgs) error {
func (p *PluginsCommand) EditAWSIC(ctx context.Context, args pluginServices) error {
plugin, err := args.plugins.GetPlugin(ctx, &pluginspb.GetPluginRequest{
Name: p.edit.awsIC.pluginName,
})
Expand All @@ -382,3 +384,116 @@ func (p *PluginsCommand) EditAWSIC(ctx context.Context, args installPluginArgs)
}
return nil
}

type awsICRotateCredsArgs struct {
cmd *kingpin.CmdClause
pluginName string
payload string
requireValidation bool
}

func (p *PluginsCommand) initRotateCredsAWSIC(parent *kingpin.CmdClause) {
p.rotateCreds.awsic.cmd = parent.Command("awsic", "Rotate the AWS Identity Center SCIM bearer token.")
cmd := p.rotateCreds.awsic.cmd
args := &p.rotateCreds.awsic

cmd.Flag("plugin-name", "Name of the AWSIC plugin instance to update. Defaults to "+apicommon.OriginAWSIdentityCenter+".").
Default(apicommon.OriginAWSIdentityCenter).
StringVar(&args.pluginName)

cmd.Arg("token", "The new SCIM bearer token.").
PlaceHolder("TOKEN").
Required().
StringVar(&p.rotateCreds.awsic.payload)

cmd.Flag("validate-token", "Validate that the supplied token is valid for the configured downstream SCIM service").
Default("true").
BoolVar(&args.requireValidation)
}

func (p *PluginsCommand) RotateAWSICCreds(ctx context.Context, args pluginServices) error {
cliArgs := &p.rotateCreds.awsic

slog.InfoContext(ctx, "Fetching plugin...", "plugin_name", cliArgs.pluginName)
plugin, err := args.plugins.GetPlugin(ctx, &pluginspb.GetPluginRequest{
Name: cliArgs.pluginName,
WithSecrets: true,
})
if err != nil {
return trace.Wrap(err, "fetching plugin %q", cliArgs.pluginName)
}

awsicSettings := plugin.Spec.GetAwsIc()
if awsicSettings == nil {
return trace.BadParameter(notAWSICPluginMsg, cliArgs.pluginName)
}

if p.rotateCreds.awsic.requireValidation {
if err := p.rotateCreds.awsic.validateToken(ctx, awsicSettings, args); err != nil {
return trace.Wrap(err, "validating SCIM bearer token")
}
}

staticCredsRef := plugin.Credentials.GetStaticCredentialsRef()
if staticCredsRef == nil {
return trace.BadParameter("plugin has no credentials reference")
}

req := pluginspb.UpdatePluginStaticCredentialsRequest{
Target: &pluginspb.UpdatePluginStaticCredentialsRequest_Query{
Query: &pluginspb.CredentialQuery{
Labels: staticCredsRef.Labels,
},
},
Credential: &types.PluginStaticCredentialsSpecV1{
Credentials: &types.PluginStaticCredentialsSpecV1_APIToken{
APIToken: p.rotateCreds.awsic.payload,
},
},
}

_, err = args.plugins.UpdatePluginStaticCredentials(ctx, &req)
if err != nil {
return trace.Wrap(err, "updating credentials")
}

return nil
}

func (args *awsICRotateCredsArgs) validateToken(ctx context.Context, awsicSettings *types.PluginAWSICSettings, env pluginServices) error {
provisioningSpec := awsicSettings.ProvisioningSpec
if provisioningSpec == nil {
return trace.BadParameter("plugin is missing provisioning spec")
}

slog.InfoContext(ctx, "Validating token", "scim_server", provisioningSpec.BaseUrl)

endPoint, err := url.JoinPath(provisioningSpec.BaseUrl, "ServiceProviderConfig")
if err != nil {
return trace.Wrap(err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, endPoint, nil)
if err != nil {
return trace.Wrap(err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", args.payload))
req.Header.Set("Accept", "application/scim+json")

resp, err := env.httpProvider.RoundTrip(req)
if err != nil {
return trace.Wrap(err)
}
resp.Body.Close()

switch resp.StatusCode {
case http.StatusOK:
case http.StatusUnauthorized, http.StatusForbidden:
return trace.BadParameter("invalid token")
case http.StatusInternalServerError:
return trace.BadParameter("internal server error")
default:
return trace.BadParameter("unexpected status code %v", resp.StatusCode)
}
return nil
}
Loading
Loading