Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions lib/cloud/mocks/aws_rds.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,15 @@ func applyInstanceFilters(in []*rds.DBInstance, filters []*rds.Filter) ([]*rds.D
}
var out []*rds.DBInstance
efs := engineFilterSet(filters)
clusterIDs := clusterIdentifierFilterSet(filters)
for _, instance := range in {
if instanceEngineMatches(instance, efs) {
out = append(out, instance)
if len(efs) > 0 && !instanceEngineMatches(instance, efs) {
continue
}
if len(clusterIDs) > 0 && !instanceClusterIDMatches(instance, clusterIDs) {
continue
}
out = append(out, instance)
}
return out, nil
}
Expand All @@ -345,9 +350,18 @@ func applyClusterFilters(in []*rds.DBCluster, filters []*rds.Filter) ([]*rds.DBC

// engineFilterSet builds a string set of engine names from a list of RDS filters.
func engineFilterSet(filters []*rds.Filter) map[string]struct{} {
return filterValues(filters, "engine")
}

// clusterIdentifierFilterSet builds a string set of ClusterIDs from a list of RDS filters.
func clusterIdentifierFilterSet(filters []*rds.Filter) map[string]struct{} {
return filterValues(filters, "db-cluster-id")
}

func filterValues(filters []*rds.Filter, filterKey string) map[string]struct{} {
out := make(map[string]struct{})
for _, f := range filters {
if aws.StringValue(f.Name) != "engine" {
if aws.StringValue(f.Name) != filterKey {
continue
}
for _, v := range f.Values {
Expand All @@ -363,6 +377,12 @@ func instanceEngineMatches(instance *rds.DBInstance, filterSet map[string]struct
return ok
}

// instanceClusterIDMatches returns whether an RDS DBInstance ClusterID matches any ClusterID in a filter set.
func instanceClusterIDMatches(instance *rds.DBInstance, filterSet map[string]struct{}) bool {
_, ok := filterSet[aws.StringValue(instance.DBClusterIdentifier)]
return ok
}

// clusterEngineMatches returns whether an RDS DBCluster engine matches any engine name in a filter set.
func clusterEngineMatches(cluster *rds.DBCluster, filterSet map[string]struct{}) bool {
_, ok := filterSet[aws.StringValue(cluster.Engine)]
Expand Down
1 change: 1 addition & 0 deletions lib/integrations/awsoidc/listdatabases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ func TestListDatabases(t *testing.T) {
"engine-version": "",
"region": "",
"status": "available",
"vpc-id": "vpc-999",
"teleport.dev/cloud": "AWS",
},
},
Expand Down
38 changes: 25 additions & 13 deletions lib/services/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,9 @@ func labelsFromRDSV2Instance(rdsInstance *rdsTypesV2.DBInstance, meta *types.AWS
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsInstance.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(RDSEndpointTypeInstance)
labels[types.DiscoveryLabelStatus] = aws.StringValue(rdsInstance.DBInstanceStatus)
if rdsInstance.DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(rdsInstance.DBSubnetGroup.VpcId)
}
return addLabels(labels, libcloudaws.TagsToLabels(rdsInstance.TagList))
}

Expand All @@ -720,7 +723,7 @@ func NewDatabaseFromRDSV2Cluster(cluster *rdsTypesV2.DBCluster, firstInstance *r
return types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v", metadata.Region),
Labels: labelsFromRDSV2Cluster(cluster, metadata, RDSEndpointTypePrimary),
Labels: labelsFromRDSV2Cluster(cluster, metadata, RDSEndpointTypePrimary, firstInstance),
}, aws.StringValue(cluster.DBClusterIdentifier)),
types.DatabaseSpecV3{
Protocol: protocol,
Expand Down Expand Up @@ -777,17 +780,20 @@ func MetadataFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, rdsInstance *rds

// labelsFromRDSV2Cluster creates database labels for the provided RDS cluster.
// It uses aws sdk v2.
func labelsFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, meta *types.AWS, endpointType RDSEndpointType) map[string]string {
func labelsFromRDSV2Cluster(rdsCluster *rdsTypesV2.DBCluster, meta *types.AWS, endpointType RDSEndpointType, memberInstance *rdsTypesV2.DBInstance) map[string]string {
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsCluster.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsCluster.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(endpointType)
labels[types.DiscoveryLabelStatus] = aws.StringValue(rdsCluster.Status)
if memberInstance != nil && memberInstance.DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(memberInstance.DBSubnetGroup.VpcId)
}
return addLabels(labels, libcloudaws.TagsToLabels(rdsCluster.TagList))
}

// NewDatabaseFromRDSCluster creates a database resource from an RDS cluster (Aurora).
func NewDatabaseFromRDSCluster(cluster *rds.DBCluster) (types.Database, error) {
func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Database, error) {
metadata, err := MetadataFromRDSCluster(cluster)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -799,7 +805,7 @@ func NewDatabaseFromRDSCluster(cluster *rds.DBCluster) (types.Database, error) {
return types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v", metadata.Region),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypePrimary),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypePrimary, memberInstances),
}, aws.StringValue(cluster.DBClusterIdentifier)),
types.DatabaseSpecV3{
Protocol: protocol,
Expand All @@ -809,7 +815,7 @@ func NewDatabaseFromRDSCluster(cluster *rds.DBCluster) (types.Database, error) {
}

// NewDatabaseFromRDSClusterReaderEndpoint creates a database resource from an RDS cluster reader endpoint (Aurora).
func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster) (types.Database, error) {
func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Database, error) {
metadata, err := MetadataFromRDSCluster(cluster)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -821,7 +827,7 @@ func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster) (types.Data
return types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, string(RDSEndpointTypeReader)),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeReader),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeReader, memberInstances),
}, aws.StringValue(cluster.DBClusterIdentifier), string(RDSEndpointTypeReader)),
types.DatabaseSpecV3{
Protocol: protocol,
Expand All @@ -831,7 +837,7 @@ func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster) (types.Data
}

// NewDatabasesFromRDSClusterCustomEndpoints creates database resources from RDS cluster custom endpoints (Aurora).
func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster) (types.Databases, error) {
func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Databases, error) {
metadata, err := MetadataFromRDSCluster(cluster)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -859,7 +865,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster) (types.Da
database, err := types.NewDatabaseV3(
setAWSDBName(types.Metadata{
Description: fmt.Sprintf("Aurora cluster in %v (%v endpoint)", metadata.Region, string(RDSEndpointTypeCustom)),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeCustom),
Labels: labelsFromRDSCluster(cluster, metadata, RDSEndpointTypeCustom, memberInstances),
}, aws.StringValue(cluster.DBClusterIdentifier), string(RDSEndpointTypeCustom), endpointDetails.ClusterCustomEndpointName),
types.DatabaseSpecV3{
Protocol: protocol,
Expand All @@ -885,7 +891,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster) (types.Da

// NewDatabasesFromRDSCluster creates all database resources from an RDS Aurora
// cluster.
func NewDatabasesFromRDSCluster(cluster *rds.DBCluster) (types.Databases, error) {
func NewDatabasesFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Databases, error) {
var errors []error
var databases types.Databases

Expand All @@ -906,7 +912,7 @@ func NewDatabasesFromRDSCluster(cluster *rds.DBCluster) (types.Databases, error)

// Add a database from primary endpoint, if any writer instances.
if cluster.Endpoint != nil && hasWriterInstance {
database, err := NewDatabaseFromRDSCluster(cluster)
database, err := NewDatabaseFromRDSCluster(cluster, memberInstances)
if err != nil {
errors = append(errors, err)
} else {
Expand All @@ -917,7 +923,7 @@ func NewDatabasesFromRDSCluster(cluster *rds.DBCluster) (types.Databases, error)
// Add a database from reader endpoint, if any reader instances.
// https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/Aurora.Overview.Endpoints.html#Aurora.Endpoints.Reader
if cluster.ReaderEndpoint != nil && hasReaderInstance {
database, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster)
database, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster, memberInstances)
if err != nil {
errors = append(errors, err)
} else {
Expand All @@ -927,7 +933,7 @@ func NewDatabasesFromRDSCluster(cluster *rds.DBCluster) (types.Databases, error)

// Add databases from custom endpoints
if len(cluster.CustomEndpoints) > 0 {
customEndpointDatabases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster)
customEndpointDatabases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster, memberInstances)
if err != nil {
errors = append(errors, err)
}
Expand Down Expand Up @@ -1626,15 +1632,21 @@ func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[str
labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsInstance.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsInstance.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(RDSEndpointTypeInstance)
if rdsInstance.DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(rdsInstance.DBSubnetGroup.VpcId)
}
return addLabels(labels, libcloudaws.TagsToLabels(rdsInstance.TagList))
}

// labelsFromRDSCluster creates database labels for the provided RDS cluster.
func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType RDSEndpointType) map[string]string {
func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType RDSEndpointType, memberInstances []*rds.DBInstance) map[string]string {
labels := labelsFromAWSMetadata(meta)
labels[types.DiscoveryLabelEngine] = aws.StringValue(rdsCluster.Engine)
labels[types.DiscoveryLabelEngineVersion] = aws.StringValue(rdsCluster.EngineVersion)
labels[types.DiscoveryLabelEndpointType] = string(endpointType)
if len(memberInstances) > 0 && memberInstances[0].DBSubnetGroup != nil {
labels[types.DiscoveryLabelVPCID] = aws.StringValue(memberInstances[0].DBSubnetGroup.VpcId)
}
return addLabels(labels, libcloudaws.TagsToLabels(rdsCluster.TagList))
}

Expand Down
27 changes: 19 additions & 8 deletions lib/services/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,7 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) {
types.DiscoveryLabelEngineVersion: "13.0",
types.DiscoveryLabelEndpointType: "instance",
types.DiscoveryLabelStatus: "available",
types.DiscoveryLabelVPCID: "vpc-asd",
"key": "val",
},
}, types.DatabaseSpecV3{
Expand Down Expand Up @@ -893,6 +894,8 @@ func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) {

// TestDatabaseFromRDSCluster tests converting an RDS cluster to a database resource.
func TestDatabaseFromRDSCluster(t *testing.T) {
vpcid := uuid.NewString()
dbInstanceMembers := []*rds.DBInstance{{DBSubnetGroup: &rds.DBSubnetGroup{VpcId: aws.String(vpcid)}}}
cluster := &rds.DBCluster{
DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"),
DBClusterIdentifier: aws.String("cluster-1"),
Expand Down Expand Up @@ -934,6 +937,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "primary",
types.DiscoveryLabelVPCID: vpcid,
"key": "val",
},
}, types.DatabaseSpecV3{
Expand All @@ -942,7 +946,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
AWS: expectedAWS,
})
require.NoError(t, err)
actual, err := NewDatabaseFromRDSCluster(cluster)
actual, err := NewDatabaseFromRDSCluster(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expected, actual))
})
Expand All @@ -958,6 +962,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "reader",
types.DiscoveryLabelVPCID: vpcid,
"key": "val",
},
}, types.DatabaseSpecV3{
Expand All @@ -966,7 +971,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
AWS: expectedAWS,
})
require.NoError(t, err)
actual, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster)
actual, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expected, actual))
})
Expand All @@ -979,6 +984,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "custom",
types.DiscoveryLabelVPCID: vpcid,
"key": "val",
}

Expand Down Expand Up @@ -1010,7 +1016,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
})
require.NoError(t, err)

databases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster)
databases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Equal(t, types.Databases{expectedMyEndpoint1, expectedMyEndpoint2}, databases)
})
Expand All @@ -1021,7 +1027,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) {
aws.String("badendpoint1"),
aws.String("badendpoint2"),
}
_, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster)
_, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster, dbInstanceMembers)
require.Error(t, err)
})
}
Expand Down Expand Up @@ -1127,6 +1133,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) {
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "primary",
types.DiscoveryLabelStatus: "available",
types.DiscoveryLabelVPCID: "vpc-123",
"key": "val",
},
}, types.DatabaseSpecV3{
Expand All @@ -1153,6 +1160,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) {

// TestDatabaseFromRDSClusterNameOverride tests converting an RDS cluster to a database resource with overridden name.
func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
dbInstanceMembers := []*rds.DBInstance{{DBSubnetGroup: &rds.DBSubnetGroup{VpcId: aws.String("vpc-123")}}}
for _, overrideLabel := range types.AWSDatabaseNameOverrideLabels {
cluster := &rds.DBCluster{
DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"),
Expand Down Expand Up @@ -1195,6 +1203,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "primary",
types.DiscoveryLabelVPCID: "vpc-123",
overrideLabel: "mycluster-2",
"key": "val",
},
Expand All @@ -1204,7 +1213,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
AWS: expectedAWS,
})
require.NoError(t, err)
actual, err := NewDatabaseFromRDSCluster(cluster)
actual, err := NewDatabaseFromRDSCluster(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expected, actual))
})
Expand All @@ -1220,6 +1229,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "reader",
types.DiscoveryLabelVPCID: "vpc-123",
overrideLabel: "mycluster-2",
"key": "val",
},
Expand All @@ -1229,7 +1239,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
AWS: expectedAWS,
})
require.NoError(t, err)
actual, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster)
actual, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expected, actual))
})
Expand All @@ -1242,6 +1252,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
types.DiscoveryLabelEngine: RDSEngineAuroraMySQL,
types.DiscoveryLabelEngineVersion: "8.0.0",
types.DiscoveryLabelEndpointType: "custom",
types.DiscoveryLabelVPCID: "vpc-123",
overrideLabel: "mycluster-2",
"key": "val",
}
Expand Down Expand Up @@ -1274,7 +1285,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
})
require.NoError(t, err)

databases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster)
databases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster, dbInstanceMembers)
require.NoError(t, err)
require.Equal(t, types.Databases{expectedMyEndpoint1, expectedMyEndpoint2}, databases)
})
Expand All @@ -1285,7 +1296,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) {
aws.String("badendpoint1"),
aws.String("badendpoint2"),
}
_, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster)
_, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster, dbInstanceMembers)
require.Error(t, err)
})
}
Expand Down
3 changes: 2 additions & 1 deletion lib/srv/db/cloud/resource_checker_url_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/opensearchservice"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/rds/rdsiface"
"github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -98,7 +99,7 @@ func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Databas
if err != nil {
return trace.Wrap(err)
}
databases, err := services.NewDatabasesFromRDSCluster(rdsCluster)
databases, err := services.NewDatabasesFromRDSCluster(rdsCluster, []*rds.DBInstance{})
if err != nil {
c.log.Warnf("Could not convert RDS cluster %q to database resources: %v.",
aws.StringValue(rdsCluster.DBClusterIdentifier), err)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/cloud/resource_checker_url_aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestURLChecker_AWS(t *testing.T) {
mocks.WithRDSClusterReader,
mocks.WithRDSClusterCustomEndpoint("my-custom"),
)
rdsClusterDBs, err := services.NewDatabasesFromRDSCluster(rdsCluster)
rdsClusterDBs, err := services.NewDatabasesFromRDSCluster(rdsCluster, []*rds.DBInstance{})
require.NoError(t, err)
require.Len(t, rdsClusterDBs, 3) // Primary, reader, custom.
testCases = append(testCases, append(rdsClusterDBs, rdsInstanceDB)...)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/discovery/common/renaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ func makeAuroraPrimaryDB(t *testing.T, name, region, accountID, overrideLabel st
overrideLabel: name,
}),
}
database, err := services.NewDatabaseFromRDSCluster(cluster)
database, err := services.NewDatabaseFromRDSCluster(cluster, []*rds.DBInstance{})
require.NoError(t, err)
return database
}
Expand Down
Loading