From 57542a59439a3c92f5f925f78de730173819ad78 Mon Sep 17 00:00:00 2001 From: Trent Clarke Date: Mon, 15 Sep 2025 11:09:22 +1000 Subject: [PATCH] [v17][AWSIC] Update SCIM token with `tctl` Implements SCIM token rotation for AWS IC, both to be generally useful and as an example of how to use the Update Creds API for other plugins. Usage: `$ tctl plugins update-creds awsic ${TOKEN} ` Addresses: #55662 Backports: #57691 Changelog: Added the ability to update the AWS Identity Center SCIM token in tctl --- tool/tctl/common/plugin/awsic.go | 133 ++++++++- tool/tctl/common/plugin/awsic_test.go | 253 +++++++++++++++++- tool/tctl/common/plugin/entraid.go | 2 +- tool/tctl/common/plugin/mocks_test.go | 125 +++++++++ tool/tctl/common/plugin/netiq.go | 2 +- tool/tctl/common/plugin/okta.go | 4 +- tool/tctl/common/plugin/plugins_command.go | 53 ++-- .../common/plugin/plugins_command_test.go | 96 +------ tool/tctl/common/plugin/scim.go | 2 +- 9 files changed, 549 insertions(+), 121 deletions(-) create mode 100644 tool/tctl/common/plugin/mocks_test.go diff --git a/tool/tctl/common/plugin/awsic.go b/tool/tctl/common/plugin/awsic.go index bb9d16e797dbb..d6283be73a549 100644 --- a/tool/tctl/common/plugin/awsic.go +++ b/tool/tctl/common/plugin/awsic.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "log/slog" + "net/http" "net/url" "slices" @@ -41,9 +42,10 @@ const ( awsicPluginNameHelp = "Name of the AWS Identity Center integration instance to update. Defaults to " + apicommon.OriginAWSIdentityCenter + "." awsicRolesSyncModeFlag = "roles-sync-mode" awsicRolesSyncModeHelp = "Control account-assignment role creation. ALL creates roles for all possible account assignments. NONE creates no roles, and also implies a totally-exclusive group import filter." + notAWSICPluginMsg = "%q is not an AWS Identity Center integration" ) -type awsICArgs struct { +type awsICInstallArgs struct { cmd *kingpin.CmdClause defaultOwners []string scimToken string @@ -64,7 +66,7 @@ type awsICArgs struct { excludeAccountIDFilters []string } -func (a *awsICArgs) validate(ctx context.Context, log *slog.Logger) error { +func (a *awsICInstallArgs) validate(ctx context.Context, log *slog.Logger) error { if !awsutils.IsKnownRegion(a.region) { return trace.BadParameter("unknown AWS region: %s", a.region) } @@ -84,7 +86,7 @@ func (a *awsICArgs) validate(ctx context.Context, log *slog.Logger) error { return nil } -func (a *awsICArgs) validateSystemCredentialInput() error { +func (a *awsICInstallArgs) validateSystemCredentialInput() error { if !a.useSystemCredentials { return trace.BadParameter("--use-system-credentials must be set. The tctl-based AWS IAM Identity Center plugin installation only supports AWS local system credentials") } @@ -100,7 +102,7 @@ func (a *awsICArgs) validateSystemCredentialInput() error { return nil } -func (a *awsICArgs) validateSCIMBaseURL(ctx context.Context, log *slog.Logger) error { +func (a *awsICInstallArgs) validateSCIMBaseURL(ctx context.Context, log *slog.Logger) error { validatedBaseUrl, err := icutils.EnsureSCIMEndpointURL(a.scimURL) if err == nil { a.scimURL = validatedBaseUrl @@ -116,7 +118,7 @@ func (a *awsICArgs) validateSCIMBaseURL(ctx context.Context, log *slog.Logger) e return trace.Wrap(err) } -func (a *awsICArgs) parseGroupFilters() (icfilters.Filters, error) { +func (a *awsICInstallArgs) parseGroupFilters() (icfilters.Filters, error) { filters := make([]*types.AWSICResourceFilter, 0, len(a.groupNameFilters)+len(a.excludeGroupNameFilters)) for _, n := range a.groupNameFilters { filters = append(filters, &types.AWSICResourceFilter{ @@ -131,7 +133,7 @@ func (a *awsICArgs) parseGroupFilters() (icfilters.Filters, error) { return icfilters.New(filters) } -func (a *awsICArgs) parseAccountFilters() (icfilters.Filters, error) { +func (a *awsICInstallArgs) parseAccountFilters() (icfilters.Filters, error) { filtersCap := len(a.accountNameFilters) + len(a.excludeAccountNameFilters) + len(a.accountIDFilters) + len(a.excludeAccountIDFilters) filters := make([]*types.AWSICResourceFilter, 0, filtersCap) for _, n := range a.accountNameFilters { @@ -161,7 +163,7 @@ func (a *awsICArgs) parseAccountFilters() (icfilters.Filters, error) { return icfilters.New(filters) } -func (a *awsICArgs) parseUserFilters() ([]*types.AWSICUserSyncFilter, error) { +func (a *awsICInstallArgs) parseUserFilters() ([]*types.AWSICUserSyncFilter, error) { result := make([]*types.AWSICUserSyncFilter, 0, len(a.userOrigins)+len(a.userLabels)) if len(a.userOrigins) > 0 { @@ -241,7 +243,7 @@ func (p *PluginsCommand) initInstallAWSIC(parent *kingpin.CmdClause) { } // InstallAWSIC installs AWS Identity Center plugin. -func (p *PluginsCommand) InstallAWSIC(ctx context.Context, args installPluginArgs) error { +func (p *PluginsCommand) InstallAWSIC(ctx context.Context, args pluginServices) error { awsICArgs := p.install.awsIC if err := awsICArgs.validate(ctx, p.config.Logger); err != nil { return trace.Wrap(err) @@ -356,7 +358,7 @@ func (p *PluginsCommand) initEditAWSIC(parent *kingpin.CmdClause) { EnumVar(&p.edit.awsIC.rolesSyncMode, types.AWSICRolesSyncModeAll, types.AWSICRolesSyncModeNone) } -func (p *PluginsCommand) EditAWSIC(ctx context.Context, args installPluginArgs) error { +func (p *PluginsCommand) EditAWSIC(ctx context.Context, args pluginServices) error { plugin, err := args.plugins.GetPlugin(ctx, &pluginspb.GetPluginRequest{ Name: p.edit.awsIC.pluginName, }) @@ -382,3 +384,116 @@ func (p *PluginsCommand) EditAWSIC(ctx context.Context, args installPluginArgs) } return nil } + +type awsICRotateCredsArgs struct { + cmd *kingpin.CmdClause + pluginName string + payload string + requireValidation bool +} + +func (p *PluginsCommand) initRotateCredsAWSIC(parent *kingpin.CmdClause) { + p.rotateCreds.awsic.cmd = parent.Command("awsic", "Rotate the AWS Identity Center SCIM bearer token.") + cmd := p.rotateCreds.awsic.cmd + args := &p.rotateCreds.awsic + + cmd.Flag("plugin-name", "Name of the AWSIC plugin instance to update. Defaults to "+apicommon.OriginAWSIdentityCenter+"."). + Default(apicommon.OriginAWSIdentityCenter). + StringVar(&args.pluginName) + + cmd.Arg("token", "The new SCIM bearer token."). + PlaceHolder("TOKEN"). + Required(). + StringVar(&p.rotateCreds.awsic.payload) + + cmd.Flag("validate-token", "Validate that the supplied token is valid for the configured downstream SCIM service"). + Default("true"). + BoolVar(&args.requireValidation) +} + +func (p *PluginsCommand) RotateAWSICCreds(ctx context.Context, args pluginServices) error { + cliArgs := &p.rotateCreds.awsic + + slog.InfoContext(ctx, "Fetching plugin...", "plugin_name", cliArgs.pluginName) + plugin, err := args.plugins.GetPlugin(ctx, &pluginspb.GetPluginRequest{ + Name: cliArgs.pluginName, + WithSecrets: true, + }) + if err != nil { + return trace.Wrap(err, "fetching plugin %q", cliArgs.pluginName) + } + + awsicSettings := plugin.Spec.GetAwsIc() + if awsicSettings == nil { + return trace.BadParameter(notAWSICPluginMsg, cliArgs.pluginName) + } + + if p.rotateCreds.awsic.requireValidation { + if err := p.rotateCreds.awsic.validateToken(ctx, awsicSettings, args); err != nil { + return trace.Wrap(err, "validating SCIM bearer token") + } + } + + staticCredsRef := plugin.Credentials.GetStaticCredentialsRef() + if staticCredsRef == nil { + return trace.BadParameter("plugin has no credentials reference") + } + + req := pluginspb.UpdatePluginStaticCredentialsRequest{ + Target: &pluginspb.UpdatePluginStaticCredentialsRequest_Query{ + Query: &pluginspb.CredentialQuery{ + Labels: staticCredsRef.Labels, + }, + }, + Credential: &types.PluginStaticCredentialsSpecV1{ + Credentials: &types.PluginStaticCredentialsSpecV1_APIToken{ + APIToken: p.rotateCreds.awsic.payload, + }, + }, + } + + _, err = args.plugins.UpdatePluginStaticCredentials(ctx, &req) + if err != nil { + return trace.Wrap(err, "updating credentials") + } + + return nil +} + +func (args *awsICRotateCredsArgs) validateToken(ctx context.Context, awsicSettings *types.PluginAWSICSettings, env pluginServices) error { + provisioningSpec := awsicSettings.ProvisioningSpec + if provisioningSpec == nil { + return trace.BadParameter("plugin is missing provisioning spec") + } + + slog.InfoContext(ctx, "Validating token", "scim_server", provisioningSpec.BaseUrl) + + endPoint, err := url.JoinPath(provisioningSpec.BaseUrl, "ServiceProviderConfig") + if err != nil { + return trace.Wrap(err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endPoint, nil) + if err != nil { + return trace.Wrap(err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", args.payload)) + req.Header.Set("Accept", "application/scim+json") + + resp, err := env.httpProvider.RoundTrip(req) + if err != nil { + return trace.Wrap(err) + } + resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + case http.StatusUnauthorized, http.StatusForbidden: + return trace.BadParameter("invalid token") + case http.StatusInternalServerError: + return trace.BadParameter("internal server error") + default: + return trace.BadParameter("unexpected status code %v", resp.StatusCode) + } + return nil +} diff --git a/tool/tctl/common/plugin/awsic_test.go b/tool/tctl/common/plugin/awsic_test.go index 55ea37c51669c..6fd702893ce43 100644 --- a/tool/tctl/common/plugin/awsic_test.go +++ b/tool/tctl/common/plugin/awsic_test.go @@ -17,13 +17,20 @@ package plugin import ( + "bytes" "context" + "io" "log/slog" + "net/http" "net/url" "testing" + "github.com/gravitational/trace" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + pluginsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1" "github.com/gravitational/teleport/api/types" ) @@ -92,7 +99,7 @@ func TestAWSICUserFilters(t *testing.T) { for _, test := range testCases { t.Run(test.name, func(t *testing.T) { - cliArgs := awsICArgs{ + cliArgs := awsICInstallArgs{ userLabels: test.labelValues, userOrigins: test.originValues, } @@ -134,7 +141,7 @@ func TestAWSICGroupFilters(t *testing.T) { for _, test := range testCases { t.Run(test.name, func(t *testing.T) { - cliArgs := awsICArgs{ + cliArgs := awsICInstallArgs{ groupNameFilters: test.nameValues, } @@ -198,7 +205,7 @@ func TestAWSICAccountFilters(t *testing.T) { for _, test := range testCases { t.Run(test.name, func(t *testing.T) { - cliArgs := awsICArgs{ + cliArgs := awsICInstallArgs{ accountNameFilters: test.nameValues, accountIDFilters: test.idValues, } @@ -256,7 +263,7 @@ func TestSCIMBaseURLValidation(t *testing.T) { for _, test := range testCases { t.Run(test.name, func(t *testing.T) { - cliArgs := awsICArgs{ + cliArgs := awsICInstallArgs{ scimURL: mustParseURL(test.suppliedURL), forceSCIMURL: test.forceURL, } @@ -305,7 +312,7 @@ func TestUseSystemCredentialsInput(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - cliArgs := awsICArgs{ + cliArgs := awsICInstallArgs{ useSystemCredentials: tc.useSystemCredential, assumeRoleARN: tc.assumeRoleARN, } @@ -315,3 +322,239 @@ func TestUseSystemCredentialsInput(t *testing.T) { }) } } + +type mockRoundTripper struct { + mock.Mock +} + +// RoundTrip implements the [http.RoundTripper] interface for the mockRoundTripper +func (m *mockRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { + args := m.Called(request) + maybeResponse := args.Get(0) + if maybeResponse == nil { + return nil, args.Error(1) + } + return args.Get(0).(*http.Response), args.Error(1) +} + +func TestRotateAWSICSCIMToken(t *testing.T) { + const ( + scimURL = "https://scim.example.com" + ) + validAWSICPlugin := func() *types.PluginV1 { + return &types.PluginV1{ + Kind: types.KindPlugin, + SubKind: types.PluginSubkindAccess, + Metadata: types.Metadata{ + Name: types.PluginTypeAWSIdentityCenter, + Labels: map[string]string{types.HostedPluginLabel: "true"}, + }, + Spec: types.PluginSpecV1{ + Settings: &types.PluginSpecV1_AwsIc{ + AwsIc: &types.PluginAWSICSettings{ + ProvisioningSpec: &types.AWSICProvisioningSpec{ + BaseUrl: scimURL, + }, + }, + }, + }, + Credentials: &types.PluginCredentialsV1{ + Credentials: &types.PluginCredentialsV1_StaticCredentialsRef{ + StaticCredentialsRef: &types.PluginStaticCredentialsRef{ + Labels: map[string]string{ + "plugin-id": "some-aws-ic-integration", + }, + }, + }, + }, + } + } + + makeResponse := func(status int) *http.Response { + if status == 0 { + return nil + } + return &http.Response{ + Status: http.StatusText(status), + StatusCode: status, + Body: io.NopCloser(&bytes.Buffer{}), + } + } + + testCases := []struct { + name string + cliArgs awsICRotateCredsArgs + pluginValueProvider func() *types.PluginV1 + pluginFetchError error + expectValidation bool + validationError error + validationResponse int + expectUpdate bool + updateError error + assertError require.ErrorAssertionFunc + }{ + { + name: "default", + cliArgs: awsICRotateCredsArgs{ + pluginName: types.PluginTypeAWSIdentityCenter, + requireValidation: true, + payload: "some-token", + }, + pluginValueProvider: validAWSICPlugin, + expectValidation: true, + validationResponse: http.StatusOK, + expectUpdate: true, + assertError: require.NoError, + }, + { + name: "no such plugin", + cliArgs: awsICRotateCredsArgs{ + pluginName: types.PluginTypeAWSIdentityCenter, + requireValidation: true, + payload: "some-token", + }, + pluginValueProvider: func() *types.PluginV1 { return nil }, + pluginFetchError: trace.NotFound("no such plugin"), + assertError: require.Error, + }, + { + name: "wrong plugin type", + cliArgs: awsICRotateCredsArgs{ + pluginName: types.PluginTypeAWSIdentityCenter, + requireValidation: true, + payload: "some-token", + }, + pluginValueProvider: func() *types.PluginV1 { + return &types.PluginV1{ + Kind: types.KindPlugin, + SubKind: types.PluginSubkindAccess, + Metadata: types.Metadata{ + Name: "okta", + Labels: map[string]string{types.HostedPluginLabel: "true"}, + }, + Spec: types.PluginSpecV1{ + Settings: &types.PluginSpecV1_Okta{ + Okta: &types.PluginOktaSettings{}, + }, + }, + } + }, + assertError: require.Error, + }, + { + name: "no such credential", + cliArgs: awsICRotateCredsArgs{ + pluginName: types.PluginTypeAWSIdentityCenter, + requireValidation: true, + payload: "some-token", + }, + pluginValueProvider: validAWSICPlugin, + expectValidation: true, + validationResponse: http.StatusOK, + expectUpdate: true, + updateError: trace.NotFound("no such credential"), + assertError: require.Error, + }, + { + name: "validation failure", + cliArgs: awsICRotateCredsArgs{ + pluginName: types.PluginTypeAWSIdentityCenter, + requireValidation: true, + payload: "some-token", + }, + expectValidation: true, + validationResponse: http.StatusForbidden, + pluginValueProvider: validAWSICPlugin, + expectUpdate: false, + assertError: requireBadParameter, + }, + { + name: "bypass validation", + cliArgs: awsICRotateCredsArgs{ + pluginName: types.PluginTypeAWSIdentityCenter, + requireValidation: false, + payload: "some-token", + }, + expectValidation: false, + validationResponse: http.StatusForbidden, + pluginValueProvider: validAWSICPlugin, + expectUpdate: true, + assertError: require.NoError, + }, + { + name: "update credential access denied", + cliArgs: awsICRotateCredsArgs{ + pluginName: types.PluginTypeAWSIdentityCenter, + requireValidation: true, + payload: "some-token", + }, + expectValidation: true, + validationResponse: http.StatusOK, + pluginValueProvider: validAWSICPlugin, + expectUpdate: true, + updateError: trace.AccessDenied("computer says no"), + assertError: requireAccessDenied, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + cliArgs := PluginsCommand{ + rotateCreds: pluginRotateCredsArgs{ + awsic: test.cliArgs, + }, + } + + pluginsClient := &mockPluginsClient{} + pluginsClient. + On("GetPlugin", anyContext, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + req, ok := args.Get(1).(*pluginsv1.GetPluginRequest) + require.True(t, ok, "expecting a *pluginsv1.GetPluginRequest, got %T", args.Get(1)) + require.Equal(t, test.cliArgs.pluginName, req.Name) + require.True(t, req.WithSecrets) + }). + Return(test.pluginValueProvider(), test.pluginFetchError) + + if test.expectUpdate { + pluginsClient. + On("UpdatePluginStaticCredentials", anyContext, mock.Anything, mock.Anything). + Return(func(ctx context.Context, in *pluginsv1.UpdatePluginStaticCredentialsRequest, _ ...grpc.CallOption) (*pluginsv1.UpdatePluginStaticCredentialsResponse, error) { + q := in.GetQuery() + require.NotNil(t, q, "Update request must specify target labels") + require.NotEmpty(t, q.Labels, "Update request must specify non-empty labels") + + return &pluginsv1.UpdatePluginStaticCredentialsResponse{ + Credential: &types.PluginStaticCredentialsV1{Spec: in.GetCredential()}, + }, test.updateError + }) + } + + roundTripper := &mockRoundTripper{} + if test.expectValidation { + response := makeResponse(test.validationResponse) + defer response.Body.Close() + + roundTripper. + On("RoundTrip", mock.Anything). + Run(func(args mock.Arguments) { + req, ok := args.Get(0).(*http.Request) + require.True(t, ok, "expecting a *http.Request, got %T", args.Get(0)) + require.Equal(t, "Bearer "+test.cliArgs.payload, req.Header.Get("Authorization")) + }). + Return(response, test.validationError) + } + + args := pluginServices{ + plugins: pluginsClient, + httpProvider: roundTripper, + } + + err := cliArgs.RotateAWSICCreds(context.Background(), args) + test.assertError(t, err) + + pluginsClient.AssertExpectations(t) + roundTripper.AssertExpectations(t) + }) + } +} diff --git a/tool/tctl/common/plugin/entraid.go b/tool/tctl/common/plugin/entraid.go index b175b11312665..631d5059b23e9 100644 --- a/tool/tctl/common/plugin/entraid.go +++ b/tool/tctl/common/plugin/entraid.go @@ -239,7 +239,7 @@ func readAzureInputs(acessGraph bool) (entraSettings, error) { // system credentials for EntraID authentication. // Finally, if no system credentials are in use, the script will set up an Azure OIDC integration // in Teleport and a Teleport plugin to synchronize access lists from EntraID to Teleport. -func (p *PluginsCommand) InstallEntra(ctx context.Context, args installPluginArgs) error { +func (p *PluginsCommand) InstallEntra(ctx context.Context, args pluginServices) error { inputs := p.install proxyPublicAddr, err := getProxyPublicAddr(ctx, args.authClient) diff --git a/tool/tctl/common/plugin/mocks_test.go b/tool/tctl/common/plugin/mocks_test.go new file mode 100644 index 0000000000000..336d57cfde5f2 --- /dev/null +++ b/tool/tctl/common/plugin/mocks_test.go @@ -0,0 +1,125 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package plugin + +import ( + "context" + + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/gravitational/teleport/api/client/proto" + pluginsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1" + "github.com/gravitational/teleport/api/mfa" + "github.com/gravitational/teleport/api/types" +) + +type mockPluginsClient struct { + mock.Mock +} + +func (m *mockPluginsClient) CreatePlugin(ctx context.Context, in *pluginsv1.CreatePluginRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { + result := m.Called(ctx, in, opts) + return result.Get(0).(*emptypb.Empty), result.Error(1) +} + +func (m *mockPluginsClient) GetPlugin(ctx context.Context, in *pluginsv1.GetPluginRequest, opts ...grpc.CallOption) (*types.PluginV1, error) { + result := m.Called(ctx, in, opts) + return result.Get(0).(*types.PluginV1), result.Error(1) +} + +func (m *mockPluginsClient) UpdatePlugin(ctx context.Context, in *pluginsv1.UpdatePluginRequest, opts ...grpc.CallOption) (*types.PluginV1, error) { + result := m.Called(ctx, in, opts) + return result.Get(0).(*types.PluginV1), result.Error(1) +} + +func (m *mockPluginsClient) NeedsCleanup(ctx context.Context, in *pluginsv1.NeedsCleanupRequest, opts ...grpc.CallOption) (*pluginsv1.NeedsCleanupResponse, error) { + result := m.Called(ctx, in, opts) + return result.Get(0).(*pluginsv1.NeedsCleanupResponse), result.Error(1) +} + +func (m *mockPluginsClient) Cleanup(ctx context.Context, in *pluginsv1.CleanupRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { + result := m.Called(ctx, in, opts) + return result.Get(0).(*emptypb.Empty), result.Error(1) +} + +func (m *mockPluginsClient) DeletePlugin(ctx context.Context, in *pluginsv1.DeletePluginRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { + result := m.Called(ctx, in, opts) + return result.Get(0).(*emptypb.Empty), result.Error(1) +} + +func (m *mockPluginsClient) UpdatePluginStaticCredentials(ctx context.Context, in *pluginsv1.UpdatePluginStaticCredentialsRequest, opts ...grpc.CallOption) (*pluginsv1.UpdatePluginStaticCredentialsResponse, error) { + result := m.Called(ctx, in, opts) + var response *pluginsv1.UpdatePluginStaticCredentialsResponse + + if fn, ok := result.Get(0).(func(context.Context, *pluginsv1.UpdatePluginStaticCredentialsRequest, ...grpc.CallOption) (*pluginsv1.UpdatePluginStaticCredentialsResponse, error)); ok { + return fn(ctx, in, opts...) + } + + if r, ok := result.Get(0).(*pluginsv1.UpdatePluginStaticCredentialsResponse); ok { + response = r + } + return response, result.Error(1) +} + +type mockAuthClient struct { + mock.Mock +} + +func (m *mockAuthClient) GetSAMLConnector(ctx context.Context, id string, withSecrets bool) (types.SAMLConnector, error) { + result := m.Called(ctx, id, withSecrets) + return result.Get(0).(types.SAMLConnector), result.Error(1) +} +func (m *mockAuthClient) CreateSAMLConnector(ctx context.Context, connector types.SAMLConnector) (types.SAMLConnector, error) { + result := m.Called(ctx, connector) + return result.Get(0).(types.SAMLConnector), result.Error(1) +} +func (m *mockAuthClient) UpsertSAMLConnector(ctx context.Context, connector types.SAMLConnector) (types.SAMLConnector, error) { + result := m.Called(ctx, connector) + return result.Get(0).(types.SAMLConnector), result.Error(1) +} +func (m *mockAuthClient) CreateIntegration(ctx context.Context, ig types.Integration) (types.Integration, error) { + result := m.Called(ctx, ig) + return result.Get(0).(types.Integration), result.Error(1) +} +func (m *mockAuthClient) UpdateIntegration(ctx context.Context, ig types.Integration) (types.Integration, error) { + result := m.Called(ctx, ig) + return result.Get(0).(types.Integration), result.Error(1) +} + +func (m *mockAuthClient) GetIntegration(ctx context.Context, name string) (types.Integration, error) { + result := m.Called(ctx, name) + return result.Get(0).(types.Integration), result.Error(1) +} + +func (m *mockAuthClient) Ping(ctx context.Context) (proto.PingResponse, error) { + result := m.Called(ctx) + return result.Get(0).(proto.PingResponse), result.Error(1) +} + +func (m *mockAuthClient) PerformMFACeremony(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) { + return &proto.MFAAuthenticateResponse{}, nil +} + +func (m *mockAuthClient) GetRole(ctx context.Context, name string) (types.Role, error) { + result := m.Called(ctx, name) + return result.Get(0).(types.Role), result.Error(1) +} + +// anyContext is an argument matcher for testify mocks that matches any context. +var anyContext any = mock.MatchedBy(func(context.Context) bool { return true }) diff --git a/tool/tctl/common/plugin/netiq.go b/tool/tctl/common/plugin/netiq.go index 66bdb891317a1..80956bf776df2 100644 --- a/tool/tctl/common/plugin/netiq.go +++ b/tool/tctl/common/plugin/netiq.go @@ -172,7 +172,7 @@ func (p *PluginsCommand) netIQSetupGuide(ctx context.Context) (netIQSettings, er return settings, nil } -func (p *PluginsCommand) InstallNetIQ(ctx context.Context, args installPluginArgs) error { +func (p *PluginsCommand) InstallNetIQ(ctx context.Context, args pluginServices) error { settings, err := p.netIQSetupGuide(ctx) if err != nil { if errors.Is(err, errCancel) { diff --git a/tool/tctl/common/plugin/okta.go b/tool/tctl/common/plugin/okta.go index c9824a26156ef..3076c20f962f3 100644 --- a/tool/tctl/common/plugin/okta.go +++ b/tool/tctl/common/plugin/okta.go @@ -108,7 +108,7 @@ type oktaArgs struct { autoGeneratedSCIMToken bool } -func (s *oktaArgs) validateAndCheckDefaults(ctx context.Context, args *installPluginArgs) error { +func (s *oktaArgs) validateAndCheckDefaults(ctx context.Context, args *pluginServices) error { if s.apiToken == "" { if !s.scimEnabled { return trace.BadParameter("API token is required") @@ -165,7 +165,7 @@ func (s *oktaArgs) validateAndCheckDefaults(ctx context.Context, args *installPl return nil } -func (p *PluginsCommand) InstallOkta(ctx context.Context, args installPluginArgs) error { +func (p *PluginsCommand) InstallOkta(ctx context.Context, args pluginServices) error { oktaSettings := p.install.okta if err := oktaSettings.validateAndCheckDefaults(ctx, &args); err != nil { return trace.Wrap(err) diff --git a/tool/tctl/common/plugin/plugins_command.go b/tool/tctl/common/plugin/plugins_command.go index f458f0a049422..6dae914ca83a5 100644 --- a/tool/tctl/common/plugin/plugins_command.go +++ b/tool/tctl/common/plugin/plugins_command.go @@ -22,6 +22,7 @@ import ( "context" "fmt" "log/slog" + "net/http" "github.com/alecthomas/kingpin/v2" "github.com/gravitational/trace" @@ -49,7 +50,7 @@ type pluginInstallArgs struct { scim scimArgs entraID entraArgs netIQ netIQArgs - awsIC awsICArgs + awsIC awsICInstallArgs } type pluginEditArgs struct { @@ -69,15 +70,21 @@ type pluginDeleteArgs struct { name string } +type pluginRotateCredsArgs struct { + cmd *kingpin.CmdClause + awsic awsICRotateCredsArgs +} + // PluginsCommand allows for management of plugins. type PluginsCommand struct { - config *servicecfg.Config - cleanupCmd *kingpin.CmdClause - pluginType string - dryRun bool - install pluginInstallArgs - delete pluginDeleteArgs - edit pluginEditArgs + config *servicecfg.Config + cleanupCmd *kingpin.CmdClause + pluginType string + dryRun bool + install pluginInstallArgs + delete pluginDeleteArgs + edit pluginEditArgs + rotateCreds pluginRotateCredsArgs } // Initialize creates the plugins command and subcommands @@ -94,6 +101,7 @@ func (p *PluginsCommand) Initialize(app *kingpin.Application, _ *tctlcfg.GlobalC p.initInstall(pluginsCommand, config) p.initDelete(pluginsCommand) p.initEdit(pluginsCommand) + p.initRotateCreds(pluginsCommand) } func (p *PluginsCommand) initInstall(parent *kingpin.CmdClause, config *servicecfg.Config) { @@ -119,7 +127,7 @@ func (p *PluginsCommand) initEdit(parent *kingpin.CmdClause) { } // Delete implements `tctl plugins delete`, deleting a plugin from the Teleport cluster -func (p *PluginsCommand) Delete(ctx context.Context, args installPluginArgs) error { +func (p *PluginsCommand) Delete(ctx context.Context, args pluginServices) error { log := p.config.Logger.With("plugin", p.delete.name) req := &pluginsv1.DeletePluginRequest{Name: p.delete.name} @@ -135,7 +143,7 @@ func (p *PluginsCommand) Delete(ctx context.Context, args installPluginArgs) err } // Cleanup cleans up the given plugin. -func (p *PluginsCommand) Cleanup(ctx context.Context, args installPluginArgs) error { +func (p *PluginsCommand) Cleanup(ctx context.Context, args pluginServices) error { needsCleanup, err := args.plugins.NeedsCleanup(ctx, &pluginsv1.NeedsCleanupRequest{ Type: p.pluginType, }) @@ -179,6 +187,12 @@ func (p *PluginsCommand) Cleanup(ctx context.Context, args installPluginArgs) er return nil } +func (p *PluginsCommand) initRotateCreds(parent *kingpin.CmdClause) { + p.rotateCreds.cmd = parent.Command("rotate", "Rotates a plugin's credentials.") + + p.initRotateCredsAWSIC(p.rotateCreds.cmd) +} + type authClient interface { GetSAMLConnector(ctx context.Context, id string, withSecrets bool) (types.SAMLConnector, error) CreateSAMLConnector(ctx context.Context, connector types.SAMLConnector) (types.SAMLConnector, error) @@ -198,16 +212,18 @@ type pluginsClient interface { NeedsCleanup(ctx context.Context, in *pluginsv1.NeedsCleanupRequest, opts ...grpc.CallOption) (*pluginsv1.NeedsCleanupResponse, error) Cleanup(ctx context.Context, in *pluginsv1.CleanupRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) DeletePlugin(ctx context.Context, in *pluginsv1.DeletePluginRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) + UpdatePluginStaticCredentials(ctx context.Context, in *pluginsv1.UpdatePluginStaticCredentialsRequest, opts ...grpc.CallOption) (*pluginsv1.UpdatePluginStaticCredentialsResponse, error) } -type installPluginArgs struct { - authClient authClient - plugins pluginsClient +type pluginServices struct { + authClient authClient + plugins pluginsClient + httpProvider http.RoundTripper } // TryRun runs the plugins command func (p *PluginsCommand) TryRun(ctx context.Context, cmd string, clientFunc commonclient.InitFunc) (match bool, err error) { - var commandFunc func(ctx context.Context, args installPluginArgs) error + var commandFunc func(ctx context.Context, args pluginServices) error switch cmd { case p.cleanupCmd.FullCommand(): commandFunc = p.Cleanup @@ -225,6 +241,8 @@ func (p *PluginsCommand) TryRun(ctx context.Context, cmd string, clientFunc comm commandFunc = p.Delete case p.edit.awsIC.cmd.FullCommand(): commandFunc = p.EditAWSIC + case p.rotateCreds.awsic.cmd.FullCommand(): + commandFunc = p.RotateAWSICCreds default: return false, nil } @@ -232,7 +250,12 @@ func (p *PluginsCommand) TryRun(ctx context.Context, cmd string, clientFunc comm if err != nil { return false, trace.Wrap(err) } - err = commandFunc(ctx, installPluginArgs{authClient: client, plugins: client.PluginsClient()}) + args := pluginServices{ + authClient: client, + plugins: client.PluginsClient(), + httpProvider: http.DefaultTransport, + } + err = commandFunc(ctx, args) closeFn(ctx) return true, trace.Wrap(err) diff --git a/tool/tctl/common/plugin/plugins_command_test.go b/tool/tctl/common/plugin/plugins_command_test.go index 6b472bfa5ab31..8566f83ba7dc8 100644 --- a/tool/tctl/common/plugin/plugins_command_test.go +++ b/tool/tctl/common/plugin/plugins_command_test.go @@ -27,12 +27,10 @@ import ( "github.com/gravitational/trace" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "google.golang.org/grpc" "google.golang.org/protobuf/types/known/emptypb" "github.com/gravitational/teleport/api/client/proto" pluginsv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1" - "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/service/servicecfg" ) @@ -503,7 +501,7 @@ func TestPluginsInstallOkta(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - var args installPluginArgs + var args pluginServices if testCase.expectRequest != nil { pluginsClient := &mockPluginsClient{} @@ -552,9 +550,14 @@ func TestPluginsInstallOkta(t *testing.T) { } } -func requireBadParameter(t require.TestingT, err error, _ ...any) { - require.Error(t, err) - require.True(t, trace.IsBadParameter(err), "Expecting bad parameter, got %T: \"%v\"", err, err) +func requireBadParameter(t require.TestingT, err error, msgAndArgs ...any) { + var bpe *trace.BadParameterError + require.ErrorAs(t, err, &bpe, msgAndArgs...) +} + +func requireAccessDenied(t require.TestingT, err error, msgAndArgs ...any) { + var ade *trace.AccessDeniedError + require.ErrorAs(t, err, &ade, msgAndArgs...) } func mustParseURL(text string) *url.URL { @@ -564,84 +567,3 @@ func mustParseURL(text string) *url.URL { } return url } - -type mockPluginsClient struct { - mock.Mock -} - -func (m *mockPluginsClient) CreatePlugin(ctx context.Context, in *pluginsv1.CreatePluginRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { - result := m.Called(ctx, in, opts) - return result.Get(0).(*emptypb.Empty), result.Error(1) -} - -func (m *mockPluginsClient) GetPlugin(ctx context.Context, in *pluginsv1.GetPluginRequest, opts ...grpc.CallOption) (*types.PluginV1, error) { - result := m.Called(ctx, in, opts) - return result.Get(0).(*types.PluginV1), result.Error(1) -} - -func (m *mockPluginsClient) UpdatePlugin(ctx context.Context, in *pluginsv1.UpdatePluginRequest, opts ...grpc.CallOption) (*types.PluginV1, error) { - result := m.Called(ctx, in, opts) - return result.Get(0).(*types.PluginV1), result.Error(1) -} - -func (m *mockPluginsClient) NeedsCleanup(ctx context.Context, in *pluginsv1.NeedsCleanupRequest, opts ...grpc.CallOption) (*pluginsv1.NeedsCleanupResponse, error) { - result := m.Called(ctx, in, opts) - return result.Get(0).(*pluginsv1.NeedsCleanupResponse), result.Error(1) -} - -func (m *mockPluginsClient) Cleanup(ctx context.Context, in *pluginsv1.CleanupRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { - result := m.Called(ctx, in, opts) - return result.Get(0).(*emptypb.Empty), result.Error(1) -} - -func (m *mockPluginsClient) DeletePlugin(ctx context.Context, in *pluginsv1.DeletePluginRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { - result := m.Called(ctx, in, opts) - return result.Get(0).(*emptypb.Empty), result.Error(1) -} - -type mockAuthClient struct { - mock.Mock -} - -func (m *mockAuthClient) GetSAMLConnector(ctx context.Context, id string, withSecrets bool) (types.SAMLConnector, error) { - result := m.Called(ctx, id, withSecrets) - return result.Get(0).(types.SAMLConnector), result.Error(1) -} -func (m *mockAuthClient) CreateSAMLConnector(ctx context.Context, connector types.SAMLConnector) (types.SAMLConnector, error) { - result := m.Called(ctx, connector) - return result.Get(0).(types.SAMLConnector), result.Error(1) -} -func (m *mockAuthClient) UpsertSAMLConnector(ctx context.Context, connector types.SAMLConnector) (types.SAMLConnector, error) { - result := m.Called(ctx, connector) - return result.Get(0).(types.SAMLConnector), result.Error(1) -} -func (m *mockAuthClient) CreateIntegration(ctx context.Context, ig types.Integration) (types.Integration, error) { - result := m.Called(ctx, ig) - return result.Get(0).(types.Integration), result.Error(1) -} -func (m *mockAuthClient) UpdateIntegration(ctx context.Context, ig types.Integration) (types.Integration, error) { - result := m.Called(ctx, ig) - return result.Get(0).(types.Integration), result.Error(1) -} - -func (m *mockAuthClient) GetIntegration(ctx context.Context, name string) (types.Integration, error) { - result := m.Called(ctx, name) - return result.Get(0).(types.Integration), result.Error(1) -} - -func (m *mockAuthClient) Ping(ctx context.Context) (proto.PingResponse, error) { - result := m.Called(ctx) - return result.Get(0).(proto.PingResponse), result.Error(1) -} - -func (m *mockAuthClient) PerformMFACeremony(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) { - return &proto.MFAAuthenticateResponse{}, nil -} - -func (m *mockAuthClient) GetRole(ctx context.Context, name string) (types.Role, error) { - result := m.Called(ctx, name) - return result.Get(0).(types.Role), result.Error(1) -} - -// anyContext is an argument matcher for testify mocks that matches any context. -var anyContext any = mock.MatchedBy(func(context.Context) bool { return true }) diff --git a/tool/tctl/common/plugin/scim.go b/tool/tctl/common/plugin/scim.go index 492736a55347a..5a4f4acaf8a72 100644 --- a/tool/tctl/common/plugin/scim.go +++ b/tool/tctl/common/plugin/scim.go @@ -66,7 +66,7 @@ func (p *PluginsCommand) initInstallSCIM(parent *kingpin.CmdClause) { // InstallSCIM implements `tctl plugins install scim`, installing a SCIM integration // plugin into the teleport cluster -func (p *PluginsCommand) InstallSCIM(ctx context.Context, args installPluginArgs) error { +func (p *PluginsCommand) InstallSCIM(ctx context.Context, args pluginServices) error { scimArgs := p.install.scim pluginName := types.PluginTypeSCIM