diff --git a/api/types/database.go b/api/types/database.go index a38ca4f182f49..92910df16adb3 100644 --- a/api/types/database.go +++ b/api/types/database.go @@ -78,6 +78,10 @@ type Database interface { GetAWS() AWS // SetStatusAWS sets the database AWS metadata in the status field. SetStatusAWS(AWS) + // SetAWSExternalID sets the database AWS external ID in the Spec.AWS field. + SetAWSExternalID(id string) + // SetAWSAssumeRole sets the database AWS assume role arn in the Spec.AWS field. + SetAWSAssumeRole(roleARN string) // GetGCP returns GCP information for Cloud SQL databases. GetGCP() GCPCloudSQL // GetAzure returns Azure database server metadata. @@ -341,6 +345,16 @@ func (d *DatabaseV3) SetStatusAWS(aws AWS) { d.Status.AWS = aws } +// SetAWSExternalID sets the database AWS external ID in the Spec.AWS field. +func (d *DatabaseV3) SetAWSExternalID(id string) { + d.Spec.AWS.ExternalID = id +} + +// SetAWSAssumeRole sets the database AWS assume role arn in the Spec.AWS field. +func (d *DatabaseV3) SetAWSAssumeRole(roleARN string) { + d.Spec.AWS.AssumeRoleARN = roleARN +} + // GetGCP returns GCP information for Cloud SQL databases. func (d *DatabaseV3) GetGCP() GCPCloudSQL { return d.Spec.GCP diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index 56ea9666cd4e3..beccd6d28f614 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -253,13 +253,18 @@ type awsAssumeRoleOpts struct { // when getting an AWS session. type AWSAssumeRoleOptionFn func(*awsAssumeRoleOpts) +// WithAssumeRole configures options needed for assuming an AWS role. +func WithAssumeRole(roleARN, externalID string) AWSAssumeRoleOptionFn { + return func(options *awsAssumeRoleOpts) { + options.assumeRoleARN = roleARN + options.assumeRoleExternalID = externalID + } +} + // WithAssumeRoleFromAWSMeta extracts options needed from AWS metadata for // assuming an AWS role. func WithAssumeRoleFromAWSMeta(meta types.AWS) AWSAssumeRoleOptionFn { - return func(options *awsAssumeRoleOpts) { - options.assumeRoleARN = meta.AssumeRoleARN - options.assumeRoleExternalID = meta.ExternalID - } + return WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID) } // WithChainedAssumeRole sets a role to assume with a base session to use diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 76af88fd9cd4c..6d7e667d45ee8 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -39,11 +39,13 @@ import ( "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/aws/aws-sdk-go/service/eks" "github.com/aws/aws-sdk-go/service/eks/eksiface" + "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshift" "github.com/aws/aws-sdk-go/service/ssm" "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" @@ -918,9 +920,22 @@ func (m *mockGKEAPI) ListClusters(ctx context.Context, projectID string, locatio func TestDiscoveryDatabase(t *testing.T) { awsRedshiftResource, awsRedshiftDB := makeRedshiftCluster(t, "aws-redshift", "us-east-1") + awsRDSInstance, awsRDSDB := makeRDSInstance(t, "aws-rds", "us-west-1") azRedisResource, azRedisDB := makeAzureRedisServer(t, "az-redis", "sub1", "group1", "East US") + role := services.AssumeRole{RoleARN: "arn:aws:iam::123456789012:role/test-role", ExternalID: "test123"} + awsRDSDBWithRole := awsRDSDB.Copy() + awsRDSDBWithRole.SetAWSAssumeRole("arn:aws:iam::123456789012:role/test-role") + awsRDSDBWithRole.SetAWSExternalID("test123") + testClients := &cloud.TestCloudClients{ + STS: &mocks.STSMock{}, + RDS: &mocks.RDSMock{ + DBInstances: []*rds.DBInstance{awsRDSInstance}, + DBEngineVersions: []*rds.DBEngineVersion{ + {Engine: aws.String(services.RDSEnginePostgres)}, + }, + }, Redshift: &mocks.RedshiftMock{ Clusters: []*redshift.Cluster{awsRedshiftResource}, }, @@ -949,6 +964,16 @@ func TestDiscoveryDatabase(t *testing.T) { }}, expectDatabases: []types.Database{awsRedshiftDB}, }, + { + name: "discover AWS database with assumed role", + awsMatchers: []services.AWSMatcher{{ + Types: []string{services.AWSMatcherRDS}, + Tags: map[string]utils.Strings{types.Wildcard: {types.Wildcard}}, + Regions: []string{"us-west-1"}, + AssumeRole: role, + }}, + expectDatabases: []types.Database{awsRDSDBWithRole}, + }, { name: "discover Azure database", azureMatchers: []services.AzureMatcher{{ @@ -979,6 +1004,26 @@ func TestDiscoveryDatabase(t *testing.T) { }}, expectDatabases: []types.Database{awsRedshiftDB}, }, + { + name: "update existing database with assumed role", + existingDatabases: []types.Database{ + mustNewDatabase(t, types.Metadata{ + Name: "aws-rds", + Description: "should be updated", + Labels: map[string]string{types.OriginLabel: types.OriginCloud}, + }, types.DatabaseSpecV3{ + Protocol: "postgres", + URI: "should.be.updated.com:12345", + }), + }, + awsMatchers: []services.AWSMatcher{{ + Types: []string{services.AWSMatcherRDS}, + Tags: map[string]utils.Strings{types.Wildcard: {types.Wildcard}}, + Regions: []string{"us-west-1"}, + AssumeRole: role, + }}, + expectDatabases: []types.Database{awsRDSDBWithRole}, + }, { name: "delete existing database", existingDatabases: []types.Database{ @@ -1091,6 +1136,23 @@ func TestDiscoveryDatabase(t *testing.T) { } } +func makeRDSInstance(t *testing.T, name, region string) (*rds.DBInstance, types.Database) { + instance := &rds.DBInstance{ + DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db:%v", region, name)), + DBInstanceIdentifier: aws.String(name), + DbiResourceId: aws.String(uuid.New().String()), + Engine: aws.String(services.RDSEnginePostgres), + DBInstanceStatus: aws.String("available"), + Endpoint: &rds.Endpoint{ + Address: aws.String("localhost"), + Port: aws.Int64(5432), + }, + } + database, err := services.NewDatabaseFromRDSInstance(instance) + require.NoError(t, err) + return instance, database +} + func makeRedshiftCluster(t *testing.T, name, region string) (*redshift.Cluster, types.Database) { t.Helper() cluster := &redshift.Cluster{ diff --git a/lib/srv/discovery/fetchers/db/aws_elasticache.go b/lib/srv/discovery/fetchers/db/aws_elasticache.go index b46eba0ccd061..b648282b9bd74 100644 --- a/lib/srv/discovery/fetchers/db/aws_elasticache.go +++ b/lib/srv/discovery/fetchers/db/aws_elasticache.go @@ -39,6 +39,8 @@ type elastiCacheFetcherConfig struct { ElastiCache elasticacheiface.ElastiCacheAPI // Region is the AWS region to query databases in. Region string + // AssumeRole is the AWS IAM role to assume before discovering databases. + AssumeRole services.AssumeRole } // CheckAndSetDefaults validates the config and sets defaults. @@ -74,6 +76,7 @@ func newElastiCacheFetcher(config elastiCacheFetcherConfig) (common.Fetcher, err trace.Component: "watch:elasticache", "labels": config.Labels, "region": config.Region, + "role": config.AssumeRole, }), }, nil } @@ -168,6 +171,7 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels } } + applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil } diff --git a/lib/srv/discovery/fetchers/db/aws_elasticache_test.go b/lib/srv/discovery/fetchers/db/aws_elasticache_test.go index cf4ca20000275..7afb22efe3acd 100644 --- a/lib/srv/discovery/fetchers/db/aws_elasticache_test.go +++ b/lib/srv/discovery/fetchers/db/aws_elasticache_test.go @@ -48,12 +48,7 @@ func TestElastiCacheFetcher(t *testing.T) { aws.StringValue(elasticacheUnsupported.ARN): elasticacheUnsupportedTags, } - tests := []struct { - name string - inputClients cloud.AWSClients - inputLabels map[string]string - wantDatabases types.Databases - }{ + tests := []awsFetcherTest{ { name: "fetch all", inputClients: &cloud.TestCloudClients{ @@ -62,7 +57,7 @@ func TestElastiCacheFetcher(t *testing.T) { TagsByARN: elasticacheTagsByARN, }, }, - inputLabels: wildcardLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels), wantDatabases: types.Databases{elasticacheDatabaseProd, elasticacheDatabaseQA}, }, { @@ -73,7 +68,7 @@ func TestElastiCacheFetcher(t *testing.T) { TagsByARN: elasticacheTagsByARN, }, }, - inputLabels: envProdLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", envProdLabels), wantDatabases: types.Databases{elasticacheDatabaseProd}, }, { @@ -84,7 +79,7 @@ func TestElastiCacheFetcher(t *testing.T) { TagsByARN: elasticacheTagsByARN, }, }, - inputLabels: wildcardLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels), wantDatabases: types.Databases{elasticacheDatabaseProd}, }, { @@ -95,20 +90,11 @@ func TestElastiCacheFetcher(t *testing.T) { TagsByARN: elasticacheTagsByARN, }, }, - inputLabels: wildcardLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels), wantDatabases: types.Databases{elasticacheDatabaseProd}, }, } - - for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - fetchers := mustMakeAWSFetchersForMatcher(t, test.inputClients, services.AWSMatcherElastiCache, "us-east-2", toTypeLabels(test.inputLabels)) - require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers)) - }) - } + testAWSFetchers(t, tests...) } func makeElastiCacheCluster(t *testing.T, name, region, env string, opts ...func(*elasticache.ReplicationGroup)) (*elasticache.ReplicationGroup, types.Database, []*elasticache.Tag) { diff --git a/lib/srv/discovery/fetchers/db/aws_memorydb.go b/lib/srv/discovery/fetchers/db/aws_memorydb.go index 61248848210df..f538421940f26 100644 --- a/lib/srv/discovery/fetchers/db/aws_memorydb.go +++ b/lib/srv/discovery/fetchers/db/aws_memorydb.go @@ -39,6 +39,8 @@ type memoryDBFetcherConfig struct { MemoryDB memorydbiface.MemoryDBAPI // Region is the AWS region to query databases in. Region string + // AssumeRole is the AWS IAM role to assume before discovering databases. + AssumeRole services.AssumeRole } // CheckAndSetDefaults validates the config and sets defaults. @@ -74,6 +76,7 @@ func newMemoryDBFetcher(config memoryDBFetcherConfig) (common.Fetcher, error) { trace.Component: "watch:memorydb", "labels": config.Labels, "region": config.Region, + "role": config.AssumeRole, }), }, nil } @@ -136,6 +139,7 @@ func (f *memoryDBFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e databases = append(databases, database) } } + applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil } diff --git a/lib/srv/discovery/fetchers/db/aws_memorydb_test.go b/lib/srv/discovery/fetchers/db/aws_memorydb_test.go index 0ceaf52d234f6..ab9340a6eac66 100644 --- a/lib/srv/discovery/fetchers/db/aws_memorydb_test.go +++ b/lib/srv/discovery/fetchers/db/aws_memorydb_test.go @@ -47,12 +47,7 @@ func TestMemoryDBFetcher(t *testing.T) { aws.StringValue(memorydbUnsupported.ARN): memorydbUnsupportedTags, } - tests := []struct { - name string - inputClients cloud.AWSClients - inputLabels map[string]string - wantDatabases types.Databases - }{ + tests := []awsFetcherTest{ { name: "fetch all", inputClients: &cloud.TestCloudClients{ @@ -61,7 +56,7 @@ func TestMemoryDBFetcher(t *testing.T) { TagsByARN: memorydbTagsByARN, }, }, - inputLabels: wildcardLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", wildcardLabels), wantDatabases: types.Databases{memorydbDatabaseProd, memorydbDatabaseTest}, }, { @@ -72,7 +67,7 @@ func TestMemoryDBFetcher(t *testing.T) { TagsByARN: memorydbTagsByARN, }, }, - inputLabels: envProdLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", envProdLabels), wantDatabases: types.Databases{memorydbDatabaseProd}, }, { @@ -83,7 +78,7 @@ func TestMemoryDBFetcher(t *testing.T) { TagsByARN: memorydbTagsByARN, }, }, - inputLabels: wildcardLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", wildcardLabels), wantDatabases: types.Databases{memorydbDatabaseProd}, }, { @@ -94,20 +89,11 @@ func TestMemoryDBFetcher(t *testing.T) { TagsByARN: memorydbTagsByARN, }, }, - inputLabels: wildcardLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", wildcardLabels), wantDatabases: types.Databases{memorydbDatabaseProd}, }, } - - for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - fetchers := mustMakeAWSFetchersForMatcher(t, test.inputClients, services.AWSMatcherMemoryDB, "us-east-2", toTypeLabels(test.inputLabels)) - require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers)) - }) - } + testAWSFetchers(t, tests...) } func makeMemoryDBCluster(t *testing.T, name, region, env string, opts ...func(*memorydb.Cluster)) (*memorydb.Cluster, types.Database, []*memorydb.Tag) { diff --git a/lib/srv/discovery/fetchers/db/aws_rds.go b/lib/srv/discovery/fetchers/db/aws_rds.go index a93d32c56c632..5ce9b8a1066c1 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds.go +++ b/lib/srv/discovery/fetchers/db/aws_rds.go @@ -41,6 +41,8 @@ type rdsFetcherConfig struct { RDS rdsiface.RDSAPI // Region is the AWS region to query databases in. Region string + // AssumeRole is the AWS IAM role to assume before discovering databases. + AssumeRole services.AssumeRole } // CheckAndSetDefaults validates the config and sets defaults. @@ -76,6 +78,7 @@ func newRDSDBInstancesFetcher(config rdsFetcherConfig) (common.Fetcher, error) { trace.Component: "watch:rds", "labels": config.Labels, "region": config.Region, + "role": config.AssumeRole, }), }, nil } @@ -87,6 +90,7 @@ func (f *rdsDBInstancesFetcher) Get(ctx context.Context) (types.ResourcesWithLab return nil, trace.Wrap(err) } + applyAssumeRoleToDatabases(rdsDatabases, f.cfg.AssumeRole) return filterDatabasesByLabels(rdsDatabases, f.cfg.Labels, f.log).AsResources(), nil } @@ -172,6 +176,7 @@ func newRDSAuroraClustersFetcher(config rdsFetcherConfig) (common.Fetcher, error trace.Component: "watch:aurora", "labels": config.Labels, "region": config.Region, + "role": config.AssumeRole, }), }, nil } @@ -183,6 +188,7 @@ func (f *rdsAuroraClustersFetcher) Get(ctx context.Context) (types.ResourcesWith return nil, trace.Wrap(err) } + applyAssumeRoleToDatabases(auroraDatabases, f.cfg.AssumeRole) return filterDatabasesByLabels(auroraDatabases, f.cfg.Labels, f.log).AsResources(), nil } diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go index cac0f6dcf9e6f..8286a1c6c02bb 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go @@ -48,6 +48,7 @@ func newRDSDBProxyFetcher(config rdsFetcherConfig) (common.Fetcher, error) { trace.Component: "watch:rdsproxy", "labels": config.Labels, "region": config.Region, + "role": config.AssumeRole, }), }, nil } @@ -60,6 +61,7 @@ func (f *rdsDBProxyFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, return nil, trace.Wrap(err) } + applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil } diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go index 51eefb202561d..9d901d7b9aa9f 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go @@ -38,40 +38,33 @@ func TestRDSDBProxyFetcher(t *testing.T) { rdsProxyEndpointVpc1, rdsProxyEndpointDatabaseVpc1 := makeRDSProxyCustomEndpoint(t, rdsProxyVpc1, "endpoint-1", "us-east-1") rdsProxyEndpointVpc2, rdsProxyEndpointDatabaseVpc2 := makeRDSProxyCustomEndpoint(t, rdsProxyVpc2, "endpoint-2", "us-east-1") - clients := &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2}, - DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2}, - DBProxyTargetPort: 9999, - }, - } - - tests := []struct { - name string - inputLabels map[string]string - wantDatabases types.Databases - }{ + tests := []awsFetcherTest{ { - name: "fetch all", - inputLabels: wildcardLabels, + name: "fetch all", + inputClients: &cloud.TestCloudClients{ + RDS: &mocks.RDSMock{ + DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2}, + DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2}, + DBProxyTargetPort: 9999, + }, + }, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherRDSProxy, "us-east-1", wildcardLabels), wantDatabases: types.Databases{rdsProxyDatabaseVpc1, rdsProxyDatabaseVpc2, rdsProxyEndpointDatabaseVpc1, rdsProxyEndpointDatabaseVpc2}, }, { - name: "fetch vpc1", - inputLabels: map[string]string{"vpc-id": "vpc1"}, + name: "fetch vpc1", + inputClients: &cloud.TestCloudClients{ + RDS: &mocks.RDSMock{ + DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2}, + DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2}, + DBProxyTargetPort: 9999, + }, + }, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherRDSProxy, "us-east-1", map[string]string{"vpc-id": "vpc1"}), wantDatabases: types.Databases{rdsProxyDatabaseVpc1, rdsProxyEndpointDatabaseVpc1}, }, } - - for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - fetchers := mustMakeAWSFetchersForMatcher(t, clients, services.AWSMatcherRDSProxy, "us-east-2", toTypeLabels(test.inputLabels)) - require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers)) - }) - } + testAWSFetchers(t, tests...) } func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rds.DBProxy, types.Database) { diff --git a/lib/srv/discovery/fetchers/db/aws_rds_test.go b/lib/srv/discovery/fetchers/db/aws_rds_test.go index 0ed3a6b7981b8..6640884f71d12 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_test.go @@ -55,12 +55,7 @@ func TestRDSFetchers(t *testing.T) { auroraClusterUnknownStatus, auroraDatabaseUnknownStatus := makeRDSCluster(t, "cluster-5", "us-east-1", nil, withRDSClusterStatus("status-does-not-exist")) auroraClusterNoWriter, auroraDatabasesNoWriter := makeRDSClusterWithExtraEndpoints(t, "cluster-6", "us-east-1", envDevLabels, false) - tests := []struct { - name string - inputClients cloud.AWSClients - inputMatchers []services.AWSMatcher - wantDatabases types.Databases - }{ + tests := []awsFetcherTest{ { name: "fetch all", inputClients: &cloud.TestCloudClients{ @@ -206,16 +201,7 @@ func TestRDSFetchers(t *testing.T) { wantDatabases: auroraDatabasesNoWriter, }, } - - for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - fetchers := mustMakeAWSFetchers(t, test.inputClients, test.inputMatchers) - require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers)) - }) - } + testAWSFetchers(t, tests...) } func makeRDSInstance(t *testing.T, name, region string, labels map[string]string, opts ...func(*rds.DBInstance)) (*rds.DBInstance, types.Database) { diff --git a/lib/srv/discovery/fetchers/db/aws_redshift.go b/lib/srv/discovery/fetchers/db/aws_redshift.go index 18ab615dedf2b..9836dc2ec73e8 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift.go @@ -40,6 +40,8 @@ type redshiftFetcherConfig struct { Redshift redshiftiface.RedshiftAPI // Region is the AWS region to query databases in. Region string + // AssumeRole is the AWS IAM role to assume before discovering databases. + AssumeRole services.AssumeRole } // CheckAndSetDefaults validates the config and sets defaults. @@ -75,6 +77,7 @@ func newRedshiftFetcher(config redshiftFetcherConfig) (common.Fetcher, error) { trace.Component: "watch:redshift", "labels": config.Labels, "region": config.Region, + "role": config.AssumeRole, }), }, nil } @@ -104,6 +107,7 @@ func (f *redshiftFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e databases = append(databases, database) } + applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil } diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go b/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go index dbf2c71da6499..2436b6a9bb48a 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go @@ -41,6 +41,8 @@ type redshiftServerlessFetcherConfig struct { Region string // Client is the Redshift Serverless API client. Client redshiftserverlessiface.RedshiftServerlessAPI + // AssumeRole is the AWS IAM role to assume before discovering databases. + AssumeRole services.AssumeRole } // CheckAndSetDefaults validates the config and sets defaults. @@ -83,6 +85,7 @@ func newRedshiftServerlessFetcher(config redshiftServerlessFetcherConfig) (commo trace.Component: "watch:rss<", // (r)ed(s)hift (s)erver(<)less "labels": config.Labels, "region": config.Region, + "role": config.AssumeRole, }), }, nil } @@ -106,6 +109,7 @@ func (f *redshiftServerlessFetcher) Get(ctx context.Context) (types.ResourcesWit databases = append(databases, vpcEndpointDatabases...) } + applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil } diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go b/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go index 6ca2277611ae0..28d2118f985c3 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go @@ -47,12 +47,7 @@ func TestRedshiftServerlessFetcher(t *testing.T) { endpointNotAvailable := mocks.RedshiftServerlessEndpointAccess(workgroupNotAvailable, "endpoint-creating", "us-east-1") endpointNotAvailable.EndpointStatus = aws.String("creating") - tests := []struct { - name string - inputClients cloud.AWSClients - inputLabels map[string]string - wantDatabases types.Databases - }{ + tests := []awsFetcherTest{ { name: "fetch all", inputClients: &cloud.TestCloudClients{ @@ -62,7 +57,7 @@ func TestRedshiftServerlessFetcher(t *testing.T) { TagsByARN: tagsByARN, }, }, - inputLabels: wildcardLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshiftServerless, "us-east-1", wildcardLabels), wantDatabases: types.Databases{workgroupProdDB, workgroupDevDB, endpointProdDB, endpointProdDev}, }, { @@ -74,7 +69,7 @@ func TestRedshiftServerlessFetcher(t *testing.T) { TagsByARN: tagsByARN, }, }, - inputLabels: envProdLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshiftServerless, "us-east-1", envProdLabels), wantDatabases: types.Databases{workgroupProdDB, endpointProdDB}, }, { @@ -86,20 +81,11 @@ func TestRedshiftServerlessFetcher(t *testing.T) { TagsByARN: tagsByARN, }, }, - inputLabels: wildcardLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshiftServerless, "us-east-1", wildcardLabels), wantDatabases: types.Databases{workgroupProdDB}, }, } - - for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - fetchers := mustMakeAWSFetchersForMatcher(t, test.inputClients, services.AWSMatcherRedshiftServerless, "us-east-2", toTypeLabels(test.inputLabels)) - require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers)) - }) - } + testAWSFetchers(t, tests...) } func makeRedshiftServerlessWorkgroup(t *testing.T, name, region string, labels map[string]string) (*redshiftserverless.Workgroup, types.Database) { diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_test.go b/lib/srv/discovery/fetchers/db/aws_redshift_test.go index c8103033380e2..8a6d17be2bbe0 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_test.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_test.go @@ -38,12 +38,7 @@ func TestRedshiftFetcher(t *testing.T) { redshiftUse1Unavailable, _ := makeRedshiftCluster(t, "us-east-1", "qa", withRedshiftStatus("paused")) redshiftUse1UnknownStatus, redshiftDatabaseUnknownStatus := makeRedshiftCluster(t, "us-east-1", "test", withRedshiftStatus("status-does-not-exist")) - tests := []struct { - name string - inputClients cloud.AWSClients - inputLabels map[string]string - wantDatabases types.Databases - }{ + tests := []awsFetcherTest{ { name: "fetch all", inputClients: &cloud.TestCloudClients{ @@ -51,7 +46,7 @@ func TestRedshiftFetcher(t *testing.T) { Clusters: []*redshift.Cluster{redshiftUse1Prod, redshiftUse1Dev}, }, }, - inputLabels: wildcardLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshift, "us-east-1", wildcardLabels), wantDatabases: types.Databases{redshiftDatabaseUse1Prod, redshiftDatabaseUse1Dev}, }, { @@ -61,7 +56,7 @@ func TestRedshiftFetcher(t *testing.T) { Clusters: []*redshift.Cluster{redshiftUse1Prod, redshiftUse1Dev}, }, }, - inputLabels: envProdLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshift, "us-east-1", envProdLabels), wantDatabases: types.Databases{redshiftDatabaseUse1Prod}, }, { @@ -71,20 +66,11 @@ func TestRedshiftFetcher(t *testing.T) { Clusters: []*redshift.Cluster{redshiftUse1Prod, redshiftUse1Unavailable, redshiftUse1UnknownStatus}, }, }, - inputLabels: wildcardLabels, + inputMatchers: makeAWSMatchersForType(services.AWSMatcherRedshift, "us-east-1", wildcardLabels), wantDatabases: types.Databases{redshiftDatabaseUse1Prod, redshiftDatabaseUnknownStatus}, }, } - - for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - fetchers := mustMakeAWSFetchersForMatcher(t, test.inputClients, services.AWSMatcherRedshift, "us-east-2", toTypeLabels(test.inputLabels)) - require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers)) - }) - } + testAWSFetchers(t, tests...) } func makeRedshiftCluster(t *testing.T, region, env string, opts ...func(*redshift.Cluster)) (*redshift.Cluster, types.Database) { diff --git a/lib/srv/discovery/fetchers/db/db.go b/lib/srv/discovery/fetchers/db/db.go index 39c3516c98a68..d9077bbb2e221 100644 --- a/lib/srv/discovery/fetchers/db/db.go +++ b/lib/srv/discovery/fetchers/db/db.go @@ -29,7 +29,7 @@ import ( "github.com/gravitational/teleport/lib/srv/discovery/common" ) -type makeAWSFetcherFunc func(context.Context, cloud.AWSClients, string, types.Labels) (common.Fetcher, error) +type makeAWSFetcherFunc func(context.Context, cloud.AWSClients, string, types.Labels, services.AssumeRole) (common.Fetcher, error) type makeAzureFetcherFunc func(azureFetcherConfig) (common.Fetcher, error) var ( @@ -71,7 +71,7 @@ func MakeAWSFetchers(ctx context.Context, clients cloud.AWSClients, matchers []s for _, makeFetcher := range makeFetchers { for _, region := range matcher.Regions { - fetcher, err := makeFetcher(ctx, clients, region, matcher.Tags) + fetcher, err := makeFetcher(ctx, clients, region, matcher.Tags, matcher.AssumeRole) if err != nil { return nil, trace.Wrap(err) } @@ -116,65 +116,69 @@ func MakeAzureFetchers(clients cloud.AzureClients, matchers []services.AzureMatc } // makeRDSInstanceFetcher returns RDS instance fetcher for the provided region and tags. -func makeRDSInstanceFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) { - rds, err := clients.GetAWSRDSClient(ctx, region) +func makeRDSInstanceFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) { + rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } fetcher, err := newRDSDBInstancesFetcher(rdsFetcherConfig{ - Region: region, - Labels: tags, - RDS: rds, + Region: region, + Labels: tags, + RDS: rds, + AssumeRole: assumeRole, }) return fetcher, trace.Wrap(err) } // makeRDSAuroraFetcher returns RDS Aurora fetcher for the provided region and tags. -func makeRDSAuroraFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) { - rds, err := clients.GetAWSRDSClient(ctx, region) +func makeRDSAuroraFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) { + rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } fetcher, err := newRDSAuroraClustersFetcher(rdsFetcherConfig{ - Region: region, - Labels: tags, - RDS: rds, + Region: region, + Labels: tags, + RDS: rds, + AssumeRole: assumeRole, }) return fetcher, trace.Wrap(err) } // makeRDSProxyFetcher returns RDS proxy fetcher for the provided region and tags. -func makeRDSProxyFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) { - rds, err := clients.GetAWSRDSClient(ctx, region) +func makeRDSProxyFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) { + rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } return newRDSDBProxyFetcher(rdsFetcherConfig{ - Region: region, - Labels: tags, - RDS: rds, + Region: region, + Labels: tags, + RDS: rds, + AssumeRole: assumeRole, }) } // makeRedshiftFetcher returns Redshift fetcher for the provided region and tags. -func makeRedshiftFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) { - redshift, err := clients.GetAWSRedshiftClient(ctx, region) +func makeRedshiftFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) { + redshift, err := clients.GetAWSRedshiftClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } return newRedshiftFetcher(redshiftFetcherConfig{ - Region: region, - Labels: tags, - Redshift: redshift, + Region: region, + Labels: tags, + Redshift: redshift, + AssumeRole: assumeRole, }) } // makeElastiCacheFetcher returns ElastiCache fetcher for the provided region and tags. -func makeElastiCacheFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) { - elastiCache, err := clients.GetAWSElastiCacheClient(ctx, region) +func makeElastiCacheFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) { + elastiCache, err := clients.GetAWSElastiCacheClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } @@ -182,33 +186,36 @@ func makeElastiCacheFetcher(ctx context.Context, clients cloud.AWSClients, regio Region: region, Labels: tags, ElastiCache: elastiCache, + AssumeRole: assumeRole, }) } // makeMemoryDBFetcher returns MemoryDB fetcher for the provided region and tags. -func makeMemoryDBFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) { - memorydb, err := clients.GetAWSMemoryDBClient(ctx, region) +func makeMemoryDBFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) { + memorydb, err := clients.GetAWSMemoryDBClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } return newMemoryDBFetcher(memoryDBFetcherConfig{ - Region: region, - Labels: tags, - MemoryDB: memorydb, + Region: region, + Labels: tags, + MemoryDB: memorydb, + AssumeRole: assumeRole, }) } // makeRedshiftServerlessFetcher returns Redshift Serverless fetcher for the // provided region and tags. -func makeRedshiftServerlessFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels) (common.Fetcher, error) { - client, err := clients.GetAWSRedshiftServerlessClient(ctx, region) +func makeRedshiftServerlessFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole services.AssumeRole) (common.Fetcher, error) { + client, err := clients.GetAWSRedshiftServerlessClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } return newRedshiftServerlessFetcher(redshiftServerlessFetcherConfig{ - Region: region, - Labels: tags, - Client: client, + Region: region, + Labels: tags, + Client: client, + AssumeRole: assumeRole, }) } @@ -228,6 +235,14 @@ func filterDatabasesByLabels(databases types.Databases, labels types.Labels, log return matchedDatabases } +// applyAssumeRoleToDatabases applies assume role settings from fetcher to databases. +func applyAssumeRoleToDatabases(databases types.Databases, assumeRole services.AssumeRole) { + for _, db := range databases { + db.SetAWSAssumeRole(assumeRole.RoleARN) + db.SetAWSExternalID(assumeRole.ExternalID) + } +} + // flatten flattens a nested slice [][]T to []T. func flatten[T any](s [][]T) (result []T) { for i := range s { diff --git a/lib/srv/discovery/fetchers/db/helpers_test.go b/lib/srv/discovery/fetchers/db/helpers_test.go index f212643e71da1..27dd07933007c 100644 --- a/lib/srv/discovery/fetchers/db/helpers_test.go +++ b/lib/srv/discovery/fetchers/db/helpers_test.go @@ -25,6 +25,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -43,6 +44,14 @@ func toTypeLabels(labels map[string]string) types.Labels { return result } +func makeAWSMatchersForType(matcherType, region string, tags map[string]string) []services.AWSMatcher { + return []services.AWSMatcher{{ + Types: []string{matcherType}, + Regions: []string{region}, + Tags: toTypeLabels(tags), + }} +} + func mustMakeAWSFetchers(t *testing.T, clients cloud.AWSClients, matchers []services.AWSMatcher) []common.Fetcher { t.Helper() @@ -57,16 +66,6 @@ func mustMakeAWSFetchers(t *testing.T, clients cloud.AWSClients, matchers []serv return fetchers } -func mustMakeAWSFetchersForMatcher(t *testing.T, clients cloud.AWSClients, matcherType, region string, tags types.Labels) []common.Fetcher { - t.Helper() - - return mustMakeAWSFetchers(t, clients, []services.AWSMatcher{{ - Types: []string{matcherType}, - Regions: []string{region}, - Tags: tags, - }}) -} - func mustMakeAzureFetchers(t *testing.T, clients cloud.AzureClients, matchers []services.AzureMatcher) []common.Fetcher { t.Helper() @@ -96,3 +95,73 @@ func mustGetDatabases(t *testing.T, fetchers []common.Fetcher) types.Databases { } return all } + +// testAssumeRole is a fixture for testing fetchers. +// every matcher, stub database, and mock AWS Session created uses this fixture. +// Tests will cover: +// - that fetchers use the configured assume role when using AWS cloud clients. +// - that databases discovered and created by fetchers have the assumed role used to discover them populated. +var testAssumeRole = services.AssumeRole{ + RoleARN: "arn:aws:iam::123456789012:role/test-role", + ExternalID: "externalID123", +} + +// awsFetcherTest is a common test struct for AWS fetchers. +type awsFetcherTest struct { + name string + inputClients *cloud.TestCloudClients + inputMatchers []services.AWSMatcher + wantDatabases types.Databases +} + +// testAWSFetchers is a helper that tests AWS fetchers, since +// all of the AWS fetcher tests are fundamentally the same. +func testAWSFetchers(t *testing.T, tests ...awsFetcherTest) { + t.Helper() + for _, test := range tests { + test := test + require.Nil(t, test.inputClients.STS, "testAWSFetchers injects an STS mock itself, but test input had already configured it. This is a test configuration error.") + stsMock := &mocks.STSMock{} + test.inputClients.STS = stsMock + t.Run(test.name, func(t *testing.T) { + t.Helper() + fetchers := mustMakeAWSFetchers(t, test.inputClients, test.inputMatchers) + require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers)) + }) + t.Run(test.name+" with assume role", func(t *testing.T) { + t.Helper() + matchers := copyAWSMatchersWithAssumeRole(testAssumeRole, test.inputMatchers...) + wantDBs := copyDatabasesWithAWSAssumeRole(testAssumeRole, test.wantDatabases...) + fetchers := mustMakeAWSFetchers(t, test.inputClients, matchers) + require.ElementsMatch(t, wantDBs, mustGetDatabases(t, fetchers)) + require.Equal(t, []string{testAssumeRole.RoleARN}, stsMock.GetAssumedRoleARNs()) + require.Equal(t, []string{testAssumeRole.ExternalID}, stsMock.GetAssumedRoleExternalIDs()) + }) + } +} + +// copyDatabasesWithAWSAssumeRole copies input databases and sets a given AWS assume role for each copy. +func copyDatabasesWithAWSAssumeRole(role services.AssumeRole, databases ...types.Database) types.Databases { + if len(databases) == 0 { + return databases + } + out := make(types.Databases, 0, len(databases)) + for _, db := range databases { + out = append(out, db.Copy()) + } + applyAssumeRoleToDatabases(out, role) + return out +} + +// copyAWSMatchersWithAssumeRole copies input AWS matchers and sets a given AWS assume role for each copy. +func copyAWSMatchersWithAssumeRole(role services.AssumeRole, matchers ...services.AWSMatcher) []services.AWSMatcher { + if len(matchers) == 0 { + return matchers + } + out := make([]services.AWSMatcher, 0, len(matchers)) + for _, m := range matchers { + m.AssumeRole = role + out = append(out, m) + } + return out +}