diff --git a/api/types/integration_awsoidc.go b/api/types/integration_awsoidc.go index 2e8adb598c4a1..4c1cbaff9cec5 100644 --- a/api/types/integration_awsoidc.go +++ b/api/types/integration_awsoidc.go @@ -22,8 +22,13 @@ const ( // This value must match the Audience defined in the IAM Identity Provider of the Integration. IntegrationAWSOIDCAudience = "discover.teleport" - // IntegrationAWSOIDCSubject identifies the system that is going to use the token. + // IntegrationAWSOIDCSubject identifies the system that is going to use the + // token as the Teleport Proxy. IntegrationAWSOIDCSubject = "system:proxy" + + // IntegrationAWSOIDCSubject identifies the system that is going to use the + // token as the Teleport Auth service. + IntegrationAWSOIDCSubjectAuth = "system:auth" ) // GenerateAWSOIDCTokenRequest are the parameters used to request an AWS OIDC Integration token. diff --git a/lib/auth/externalcloudaudit.go b/lib/auth/externalcloudaudit.go new file mode 100644 index 0000000000000..5f2976fad823e --- /dev/null +++ b/lib/auth/externalcloudaudit.go @@ -0,0 +1,68 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/integrations/externalcloudaudit" + "github.com/gravitational/teleport/lib/jwt" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils/oidc" +) + +// GenerateExternalCloudAuditOIDCToken generates a signed OIDC token for use by +// the ExternalCloudAudit feature when authenticating to customer AWS accounts. +func (a *Server) GenerateExternalCloudAuditOIDCToken(ctx context.Context) (string, error) { + clusterName, err := a.GetDomainName() + if err != nil { + return "", trace.Wrap(err) + } + + ca, err := a.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.OIDCIdPCA, + DomainName: clusterName, + }, true /*loadKeys*/) + if err != nil { + return "", trace.Wrap(err) + } + + signer, err := a.GetKeyStore().GetJWTSigner(ctx, ca) + if err != nil { + return "", trace.Wrap(err) + } + + privateKey, err := services.GetJWTSigner(signer, ca.GetClusterName(), a.clock) + if err != nil { + return "", trace.Wrap(err) + } + + issuer, err := oidc.IssuerForCluster(ctx, a) + if err != nil { + return "", trace.Wrap(err) + } + + token, err := privateKey.SignAWSOIDC(jwt.SignParams{ + Username: a.ServerID, + Audience: types.IntegrationAWSOIDCAudience, + Subject: types.IntegrationAWSOIDCSubjectAuth, + Issuer: issuer, + Expires: a.clock.Now().Add(externalcloudaudit.TokenLifetime), + }) + return token, trace.Wrap(err) +} diff --git a/lib/integrations/awsoidc/idp_thumbprint.go b/lib/integrations/awsoidc/idp_thumbprint.go index 35f33a0aea70f..2241e7e5f1be8 100644 --- a/lib/integrations/awsoidc/idp_thumbprint.go +++ b/lib/integrations/awsoidc/idp_thumbprint.go @@ -26,6 +26,7 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/lib" + "github.com/gravitational/teleport/lib/utils/oidc" ) // ThumbprintIdP returns the thumbprint as required by AWS when adding an OIDC Identity Provider. @@ -34,7 +35,7 @@ import ( // Returns the thumbprint of the top intermediate CA that signed the TLS cert used to serve HTTPS requests. // In case of a self signed certificate, then it returns the thumbprint of the TLS cert itself. func ThumbprintIdP(ctx context.Context, publicAddress string) (string, error) { - issuer, err := IssuerFromPublicAddress(publicAddress) + issuer, err := oidc.IssuerFromPublicAddress(publicAddress) if err != nil { return "", trace.Wrap(err) } diff --git a/lib/integrations/externalcloudaudit/configurator.go b/lib/integrations/externalcloudaudit/configurator.go new file mode 100644 index 0000000000000..e860315592b58 --- /dev/null +++ b/lib/integrations/externalcloudaudit/configurator.go @@ -0,0 +1,443 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package externalcloudaudit + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/api/types/externalcloudaudit" + "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/services" +) + +const ( + // TokenLifetime is the lifetime of OIDC tokens used by the + // ExternalCloudAudit service with the AWS OIDC integration. + TokenLifetime = time.Hour + + refreshBeforeExpirationPeriod = 15 * time.Minute + refreshCheckInterval = 30 * time.Second + retrieveTimeout = 30 * time.Second +) + +// Configurator provides functionality necessary for configuring the External +// Cloud Audit feature. +// +// Specifically: +// - IsUsed() reports whether the feature is currently activated and in use. +// - GetSpec() provides the current cluster ExternalCloudAuditSpec +// - CredentialsProvider() provides AWS credentials for the necessary customer +// resources that can be used with aws-sdk-go-v2 +// - CredentialsProviderSDKV1() provides AWS credentials for the necessary customer +// resources that can be used with aws-sdk-go +// +// Configurator is a dependency to both the S3 session uploader and the Athena +// audit logger. They are both initialized before Auth. However, Auth needs to +// be initialized in order to provide signatures for the OIDC tokens. That's +// why SetGenerateOIDCTokenFn() must be called after auth is initialized to inject +// the OIDC token source dynamically. +// +// If auth needs to emit any events during initialization (before +// SetGenerateOIDCTokenFn is called) that is okay. Events are written to +// SQS first, credentials from the Configurator are not needed until the batcher +// reads the events from SQS and tries to write a batch to the customer S3 +// bucket. If the batcher tries to write a batch before the Configurator is +// initialized and gets an error when trying to retrieve credentials, that's +// still okay, it will always retry. +type Configurator struct { + // spec is set during initialization of the Configurator. It won't + // change, because every change of spec triggers an Auth service reload. + spec *externalcloudaudit.ExternalCloudAuditSpec + isUsed bool + + credentialsCache *credentialsCache +} + +// Options holds options for the Configurator. +type Options struct { + clock clockwork.Clock + stsClient stscreds.AssumeRoleWithWebIdentityAPIClient +} + +func (o *Options) setDefaults(ctx context.Context, region string) error { + if o.clock == nil { + o.clock = clockwork.NewRealClock() + } + if o.stsClient == nil { + var useFips aws.FIPSEndpointState + if modules.GetModules().IsBoringBinary() { + useFips = aws.FIPSEndpointStateEnabled + } + cfg, err := config.LoadDefaultConfig( + ctx, + config.WithRegion(region), + config.WithUseFIPSEndpoint(useFips), + config.WithRetryMaxAttempts(10), + ) + if err != nil { + return trace.Wrap(err) + } + o.stsClient = sts.NewFromConfig(cfg) + } + return nil +} + +// WithClock is a functional option to set the clock. +func WithClock(clock clockwork.Clock) func(*Options) { + return func(opts *Options) { + opts.clock = clock + } +} + +// WithSTSClient is a functional option to set the sts client. +func WithSTSClient(clt stscreds.AssumeRoleWithWebIdentityAPIClient) func(*Options) { + return func(opts *Options) { + opts.stsClient = clt + } +} + +// NewConfigurator returns a new Configurator set up with the current active +// cluster ExternalCloudAudit spec from [ecaSvc]. +// +// If the External Cloud Audit feature is not used in this cluster then a valid +// instance will be returned where IsUsed() will return false. +func NewConfigurator(ctx context.Context, ecaSvc services.ExternalCloudAuditGetter, integrationSvc services.IntegrationsGetter, optFns ...func(*Options)) (*Configurator, error) { + active, err := ecaSvc.GetClusterExternalCloudAudit(ctx) + if err != nil { + if trace.IsNotFound(err) { + return &Configurator{isUsed: false}, nil + } + return nil, trace.Wrap(err) + } + return newConfigurator(ctx, &active.Spec, integrationSvc, optFns...) +} + +// NewDraftConfigurator is equivalent to NewConfigurator but is based on the +// current *draft* ExternalCloudAudit configuration instead of the active +// configuration. +// +// If a draft ExternalCloudAudit configuration is not found, an error will be +// returned. +func NewDraftConfigurator(ctx context.Context, ecaSvc services.ExternalCloudAuditGetter, integrationSvc services.IntegrationsGetter, optFns ...func(*Options)) (*Configurator, error) { + draft, err := ecaSvc.GetDraftExternalCloudAudit(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + return newConfigurator(ctx, &draft.Spec, integrationSvc, optFns...) +} + +func newConfigurator(ctx context.Context, spec *externalcloudaudit.ExternalCloudAuditSpec, integrationSvc services.IntegrationsGetter, optFns ...func(*Options)) (*Configurator, error) { + // ExternalCloudAudit is only available in Cloud Enterprise + // (IsUsageBasedBilling indicates Teleport Team, where this is not supported) + if !modules.GetModules().Features().Cloud || modules.GetModules().Features().IsUsageBasedBilling { + return &Configurator{isUsed: false}, nil + } + + oidcIntegrationName := spec.IntegrationName + integration, err := integrationSvc.GetIntegration(ctx, oidcIntegrationName) + if err != nil { + if trace.IsNotFound(err) { + return nil, trace.NotFound( + "ExternalCloudAudit: configured AWS OIDC integration %q not found", + oidcIntegrationName) + } + } + awsOIDCSpec := integration.GetAWSOIDCIntegrationSpec() + if awsOIDCSpec == nil { + return nil, trace.NotFound( + "ExternalCloudAudit: configured integration %q does not appear to be an AWS OIDC integration", + oidcIntegrationName) + } + awsRoleARN := awsOIDCSpec.RoleARN + + options := &Options{} + for _, optFn := range optFns { + optFn(options) + } + if err := options.setDefaults(ctx, spec.Region); err != nil { + return nil, trace.Wrap(err) + } + + credentialsCache, err := newCredentialsCache(ctx, spec.Region, awsRoleARN, options) + if err != nil { + return nil, trace.Wrap(err) + } + go credentialsCache.run(ctx) + + return &Configurator{ + isUsed: true, + spec: spec, + credentialsCache: credentialsCache, + }, nil +} + +// IsUsed returns a boolean indicating whether the ExternalCloudAudit feature is +// currently in active use. +func (c *Configurator) IsUsed() bool { + return c.isUsed +} + +// GetSpec returns the current active ExternalCloudAuditSpec. +func (c *Configurator) GetSpec() *externalcloudaudit.ExternalCloudAuditSpec { + return c.spec +} + +// GenerateOIDCTokenFn is a function that should return a valid, signed JWT for +// authenticating to AWS via OIDC. +type GenerateOIDCTokenFn func(ctx context.Context) (string, error) + +// SetGenerateOIDCTokenFn sets the source of OIDC tokens for this Configurator. +func (c *Configurator) SetGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) { + c.credentialsCache.setGenerateOIDCTokenFn(fn) +} + +// CredentialsProvider returns an aws.CredentialsProvider that can be used to +// authenticate with the customer AWS account via the configured AWS OIDC +// integration with aws-sdk-go-v2. +func (p *Configurator) CredentialsProvider() aws.CredentialsProvider { + return p.credentialsCache +} + +// CredentialsProviderSDKV1 returns a credentials.ProviderWithContext that can be used to +// authenticate with the customer AWS account via the configured AWS OIDC +// integration with aws-sdk-go. +func (p *Configurator) CredentialsProviderSDKV1() credentials.ProviderWithContext { + return &v1Adapter{cc: p.credentialsCache} +} + +// WaitForFirstCredentials waits for the internal credentials cache to finish +// fetching its first credentials (or getting an error attempting to do so). +// This can be called after SetGenerateOIDCTokenFn to make sure any returned +// credential providers won't return errors simply due to the cache not being +// ready yet. +func (p *Configurator) WaitForFirstCredentials(ctx context.Context) { + p.credentialsCache.waitForFirstCredsOrErr(ctx) +} + +// credentialsCache is used to store and refresh AWS credentials used with +// AWS OIDC integration. +// +// Credentials are valid for 1h, but they cannot be refreshed if Proxy is down, +// so we attempt to refresh the credentials early and retry on failure. +// +// credentialsCache is a dependency to both the s3 session uploader and the +// athena audit logger. They are both initialized before auth. However AWS +// credentials using OIDC integration can be obtained only after auth is +// initialized. That's why generateOIDCTokenFn is injected dynamically after +// auth is initialized. Before initialization, credentialsCache will return +// an error on any Retrieve call. +type credentialsCache struct { + log *logrus.Entry + + roleARN string + + // generateOIDCTokenFn is dynamically set after auth is initialized. + generateOIDCTokenFn GenerateOIDCTokenFn + + // initialized communicates (via closing channel) that generateOIDCTokenFn is set. + initialized chan struct{} + closeInitialized func() + + // gotFirstCredsOrErr communicates (via closing channel) that the first + // credsOrErr has been set. + gotFirstCredsOrErr chan struct{} + closeGotFirstCredsOrErr func() + + credsOrErr credsOrErr + credsOrErrMu sync.RWMutex + + stsClient stscreds.AssumeRoleWithWebIdentityAPIClient + clock clockwork.Clock +} + +type credsOrErr struct { + creds aws.Credentials + err error +} + +func newCredentialsCache(ctx context.Context, region, roleARN string, options *Options) (*credentialsCache, error) { + initialized := make(chan struct{}) + gotFirstCredsOrErr := make(chan struct{}) + return &credentialsCache{ + roleARN: roleARN, + log: logrus.WithField(trace.Component, "ExternalCloudAudit.CredentialsCache"), + initialized: initialized, + closeInitialized: sync.OnceFunc(func() { close(initialized) }), + gotFirstCredsOrErr: gotFirstCredsOrErr, + closeGotFirstCredsOrErr: sync.OnceFunc(func() { close(gotFirstCredsOrErr) }), + credsOrErr: credsOrErr{ + err: errors.New("ExternalCloudAudit: credential cache not yet initialized"), + }, + clock: options.clock, + stsClient: options.stsClient, + }, nil +} + +func (cc *credentialsCache) setGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) { + cc.generateOIDCTokenFn = fn + cc.closeInitialized() +} + +// Retrieve implements [aws.CredentialsProvider] and returns the latest cached +// credentials, or an error if no credentials have been generated yet or the +// last generated credentials have expired. +func (cc *credentialsCache) Retrieve(ctx context.Context) (aws.Credentials, error) { + cc.credsOrErrMu.RLock() + defer cc.credsOrErrMu.RUnlock() + return cc.credsOrErr.creds, cc.credsOrErr.err +} + +func (cc *credentialsCache) run(ctx context.Context) { + // Wait for initialized signal before running loop. + select { + case <-cc.initialized: + case <-ctx.Done(): + cc.log.Debug("Context canceled before initialized.") + return + } + + cc.refreshIfNeeded(ctx) + + ticker := cc.clock.NewTicker(refreshCheckInterval) + defer ticker.Stop() + for { + select { + case <-ticker.Chan(): + cc.refreshIfNeeded(ctx) + case <-ctx.Done(): + cc.log.Debugf("Context canceled, stopping refresh loop.") + return + } + } +} + +func (cc *credentialsCache) refreshIfNeeded(ctx context.Context) { + credsFromCache, err := cc.Retrieve(ctx) + if err == nil && + credsFromCache.HasKeys() && + cc.clock.Now().Add(refreshBeforeExpirationPeriod).Before(credsFromCache.Expires) { + // No need to refresh, credentials in cache are still valid for longer + // than refreshBeforeExpirationPeriod + return + } + cc.log.Debugf("Refreshing credentials.") + + creds, err := cc.refresh(ctx) + if err != nil { + // If we were not able to refresh, check if existing credentials in cache are still valid. + // If yes, just log debug, it will be retried on next interval check. + if credsFromCache.HasKeys() && cc.clock.Now().Before(credsFromCache.Expires) { + cc.log.Warnf("Failed to retrieve new credentials: %v", err) + cc.log.Debugf("Using existing credentials expiring in %s.", credsFromCache.Expires.Sub(cc.clock.Now()).Round(time.Second).String()) + return + } + // If existing creds are expired, update cached error. + cc.setCredsOrErr(credsOrErr{err: trace.Wrap(err)}) + return + } + // Refresh went well, update cached creds. + cc.setCredsOrErr(credsOrErr{creds: creds}) + cc.log.Debugf("Successfully refreshed credentials, new expiry at %v", creds.Expires) +} + +func (cc *credentialsCache) setCredsOrErr(coe credsOrErr) { + cc.credsOrErrMu.Lock() + defer cc.credsOrErrMu.Unlock() + cc.credsOrErr = coe + cc.closeGotFirstCredsOrErr() +} + +func (cc *credentialsCache) refresh(ctx context.Context) (aws.Credentials, error) { + oidcToken, err := cc.generateOIDCTokenFn(ctx) + if err != nil { + return aws.Credentials{}, trace.Wrap(err) + } + + roleProvider := stscreds.NewWebIdentityRoleProvider( + cc.stsClient, + cc.roleARN, + identityToken(oidcToken), + func(wiro *stscreds.WebIdentityRoleOptions) { + wiro.Duration = TokenLifetime + }, + ) + + ctx, cancel := context.WithTimeout(ctx, retrieveTimeout) + defer cancel() + + creds, err := roleProvider.Retrieve(ctx) + return creds, trace.Wrap(err) +} + +func (cc *credentialsCache) waitForFirstCredsOrErr(ctx context.Context) { + select { + case <-ctx.Done(): + case <-cc.gotFirstCredsOrErr: + } +} + +// identityToken is an implementation of [stscreds.IdentityTokenRetriever] for returning a static token. +type identityToken string + +// GetIdentityToken returns the token configured. +func (j identityToken) GetIdentityToken() ([]byte, error) { + return []byte(j), nil +} + +// v1Adapter wraps the credentialsCache to implement +// [credentials.ProviderWithContext] used by aws-sdk-go (v1). +type v1Adapter struct { + cc *credentialsCache +} + +var _ credentials.ProviderWithContext = (*v1Adapter)(nil) + +// RetrieveWithContext returns cached credentials. +func (a *v1Adapter) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { + credsV2, err := a.cc.Retrieve(ctx) + if err != nil { + return credentials.Value{}, trace.Wrap(err) + } + + return credentials.Value{ + AccessKeyID: credsV2.AccessKeyID, + SecretAccessKey: credsV2.SecretAccessKey, + SessionToken: credsV2.SessionToken, + ProviderName: credsV2.Source, + }, nil +} + +// Retrieve returns cached credentials. +func (a *v1Adapter) Retrieve() (credentials.Value, error) { + return a.RetrieveWithContext(context.Background()) +} + +// IsExpired always returns true in order to opt out of AWS SDK credential +// caching. Retrieve(WithContext) already returns cached credentials. +func (a *v1Adapter) IsExpired() bool { + return true +} diff --git a/lib/integrations/externalcloudaudit/configurator_test.go b/lib/integrations/externalcloudaudit/configurator_test.go new file mode 100644 index 0000000000000..5564a86f7472e --- /dev/null +++ b/lib/integrations/externalcloudaudit/configurator_test.go @@ -0,0 +1,397 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package externalcloudaudit + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/aws/aws-sdk-go/aws" + "github.com/google/uuid" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/externalcloudaudit" + "github.com/gravitational/teleport/api/types/header" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" +) + +func testOIDCIntegration(t *testing.T) *types.IntegrationV1 { + oidcIntegration, err := types.NewIntegrationAWSOIDC( + types.Metadata{Name: "aws-integration-1"}, + &types.AWSOIDCIntegrationSpecV1{ + RoleARN: "role1", + }, + ) + require.NoError(t, err) + return oidcIntegration +} + +func testDraftExternalCloudAudit(t *testing.T) *externalcloudaudit.ExternalCloudAudit { + draft, err := externalcloudaudit.NewDraftExternalCloudAudit(header.Metadata{}, externalcloudaudit.ExternalCloudAuditSpec{ + IntegrationName: "aws-integration-1", + PolicyName: "ecaPolicy", + Region: "us-west-2", + SessionsRecordingsURI: "s3://bucket/sess_rec", + AthenaWorkgroup: "primary", + GlueDatabase: "teleport_db", + GlueTable: "teleport_table", + AuditEventsLongTermURI: "s3://bucket/events", + AthenaResultsURI: "s3://bucket/results", + }) + require.NoError(t, err) + return draft +} + +func TestConfiguratorIsUsed(t *testing.T) { + ctx := context.Background() + + draftConfig := testDraftExternalCloudAudit(t) + tests := []struct { + name string + modules *modules.TestModules + resourceServiceFn func(t *testing.T, s services.ExternalCloudAudits) + wantIsUsed bool + }{ + { + name: "not cloud", + modules: &modules.TestModules{ + TestFeatures: modules.Features{ + Cloud: false, + }, + }, + wantIsUsed: false, + }, + { + name: "cloud team", + modules: &modules.TestModules{ + TestFeatures: modules.Features{ + Cloud: true, + IsUsageBasedBilling: true, + }, + }, + wantIsUsed: false, + }, + { + name: "cloud enterprise without config", + modules: &modules.TestModules{ + TestFeatures: modules.Features{ + Cloud: true, + IsUsageBasedBilling: false, + }, + }, + wantIsUsed: false, + }, + { + name: "cloud enterprise with only draft", + modules: &modules.TestModules{ + TestFeatures: modules.Features{ + Cloud: true, + IsUsageBasedBilling: false, + }, + }, + // Just create draft, external cloud audit should be disabled, it's + // active only when the draft is promoted to cluster external cloud + // audit resource. + resourceServiceFn: func(t *testing.T, s services.ExternalCloudAudits) { + _, err := s.UpsertDraftExternalCloudAudit(ctx, draftConfig) + require.NoError(t, err) + }, + wantIsUsed: false, + }, + { + name: "cloud enterprise with cluster config", + modules: &modules.TestModules{ + TestFeatures: modules.Features{ + Cloud: true, + IsUsageBasedBilling: false, + }, + }, + // Create draft and promote it to cluster. + resourceServiceFn: func(t *testing.T, s services.ExternalCloudAudits) { + _, err := s.UpsertDraftExternalCloudAudit(ctx, draftConfig) + require.NoError(t, err) + err = s.PromoteToClusterExternalCloudAudit(ctx) + require.NoError(t, err) + }, + wantIsUsed: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mem, err := memory.New(memory.Config{}) + require.NoError(t, err) + + integrationSvc, err := local.NewIntegrationsService(mem) + require.NoError(t, err) + _, err = integrationSvc.CreateIntegration(ctx, testOIDCIntegration(t)) + require.NoError(t, err) + + ecaSvc := local.NewExternalCloudAuditService(mem) + if tt.resourceServiceFn != nil { + tt.resourceServiceFn(t, ecaSvc) + } + + modules.SetTestModules(t, tt.modules) + + c, err := NewConfigurator(ctx, ecaSvc, integrationSvc) + require.NoError(t, err) + require.Equal(t, tt.wantIsUsed, c.IsUsed(), + "Configurator.IsUsed() = %v, want %v", c.IsUsed(), tt.wantIsUsed) + if c.IsUsed() { + require.Equal(t, draftConfig.Spec, *c.GetSpec()) + } + }) + } +} + +func TestCredentialsCache(t *testing.T) { + logrus.SetLevel(logrus.DebugLevel) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + modules.SetTestModules(t, &modules.TestModules{ + TestFeatures: modules.Features{ + Cloud: true, + IsUsageBasedBilling: false, + }, + }) + + mem, err := memory.New(memory.Config{}) + require.NoError(t, err) + + // Pre-req: existing AWS OIDC integration + integrationSvc, err := local.NewIntegrationsService(mem) + require.NoError(t, err) + oidcIntegration := testOIDCIntegration(t) + _, err = integrationSvc.CreateIntegration(ctx, oidcIntegration) + require.NoError(t, err) + + // Pre-req: existing cluster ExternalCloudAudit configuration + draftConfig := testDraftExternalCloudAudit(t) + svc := local.NewExternalCloudAuditService(mem) + _, err = svc.UpsertDraftExternalCloudAudit(ctx, draftConfig) + require.NoError(t, err) + err = svc.PromoteToClusterExternalCloudAudit(ctx) + require.NoError(t, err) + + clock := clockwork.NewFakeClock() + stsClient := &fakeSTSClient{ + clock: clock, + } + + // Create a configurator with a fake clock and STS client. + c, err := NewConfigurator(ctx, svc, integrationSvc, WithClock(clock), WithSTSClient(stsClient)) + require.NoError(t, err) + require.True(t, c.IsUsed()) + + // Set the GenerateOIDCTokenFn to a dumb faked function. + c.SetGenerateOIDCTokenFn(func(ctx context.Context) (string, error) { + return uuid.NewString(), nil + }) + + provider := c.CredentialsProvider() + providerV1 := c.CredentialsProviderSDKV1() + + checkRetrieveCredentials := func(t require.TestingT, expectErr error) { + _, err = providerV1.RetrieveWithContext(ctx) + assert.ErrorIs(t, err, expectErr) + _, err := provider.Retrieve(ctx) + assert.ErrorIs(t, err, expectErr) + } + checkRetrieveCredentialsWithExpiry := func(t require.TestingT, expectExpiry time.Time) { + _, err = providerV1.RetrieveWithContext(ctx) + assert.NoError(t, err) + creds, err := provider.Retrieve(ctx) + assert.NoError(t, err) + if err == nil { + assert.WithinDuration(t, expectExpiry, creds.Expires, time.Minute) + } + } + + // Assert that credentials can be retrieved when everything is happy. + // EventuallyWithT is necessary to allow credentialsCache.run to be + // scheduled after SetGenerateOIDCTokenFn above. + initialCredentialExpiry := clock.Now().Add(TokenLifetime) + require.EventuallyWithT(t, func(t *assert.CollectT) { + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + }, time.Second, time.Millisecond) + + // Assert that the good cached credentials are still used even if sts starts + // returning errors. + stsError := errors.New("test error") + stsClient.setError(stsError) + // Test immediately + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + // Advance to 1 minute before first refresh attempt + clock.Advance(TokenLifetime - refreshBeforeExpirationPeriod - time.Minute) + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + // Advance to 1 minute after first refresh attempt + clock.Advance(2 * time.Minute) + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + // Advance to 1 minute before credential expiry + clock.Advance(refreshBeforeExpirationPeriod - 2*time.Minute) + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + + // Advance 1 minute past the credential expiry and make sure we get the + // expected error. + clock.Advance(2 * time.Minute) + require.EventuallyWithT(t, func(t *assert.CollectT) { + checkRetrieveCredentials(t, stsError) + }, time.Second, time.Millisecond) + + // Fix STS and make sure we stop getting errors within refreshCheckInterval + stsClient.setError(nil) + clock.Advance(refreshCheckInterval) + newCredentialExpiry := clock.Now().Add(TokenLifetime) + require.EventuallyWithT(t, func(t *assert.CollectT) { + checkRetrieveCredentialsWithExpiry(t, newCredentialExpiry) + }, time.Second, time.Millisecond) + + // Test that even if STS is returning errors for 5 minutes surrounding the + // expected refresh time and the expiry time, no errors are observed. + expectedRefreshTime := newCredentialExpiry.Add(-refreshBeforeExpirationPeriod) + credentialsUpdated := false + for done := newCredentialExpiry.Add(10 * time.Minute); clock.Now().Before(done); clock.Advance(time.Minute) { + if clock.Now().Sub(expectedRefreshTime).Abs() < 5*time.Minute || + clock.Now().Sub(newCredentialExpiry).Abs() < 5*time.Minute { + stsClient.setError(stsError) + } else { + stsClient.setError(nil) + if !credentialsUpdated && clock.Now().After(expectedRefreshTime) { + // For the test we need to make sure the credentials actually get + // updated during the window between expectedRefreshTime and + // newCredentialExpiry where STS is not returning errors, and we might + // need to sleep a bit to give the cache run loop time to get scheduled + // and updated the cached creds. To solve that we wait for the current + // credential expiry to match the newer value. + expectedExpiry := expectedRefreshTime.Add(5*time.Minute + TokenLifetime) + require.EventuallyWithT(t, func(t *assert.CollectT) { + creds, err := provider.Retrieve(ctx) + assert.NoError(t, err) + assert.WithinDuration(t, expectedExpiry, creds.Expires, 2*time.Minute) + }, time.Second, time.Millisecond) + credentialsUpdated = true + } + } + + // Assert that there is never an error getting credentials. + checkRetrieveCredentials(t, nil) + + } +} + +// TestDraftConfigurator models the way the connection tester will use the +// configurator to synchronously get credentials for the current draft +// ExternalCloudAuditSpec. +func TestDraftConfigurator(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + modules.SetTestModules(t, &modules.TestModules{ + TestFeatures: modules.Features{ + Cloud: true, + IsUsageBasedBilling: false, + }, + }) + + mem, err := memory.New(memory.Config{}) + require.NoError(t, err) + + // Pre-req: existing AWS OIDC integration + integrationSvc, err := local.NewIntegrationsService(mem) + require.NoError(t, err) + oidcIntegration := testOIDCIntegration(t) + _, err = integrationSvc.CreateIntegration(ctx, oidcIntegration) + require.NoError(t, err) + + // Pre-req: existing draft ExternalCloudAudit configuration + draftConfig := testDraftExternalCloudAudit(t) + svc := local.NewExternalCloudAuditService(mem) + _, err = svc.UpsertDraftExternalCloudAudit(ctx, draftConfig) + require.NoError(t, err) + + clock := clockwork.NewFakeClock() + stsClient := &fakeSTSClient{ + clock: clock, + } + + // Create a draft configurator with a fake clock and STS client. + c, err := NewDraftConfigurator(ctx, svc, integrationSvc, WithClock(clock), WithSTSClient(stsClient)) + require.NoError(t, err) + require.True(t, c.IsUsed()) + + // Set the GenerateOIDCTokenFn to a faked function for the test. + c.SetGenerateOIDCTokenFn(func(ctx context.Context) (string, error) { + // Can sleep here to confirm that WaitForFirstCredentials works. + // time.Sleep(time.Second) + return uuid.NewString(), nil + }) + + // Wait for the first set of credentials to be ready. + c.WaitForFirstCredentials(ctx) + + // Get credentials, make sure there's no error and the expiry looks right. + provider := c.CredentialsProvider() + creds, err := provider.Retrieve(ctx) + require.NoError(t, err) + require.WithinDuration(t, clock.Now().Add(TokenLifetime), creds.Expires, time.Minute) +} + +type fakeSTSClient struct { + clock clockwork.Clock + err error + sync.Mutex +} + +func (f *fakeSTSClient) setError(err error) { + f.Lock() + f.err = err + f.Unlock() +} + +func (f *fakeSTSClient) getError() error { + f.Lock() + defer f.Unlock() + return f.err +} + +func (f *fakeSTSClient) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + if err := f.getError(); err != nil { + return nil, err + } + + expiration := f.clock.Now().Add(time.Second * time.Duration(*params.DurationSeconds)) + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &ststypes.Credentials{ + Expiration: &expiration, + // These are example values taken from https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + SessionToken: aws.String("AQoDYXdzEE0a8ANXXXXXXXXNO1ewxE5TijQyp+IEXAMPLE"), + SecretAccessKey: aws.String("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY"), + AccessKeyId: aws.String("ASgeIAIOSFODNN7EXAMPLE"), + }, + }, nil +} diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 318d5fcdcb549..f1f18a89c05dd 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -45,6 +45,7 @@ import ( "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/oidc" proxyutils "github.com/gravitational/teleport/lib/utils/proxy" ) @@ -531,7 +532,7 @@ func (s *localSite) setupTunnelForOpenSSHEICENode(ctx context.Context, targetSer return nil, trace.BadParameter("missing aws cloud metadata") } - issuer, err := awsoidc.IssuerForCluster(ctx, s.accessPoint) + issuer, err := oidc.IssuerForCluster(ctx, s.accessPoint) if err != nil { return nil, trace.BadParameter("failed to get issuer %v", err) } diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index 0ad4d007af523..235528aeadbf5 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/interval" + "github.com/gravitational/teleport/lib/utils/oidc" ) const ( @@ -76,7 +77,7 @@ func (process *TeleportProcess) initDeployServiceUpdater() error { } } - issuer, err := awsoidc.IssuerFromPublicAddress(process.proxyPublicAddr().Addr) + issuer, err := oidc.IssuerFromPublicAddress(process.proxyPublicAddr().Addr) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 4a02cd19f6cfa..790110a559672 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -53,6 +53,7 @@ import ( "github.com/gravitational/teleport/lib/sshutils/x11" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/oidc" ) // Server is a forwarding server. Server is used to create a single in-memory @@ -667,7 +668,7 @@ func (s *Server) sendSSHPublicKeyToTarget(ctx context.Context) (ssh.Signer, erro return nil, trace.BadParameter("missing aws cloud metadata") } - issuer, err := awsoidc.IssuerForCluster(ctx, s.authClient) + issuer, err := oidc.IssuerForCluster(ctx, s.authClient) if err != nil { return nil, trace.BadParameter("failed to get issuer %v", err) } diff --git a/lib/integrations/awsoidc/issuer.go b/lib/utils/oidc/issuer.go similarity index 99% rename from lib/integrations/awsoidc/issuer.go rename to lib/utils/oidc/issuer.go index ff4d60f3c751a..26836121471ae 100644 --- a/lib/integrations/awsoidc/issuer.go +++ b/lib/utils/oidc/issuer.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package awsoidc +package oidc import ( "context" diff --git a/lib/integrations/awsoidc/issuer_test.go b/lib/utils/oidc/issuer_test.go similarity index 88% rename from lib/integrations/awsoidc/issuer_test.go rename to lib/utils/oidc/issuer_test.go index 8708ce604a3ec..cf7675dc91f10 100644 --- a/lib/integrations/awsoidc/issuer_test.go +++ b/lib/utils/oidc/issuer_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package awsoidc +package oidc import ( "context" @@ -106,7 +106,7 @@ func TestIssuerForCluster(t *testing.T) { { name: "api returns not found", mockErr: &trace.NotFoundError{}, - checkErr: notFounCheck, + checkErr: notFoundCheck, }, { name: "api returns an empty list of proxies", @@ -129,3 +129,11 @@ func TestIssuerForCluster(t *testing.T) { }) } } + +func badParameterCheck(t require.TestingT, err error, msgAndArgs ...interface{}) { + require.True(t, trace.IsBadParameter(err), `expected "bad parameter", but got %v`, err) +} + +func notFoundCheck(t require.TestingT, err error, msgAndArgs ...interface{}) { + require.True(t, trace.IsNotFound(err), `expected "not found", but got %v`, err) +} diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go index 77e02c06db2b7..cca23d74cef53 100644 --- a/lib/web/integrations_awsoidc.go +++ b/lib/web/integrations_awsoidc.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/reversetunnelclient" + "github.com/gravitational/teleport/lib/utils/oidc" "github.com/gravitational/teleport/lib/web/scripts/oneoff" "github.com/gravitational/teleport/lib/web/ui" ) @@ -93,7 +94,7 @@ func (h *Handler) awsOIDCClientRequest(ctx context.Context, region string, p htt return nil, trace.BadParameter("integration subkind (%s) mismatch", integration.GetSubKind()) } - issuer, err := awsoidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr) + issuer, err := oidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr) if err != nil { return nil, trace.Wrap(err) } @@ -470,7 +471,7 @@ func (h *Handler) awsOIDCConfigureIdP(w http.ResponseWriter, r *http.Request, p return nil, trace.BadParameter("invalid role %q", role) } - proxyAddr, err := awsoidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr) + proxyAddr, err := oidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/web/oidcidp.go b/lib/web/oidcidp.go index b196ed54ba606..dee948da4bf51 100644 --- a/lib/web/oidcidp.go +++ b/lib/web/oidcidp.go @@ -22,6 +22,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/jwt" + "github.com/gravitational/teleport/lib/utils/oidc" ) const ( @@ -31,7 +32,7 @@ const ( // openidConfiguration returns the openid-configuration for setting up the AWS OIDC Integration func (h *Handler) openidConfiguration(_ http.ResponseWriter, _ *http.Request, _ httprouter.Params) (interface{}, error) { - issuer, err := awsoidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr) + issuer, err := oidc.IssuerFromPublicAddress(h.cfg.PublicProxyAddr) if err != nil { return nil, trace.Wrap(err) }