Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions api/types/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ type Database interface {
GetAWS() AWS
// SetStatusAWS sets the database AWS metadata in the status field.
SetStatusAWS(AWS)
// SetAWSExternalID sets the database AWS external ID in the Spec.AWS field.
SetAWSExternalID(id string)
// SetAWSAssumeRole sets the database AWS assume role arn in the Spec.AWS field.
SetAWSAssumeRole(roleARN string)
// GetGCP returns GCP information for Cloud SQL databases.
GetGCP() GCPCloudSQL
// GetAzure returns Azure database server metadata.
Expand Down Expand Up @@ -341,6 +345,16 @@ func (d *DatabaseV3) SetStatusAWS(aws AWS) {
d.Status.AWS = aws
}

// SetAWSExternalID sets the database AWS external ID in the Spec.AWS field.
func (d *DatabaseV3) SetAWSExternalID(id string) {
d.Spec.AWS.ExternalID = id
}

// SetAWSAssumeRole sets the database AWS assume role arn in the Spec.AWS field.
func (d *DatabaseV3) SetAWSAssumeRole(roleARN string) {
d.Spec.AWS.AssumeRoleARN = roleARN
}

// GetGCP returns GCP information for Cloud SQL databases.
func (d *DatabaseV3) GetGCP() GCPCloudSQL {
return d.Spec.GCP
Expand Down
13 changes: 9 additions & 4 deletions lib/cloud/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,18 @@ type awsAssumeRoleOpts struct {
// when getting an AWS session.
type AWSAssumeRoleOptionFn func(*awsAssumeRoleOpts)

// WithAssumeRole configures options needed for assuming an AWS role.
func WithAssumeRole(roleARN, externalID string) AWSAssumeRoleOptionFn {
return func(options *awsAssumeRoleOpts) {
options.assumeRoleARN = roleARN
options.assumeRoleExternalID = externalID
}
}

// WithAssumeRoleFromAWSMeta extracts options needed from AWS metadata for
// assuming an AWS role.
func WithAssumeRoleFromAWSMeta(meta types.AWS) AWSAssumeRoleOptionFn {
return func(options *awsAssumeRoleOpts) {
options.assumeRoleARN = meta.AssumeRoleARN
options.assumeRoleExternalID = meta.ExternalID
}
return WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID)
}

// WithChainedAssumeRole sets a role to assume with a base session to use
Expand Down
62 changes: 62 additions & 0 deletions lib/srv/discovery/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ import (
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/aws/aws-sdk-go/service/eks"
"github.com/aws/aws-sdk-go/service/eks/eksiface"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/redshift"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -918,9 +920,22 @@ func (m *mockGKEAPI) ListClusters(ctx context.Context, projectID string, locatio

func TestDiscoveryDatabase(t *testing.T) {
awsRedshiftResource, awsRedshiftDB := makeRedshiftCluster(t, "aws-redshift", "us-east-1")
awsRDSInstance, awsRDSDB := makeRDSInstance(t, "aws-rds", "us-west-1")
azRedisResource, azRedisDB := makeAzureRedisServer(t, "az-redis", "sub1", "group1", "East US")

role := services.AssumeRole{RoleARN: "arn:aws:iam::123456789012:role/test-role", ExternalID: "test123"}
awsRDSDBWithRole := awsRDSDB.Copy()
awsRDSDBWithRole.SetAWSAssumeRole("arn:aws:iam::123456789012:role/test-role")
awsRDSDBWithRole.SetAWSExternalID("test123")

testClients := &cloud.TestCloudClients{
STS: &mocks.STSMock{},
RDS: &mocks.RDSMock{
DBInstances: []*rds.DBInstance{awsRDSInstance},
DBEngineVersions: []*rds.DBEngineVersion{
{Engine: aws.String(services.RDSEnginePostgres)},
},
},
Redshift: &mocks.RedshiftMock{
Clusters: []*redshift.Cluster{awsRedshiftResource},
},
Expand Down Expand Up @@ -949,6 +964,16 @@ func TestDiscoveryDatabase(t *testing.T) {
}},
expectDatabases: []types.Database{awsRedshiftDB},
},
{
name: "discover AWS database with assumed role",
awsMatchers: []services.AWSMatcher{{
Types: []string{services.AWSMatcherRDS},
Tags: map[string]utils.Strings{types.Wildcard: {types.Wildcard}},
Regions: []string{"us-west-1"},
AssumeRole: role,
}},
expectDatabases: []types.Database{awsRDSDBWithRole},
},
{
name: "discover Azure database",
azureMatchers: []services.AzureMatcher{{
Expand Down Expand Up @@ -979,6 +1004,26 @@ func TestDiscoveryDatabase(t *testing.T) {
}},
expectDatabases: []types.Database{awsRedshiftDB},
},
{
name: "update existing database with assumed role",
existingDatabases: []types.Database{
mustNewDatabase(t, types.Metadata{
Name: "aws-rds",
Description: "should be updated",
Labels: map[string]string{types.OriginLabel: types.OriginCloud},
}, types.DatabaseSpecV3{
Protocol: "postgres",
URI: "should.be.updated.com:12345",
}),
},
awsMatchers: []services.AWSMatcher{{
Types: []string{services.AWSMatcherRDS},
Tags: map[string]utils.Strings{types.Wildcard: {types.Wildcard}},
Regions: []string{"us-west-1"},
AssumeRole: role,
}},
expectDatabases: []types.Database{awsRDSDBWithRole},
},
{
name: "delete existing database",
existingDatabases: []types.Database{
Expand Down Expand Up @@ -1091,6 +1136,23 @@ func TestDiscoveryDatabase(t *testing.T) {
}
}

func makeRDSInstance(t *testing.T, name, region string) (*rds.DBInstance, types.Database) {
instance := &rds.DBInstance{
DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db:%v", region, name)),
DBInstanceIdentifier: aws.String(name),
DbiResourceId: aws.String(uuid.New().String()),
Engine: aws.String(services.RDSEnginePostgres),
DBInstanceStatus: aws.String("available"),
Endpoint: &rds.Endpoint{
Address: aws.String("localhost"),
Port: aws.Int64(5432),
},
}
database, err := services.NewDatabaseFromRDSInstance(instance)
require.NoError(t, err)
return instance, database
}

func makeRedshiftCluster(t *testing.T, name, region string) (*redshift.Cluster, types.Database) {
t.Helper()
cluster := &redshift.Cluster{
Expand Down
4 changes: 4 additions & 0 deletions lib/srv/discovery/fetchers/db/aws_elasticache.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type elastiCacheFetcherConfig struct {
ElastiCache elasticacheiface.ElastiCacheAPI
// Region is the AWS region to query databases in.
Region string
// AssumeRole is the AWS IAM role to assume before discovering databases.
AssumeRole services.AssumeRole
}

// CheckAndSetDefaults validates the config and sets defaults.
Expand Down Expand Up @@ -74,6 +76,7 @@ func newElastiCacheFetcher(config elastiCacheFetcherConfig) (common.Fetcher, err
trace.Component: "watch:elasticache",
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
Expand Down Expand Up @@ -168,6 +171,7 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels
}
}

applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole)
return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil
}

Expand Down
26 changes: 6 additions & 20 deletions lib/srv/discovery/fetchers/db/aws_elasticache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,7 @@ func TestElastiCacheFetcher(t *testing.T) {
aws.StringValue(elasticacheUnsupported.ARN): elasticacheUnsupportedTags,
}

tests := []struct {
name string
inputClients cloud.AWSClients
inputLabels map[string]string
wantDatabases types.Databases
}{
tests := []awsFetcherTest{
{
name: "fetch all",
inputClients: &cloud.TestCloudClients{
Expand All @@ -62,7 +57,7 @@ func TestElastiCacheFetcher(t *testing.T) {
TagsByARN: elasticacheTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{elasticacheDatabaseProd, elasticacheDatabaseQA},
},
{
Expand All @@ -73,7 +68,7 @@ func TestElastiCacheFetcher(t *testing.T) {
TagsByARN: elasticacheTagsByARN,
},
},
inputLabels: envProdLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", envProdLabels),
wantDatabases: types.Databases{elasticacheDatabaseProd},
},
{
Expand All @@ -84,7 +79,7 @@ func TestElastiCacheFetcher(t *testing.T) {
TagsByARN: elasticacheTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{elasticacheDatabaseProd},
},
{
Expand All @@ -95,20 +90,11 @@ func TestElastiCacheFetcher(t *testing.T) {
TagsByARN: elasticacheTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{elasticacheDatabaseProd},
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()

fetchers := mustMakeAWSFetchersForMatcher(t, test.inputClients, services.AWSMatcherElastiCache, "us-east-2", toTypeLabels(test.inputLabels))
require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers))
})
}
testAWSFetchers(t, tests...)
}

func makeElastiCacheCluster(t *testing.T, name, region, env string, opts ...func(*elasticache.ReplicationGroup)) (*elasticache.ReplicationGroup, types.Database, []*elasticache.Tag) {
Expand Down
4 changes: 4 additions & 0 deletions lib/srv/discovery/fetchers/db/aws_memorydb.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type memoryDBFetcherConfig struct {
MemoryDB memorydbiface.MemoryDBAPI
// Region is the AWS region to query databases in.
Region string
// AssumeRole is the AWS IAM role to assume before discovering databases.
AssumeRole services.AssumeRole
}

// CheckAndSetDefaults validates the config and sets defaults.
Expand Down Expand Up @@ -74,6 +76,7 @@ func newMemoryDBFetcher(config memoryDBFetcherConfig) (common.Fetcher, error) {
trace.Component: "watch:memorydb",
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
Expand Down Expand Up @@ -136,6 +139,7 @@ func (f *memoryDBFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e
databases = append(databases, database)
}
}
applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole)
return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil
}

Expand Down
26 changes: 6 additions & 20 deletions lib/srv/discovery/fetchers/db/aws_memorydb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,7 @@ func TestMemoryDBFetcher(t *testing.T) {
aws.StringValue(memorydbUnsupported.ARN): memorydbUnsupportedTags,
}

tests := []struct {
name string
inputClients cloud.AWSClients
inputLabels map[string]string
wantDatabases types.Databases
}{
tests := []awsFetcherTest{
{
name: "fetch all",
inputClients: &cloud.TestCloudClients{
Expand All @@ -61,7 +56,7 @@ func TestMemoryDBFetcher(t *testing.T) {
TagsByARN: memorydbTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{memorydbDatabaseProd, memorydbDatabaseTest},
},
{
Expand All @@ -72,7 +67,7 @@ func TestMemoryDBFetcher(t *testing.T) {
TagsByARN: memorydbTagsByARN,
},
},
inputLabels: envProdLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", envProdLabels),
wantDatabases: types.Databases{memorydbDatabaseProd},
},
{
Expand All @@ -83,7 +78,7 @@ func TestMemoryDBFetcher(t *testing.T) {
TagsByARN: memorydbTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{memorydbDatabaseProd},
},
{
Expand All @@ -94,20 +89,11 @@ func TestMemoryDBFetcher(t *testing.T) {
TagsByARN: memorydbTagsByARN,
},
},
inputLabels: wildcardLabels,
inputMatchers: makeAWSMatchersForType(services.AWSMatcherMemoryDB, "us-east-1", wildcardLabels),
wantDatabases: types.Databases{memorydbDatabaseProd},
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()

fetchers := mustMakeAWSFetchersForMatcher(t, test.inputClients, services.AWSMatcherMemoryDB, "us-east-2", toTypeLabels(test.inputLabels))
require.ElementsMatch(t, test.wantDatabases, mustGetDatabases(t, fetchers))
})
}
testAWSFetchers(t, tests...)
}

func makeMemoryDBCluster(t *testing.T, name, region, env string, opts ...func(*memorydb.Cluster)) (*memorydb.Cluster, types.Database, []*memorydb.Tag) {
Expand Down
6 changes: 6 additions & 0 deletions lib/srv/discovery/fetchers/db/aws_rds.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type rdsFetcherConfig struct {
RDS rdsiface.RDSAPI
// Region is the AWS region to query databases in.
Region string
// AssumeRole is the AWS IAM role to assume before discovering databases.
AssumeRole services.AssumeRole
}

// CheckAndSetDefaults validates the config and sets defaults.
Expand Down Expand Up @@ -76,6 +78,7 @@ func newRDSDBInstancesFetcher(config rdsFetcherConfig) (common.Fetcher, error) {
trace.Component: "watch:rds",
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
Expand All @@ -87,6 +90,7 @@ func (f *rdsDBInstancesFetcher) Get(ctx context.Context) (types.ResourcesWithLab
return nil, trace.Wrap(err)
}

applyAssumeRoleToDatabases(rdsDatabases, f.cfg.AssumeRole)
return filterDatabasesByLabels(rdsDatabases, f.cfg.Labels, f.log).AsResources(), nil
}

Expand Down Expand Up @@ -172,6 +176,7 @@ func newRDSAuroraClustersFetcher(config rdsFetcherConfig) (common.Fetcher, error
trace.Component: "watch:aurora",
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
Expand All @@ -183,6 +188,7 @@ func (f *rdsAuroraClustersFetcher) Get(ctx context.Context) (types.ResourcesWith
return nil, trace.Wrap(err)
}

applyAssumeRoleToDatabases(auroraDatabases, f.cfg.AssumeRole)
return filterDatabasesByLabels(auroraDatabases, f.cfg.Labels, f.log).AsResources(), nil
}

Expand Down
2 changes: 2 additions & 0 deletions lib/srv/discovery/fetchers/db/aws_rds_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func newRDSDBProxyFetcher(config rdsFetcherConfig) (common.Fetcher, error) {
trace.Component: "watch:rdsproxy",
"labels": config.Labels,
"region": config.Region,
"role": config.AssumeRole,
}),
}, nil
}
Expand All @@ -60,6 +61,7 @@ func (f *rdsDBProxyFetcher) Get(ctx context.Context) (types.ResourcesWithLabels,
return nil, trace.Wrap(err)
}

applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole)
return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil
}

Expand Down
Loading