diff --git a/lib/cloud/mocks/aws.go b/lib/cloud/mocks/aws.go index 1fb0d979743de..21dabfe707dd9 100644 --- a/lib/cloud/mocks/aws.go +++ b/lib/cloud/mocks/aws.go @@ -27,22 +27,11 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/eks" "github.com/aws/aws-sdk-go/service/eks/eksiface" - "github.com/aws/aws-sdk-go/service/elasticache" - "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam/iamiface" - "github.com/aws/aws-sdk-go/service/memorydb" - "github.com/aws/aws-sdk-go/service/memorydb/memorydbiface" - "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/opensearchservice/opensearchserviceiface" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" - "github.com/aws/aws-sdk-go/service/redshift" - "github.com/aws/aws-sdk-go/service/redshift/redshiftiface" "github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" "golang.org/x/exp/slices" ) @@ -118,192 +107,6 @@ func (m *STSMock) GetCallerIdentityRequest(req *sts.GetCallerIdentityInput) (*re }, nil } -// RDSMock mocks AWS RDS API. -type RDSMock struct { - rdsiface.RDSAPI - DBInstances []*rds.DBInstance - DBClusters []*rds.DBCluster - DBProxies []*rds.DBProxy - DBProxyEndpoints []*rds.DBProxyEndpoint - DBEngineVersions []*rds.DBEngineVersion - DBProxyTargetPort int64 -} - -func (m *RDSMock) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { - return nil, trace.Wrap(err) - } - instances, err := applyInstanceFilters(m.DBInstances, input.Filters) - if err != nil { - return nil, trace.Wrap(err) - } - if aws.StringValue(input.DBInstanceIdentifier) == "" { - return &rds.DescribeDBInstancesOutput{ - DBInstances: instances, - }, nil - } - for _, instance := range instances { - if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) { - return &rds.DescribeDBInstancesOutput{ - DBInstances: []*rds.DBInstance{instance}, - }, nil - } - } - return nil, trace.NotFound("instance %v not found", aws.StringValue(input.DBInstanceIdentifier)) -} - -func (m *RDSMock) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { - return trace.Wrap(err) - } - instances, err := applyInstanceFilters(m.DBInstances, input.Filters) - if err != nil { - return trace.Wrap(err) - } - fn(&rds.DescribeDBInstancesOutput{ - DBInstances: instances, - }, true) - return nil -} - -func (m *RDSMock) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { - return nil, trace.Wrap(err) - } - clusters, err := applyClusterFilters(m.DBClusters, input.Filters) - if err != nil { - return nil, trace.Wrap(err) - } - if aws.StringValue(input.DBClusterIdentifier) == "" { - return &rds.DescribeDBClustersOutput{ - DBClusters: clusters, - }, nil - } - for _, cluster := range clusters { - if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) { - return &rds.DescribeDBClustersOutput{ - DBClusters: []*rds.DBCluster{cluster}, - }, nil - } - } - return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier)) -} - -func (m *RDSMock) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { - return trace.Wrap(err) - } - clusters, err := applyClusterFilters(m.DBClusters, input.Filters) - if err != nil { - return trace.Wrap(err) - } - fn(&rds.DescribeDBClustersOutput{ - DBClusters: clusters, - }, true) - return nil -} - -func (m *RDSMock) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { - for i, instance := range m.DBInstances { - if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) { - if aws.BoolValue(input.EnableIAMDatabaseAuthentication) { - m.DBInstances[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) - } - return &rds.ModifyDBInstanceOutput{ - DBInstance: m.DBInstances[i], - }, nil - } - } - return nil, trace.NotFound("instance %v not found", aws.StringValue(input.DBInstanceIdentifier)) -} - -func (m *RDSMock) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { - for i, cluster := range m.DBClusters { - if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) { - if aws.BoolValue(input.EnableIAMDatabaseAuthentication) { - m.DBClusters[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) - } - return &rds.ModifyDBClusterOutput{ - DBCluster: m.DBClusters[i], - }, nil - } - } - return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier)) -} - -func (m *RDSMock) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { - if aws.StringValue(input.DBProxyName) == "" { - return &rds.DescribeDBProxiesOutput{ - DBProxies: m.DBProxies, - }, nil - } - for _, dbProxy := range m.DBProxies { - if aws.StringValue(dbProxy.DBProxyName) == aws.StringValue(input.DBProxyName) { - return &rds.DescribeDBProxiesOutput{ - DBProxies: []*rds.DBProxy{dbProxy}, - }, nil - } - } - return nil, trace.NotFound("proxy %v not found", aws.StringValue(input.DBProxyName)) -} - -func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { - inputProxyName := aws.StringValue(input.DBProxyName) - inputProxyEndpointName := aws.StringValue(input.DBProxyEndpointName) - - if inputProxyName == "" && inputProxyEndpointName == "" { - return &rds.DescribeDBProxyEndpointsOutput{ - DBProxyEndpoints: m.DBProxyEndpoints, - }, nil - } - - var endpoints []*rds.DBProxyEndpoint - for _, dbProxyEndpoiont := range m.DBProxyEndpoints { - if inputProxyEndpointName != "" && - inputProxyEndpointName != aws.StringValue(dbProxyEndpoiont.DBProxyEndpointName) { - continue - } - - if inputProxyName != "" && - inputProxyName != aws.StringValue(dbProxyEndpoiont.DBProxyName) { - continue - } - - endpoints = append(endpoints, dbProxyEndpoiont) - } - if len(endpoints) == 0 { - return nil, trace.NotFound("proxy endpoint %v not found", aws.StringValue(input.DBProxyEndpointName)) - } - return &rds.DescribeDBProxyEndpointsOutput{DBProxyEndpoints: endpoints}, nil -} - -func (m *RDSMock) DescribeDBProxyTargetsWithContext(ctx aws.Context, input *rds.DescribeDBProxyTargetsInput, options ...request.Option) (*rds.DescribeDBProxyTargetsOutput, error) { - // only mocking to return a port here - return &rds.DescribeDBProxyTargetsOutput{ - Targets: []*rds.DBProxyTarget{{ - Port: aws.Int64(m.DBProxyTargetPort), - }}, - }, nil -} - -func (m *RDSMock) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { - fn(&rds.DescribeDBProxiesOutput{ - DBProxies: m.DBProxies, - }, true) - return nil -} - -func (m *RDSMock) DescribeDBProxyEndpointsPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, fn func(*rds.DescribeDBProxyEndpointsOutput, bool) bool, options ...request.Option) error { - fn(&rds.DescribeDBProxyEndpointsOutput{ - DBProxyEndpoints: m.DBProxyEndpoints, - }, true) - return nil -} - -func (m *RDSMock) ListTagsForResourceWithContext(ctx aws.Context, input *rds.ListTagsForResourceInput, options ...request.Option) (*rds.ListTagsForResourceOutput, error) { - return &rds.ListTagsForResourceOutput{}, nil -} - // IAMMock mocks AWS IAM API. type IAMMock struct { iamiface.IAMAPI @@ -394,137 +197,6 @@ func (m *IAMMock) DeleteUserPolicyWithContext(ctx aws.Context, input *iam.Delete return &iam.DeleteUserPolicyOutput{}, nil } -// RedshiftMock mocks AWS Redshift API. -type RedshiftMock struct { - redshiftiface.RedshiftAPI - Clusters []*redshift.Cluster - GetClusterCredentialsOutput *redshift.GetClusterCredentialsOutput -} - -func (m *RedshiftMock) GetClusterCredentialsWithContext(aws.Context, *redshift.GetClusterCredentialsInput, ...request.Option) (*redshift.GetClusterCredentialsOutput, error) { - if m.GetClusterCredentialsOutput == nil { - return nil, trace.AccessDenied("access denied") - } - return m.GetClusterCredentialsOutput, nil -} - -func (m *RedshiftMock) DescribeClustersWithContext(ctx aws.Context, input *redshift.DescribeClustersInput, options ...request.Option) (*redshift.DescribeClustersOutput, error) { - if aws.StringValue(input.ClusterIdentifier) == "" { - return &redshift.DescribeClustersOutput{ - Clusters: m.Clusters, - }, nil - } - for _, cluster := range m.Clusters { - if aws.StringValue(cluster.ClusterIdentifier) == aws.StringValue(input.ClusterIdentifier) { - return &redshift.DescribeClustersOutput{ - Clusters: []*redshift.Cluster{cluster}, - }, nil - } - } - return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.ClusterIdentifier)) -} - -func (m *RedshiftMock) DescribeClustersPagesWithContext(ctx aws.Context, input *redshift.DescribeClustersInput, fn func(*redshift.DescribeClustersOutput, bool) bool, options ...request.Option) error { - fn(&redshift.DescribeClustersOutput{ - Clusters: m.Clusters, - }, true) - return nil -} - -// RDSMockUnauth is a mock RDS client that returns access denied to each call. -type RDSMockUnauth struct { - rdsiface.RDSAPI -} - -func (m *RDSMockUnauth) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { - return trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { - return trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { - return trace.AccessDenied("unauthorized") -} - -// RDSMockByDBType is a mock RDS client that mocks API calls by DB type -type RDSMockByDBType struct { - rdsiface.RDSAPI - DBInstances rdsiface.RDSAPI - DBClusters rdsiface.RDSAPI - DBProxies rdsiface.RDSAPI -} - -func (m *RDSMockByDBType) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { - return m.DBInstances.DescribeDBInstancesWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { - return m.DBInstances.ModifyDBInstanceWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { - return m.DBInstances.DescribeDBInstancesPagesWithContext(ctx, input, fn, options...) -} - -func (m *RDSMockByDBType) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { - return m.DBClusters.DescribeDBClustersWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { - return m.DBClusters.ModifyDBClusterWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { - return m.DBClusters.DescribeDBClustersPagesWithContext(aws, input, fn, options...) -} - -func (m *RDSMockByDBType) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { - return m.DBProxies.DescribeDBProxiesWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { - return m.DBProxies.DescribeDBProxyEndpointsWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { - return m.DBProxies.DescribeDBProxiesPagesWithContext(ctx, input, fn, options...) -} - -// RedshiftMockUnauth is a mock Redshift client that returns access denied to each call. -type RedshiftMockUnauth struct { - redshiftiface.RedshiftAPI -} - -func (m *RedshiftMockUnauth) DescribeClustersWithContext(ctx aws.Context, input *redshift.DescribeClustersInput, options ...request.Option) (*redshift.DescribeClustersOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - // IAMErrorMock is a mock IAM client that returns the provided Error to all // APIs. If Error is not provided, all APIs returns trace.AccessDenied by // default. @@ -561,332 +233,6 @@ func (m *IAMErrorMock) PutUserPolicyWithContext(ctx aws.Context, input *iam.PutU return nil, trace.AccessDenied("unauthorized") } -// ElastiCache mocks AWS ElastiCache API. -type ElastiCacheMock struct { - elasticacheiface.ElastiCacheAPI - // Unauth set to true will make API calls return unauthorized errors. - Unauth bool - - ReplicationGroups []*elasticache.ReplicationGroup - Users []*elasticache.User - TagsByARN map[string][]*elasticache.Tag -} - -func (m *ElastiCacheMock) AddMockUser(user *elasticache.User, tagsMap map[string]string) { - m.Users = append(m.Users, user) - m.addTags(aws.StringValue(user.ARN), tagsMap) -} - -func (m *ElastiCacheMock) addTags(arn string, tagsMap map[string]string) { - if m.TagsByARN == nil { - m.TagsByARN = make(map[string][]*elasticache.Tag) - } - - var tags []*elasticache.Tag - for key, value := range tagsMap { - tags = append(tags, &elasticache.Tag{ - Key: aws.String(key), - Value: aws.String(value), - }) - } - m.TagsByARN[arn] = tags -} - -func (m *ElastiCacheMock) DescribeUsersWithContext(_ aws.Context, input *elasticache.DescribeUsersInput, opts ...request.Option) (*elasticache.DescribeUsersOutput, error) { - if m.Unauth { - return nil, trace.AccessDenied("unauthorized") - } - if input.UserId == nil { - return &elasticache.DescribeUsersOutput{Users: m.Users}, nil - } - for _, user := range m.Users { - if aws.StringValue(user.UserId) == aws.StringValue(input.UserId) { - return &elasticache.DescribeUsersOutput{Users: []*elasticache.User{user}}, nil - } - } - return nil, trace.NotFound("ElastiCache UserId %v not found", aws.StringValue(input.UserId)) -} - -func (m *ElastiCacheMock) DescribeReplicationGroupsWithContext(_ aws.Context, input *elasticache.DescribeReplicationGroupsInput, opts ...request.Option) (*elasticache.DescribeReplicationGroupsOutput, error) { - if m.Unauth { - return nil, trace.AccessDenied("unauthorized") - } - for _, replicationGroup := range m.ReplicationGroups { - if aws.StringValue(replicationGroup.ReplicationGroupId) == aws.StringValue(input.ReplicationGroupId) { - return &elasticache.DescribeReplicationGroupsOutput{ - ReplicationGroups: []*elasticache.ReplicationGroup{replicationGroup}, - }, nil - } - } - return nil, trace.NotFound("ElastiCache %v not found", aws.StringValue(input.ReplicationGroupId)) -} - -func (m *ElastiCacheMock) DescribeReplicationGroupsPagesWithContext(_ aws.Context, _ *elasticache.DescribeReplicationGroupsInput, fn func(*elasticache.DescribeReplicationGroupsOutput, bool) bool, _ ...request.Option) error { - if m.Unauth { - return trace.AccessDenied("unauthorized") - } - fn(&elasticache.DescribeReplicationGroupsOutput{ - ReplicationGroups: m.ReplicationGroups, - }, true) - return nil -} - -func (m *ElastiCacheMock) DescribeUsersPagesWithContext(_ aws.Context, _ *elasticache.DescribeUsersInput, fn func(*elasticache.DescribeUsersOutput, bool) bool, _ ...request.Option) error { - if m.Unauth { - return trace.AccessDenied("unauthorized") - } - fn(&elasticache.DescribeUsersOutput{ - Users: m.Users, - }, true) - return nil -} - -func (m *ElastiCacheMock) DescribeCacheClustersPagesWithContext(aws.Context, *elasticache.DescribeCacheClustersInput, func(*elasticache.DescribeCacheClustersOutput, bool) bool, ...request.Option) error { - if m.Unauth { - return trace.AccessDenied("unauthorized") - } - return trace.NotImplemented("elasticache:DescribeCacheClustersPagesWithContext is not implemented") -} - -func (m *ElastiCacheMock) DescribeCacheSubnetGroupsPagesWithContext(aws.Context, *elasticache.DescribeCacheSubnetGroupsInput, func(*elasticache.DescribeCacheSubnetGroupsOutput, bool) bool, ...request.Option) error { - if m.Unauth { - return trace.AccessDenied("unauthorized") - } - return trace.NotImplemented("elasticache:DescribeCacheSubnetGroupsPagesWithContext is not implemented") -} - -func (m *ElastiCacheMock) ListTagsForResourceWithContext(_ aws.Context, input *elasticache.ListTagsForResourceInput, _ ...request.Option) (*elasticache.TagListMessage, error) { - if m.Unauth { - return nil, trace.AccessDenied("unauthorized") - } - if m.TagsByARN == nil { - return nil, trace.NotFound("no tags") - } - - tags, ok := m.TagsByARN[aws.StringValue(input.ResourceName)] - if !ok { - return nil, trace.NotFound("no tags") - } - - return &elasticache.TagListMessage{ - TagList: tags, - }, nil -} - -func (m *ElastiCacheMock) ModifyUserWithContext(_ aws.Context, input *elasticache.ModifyUserInput, opts ...request.Option) (*elasticache.ModifyUserOutput, error) { - if m.Unauth { - return nil, trace.AccessDenied("unauthorized") - } - for _, user := range m.Users { - if aws.StringValue(user.UserId) == aws.StringValue(input.UserId) { - return &elasticache.ModifyUserOutput{}, nil - } - } - return nil, trace.NotFound("user %s not found", aws.StringValue(input.UserId)) -} - -type OpenSearchMock struct { - opensearchserviceiface.OpenSearchServiceAPI - - Domains []*opensearchservice.DomainStatus - TagsByARN map[string][]*opensearchservice.Tag -} - -func (o *OpenSearchMock) ListDomainNamesWithContext(aws.Context, *opensearchservice.ListDomainNamesInput, ...request.Option) (*opensearchservice.ListDomainNamesOutput, error) { - out := &opensearchservice.ListDomainNamesOutput{} - for _, domain := range o.Domains { - out.DomainNames = append(out.DomainNames, &opensearchservice.DomainInfo{ - DomainName: domain.DomainName, - EngineType: aws.String("OpenSearch"), - }) - } - - return out, nil -} - -func (o *OpenSearchMock) DescribeDomainsWithContext(aws.Context, *opensearchservice.DescribeDomainsInput, ...request.Option) (*opensearchservice.DescribeDomainsOutput, error) { - out := &opensearchservice.DescribeDomainsOutput{DomainStatusList: o.Domains} - return out, nil -} - -func (o *OpenSearchMock) ListTagsWithContext(_ aws.Context, request *opensearchservice.ListTagsInput, _ ...request.Option) (*opensearchservice.ListTagsOutput, error) { - tags, found := o.TagsByARN[aws.StringValue(request.ARN)] - if !found { - return nil, trace.NotFound("tags not found") - } - return &opensearchservice.ListTagsOutput{TagList: tags}, nil -} - -// MemoryDBMock mocks AWS MemoryDB API. -type MemoryDBMock struct { - memorydbiface.MemoryDBAPI - - Clusters []*memorydb.Cluster - Users []*memorydb.User - TagsByARN map[string][]*memorydb.Tag -} - -func (m *MemoryDBMock) AddMockUser(user *memorydb.User, tagsMap map[string]string) { - m.Users = append(m.Users, user) - m.addTags(aws.StringValue(user.ARN), tagsMap) -} - -func (m *MemoryDBMock) addTags(arn string, tagsMap map[string]string) { - if m.TagsByARN == nil { - m.TagsByARN = make(map[string][]*memorydb.Tag) - } - - var tags []*memorydb.Tag - for key, value := range tagsMap { - tags = append(tags, &memorydb.Tag{ - Key: aws.String(key), - Value: aws.String(value), - }) - } - m.TagsByARN[arn] = tags -} - -func (m *MemoryDBMock) DescribeSubnetGroupsWithContext(aws.Context, *memorydb.DescribeSubnetGroupsInput, ...request.Option) (*memorydb.DescribeSubnetGroupsOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memorydb.DescribeClustersInput, _ ...request.Option) (*memorydb.DescribeClustersOutput, error) { - if aws.StringValue(input.ClusterName) == "" { - return &memorydb.DescribeClustersOutput{ - Clusters: m.Clusters, - }, nil - } - - for _, cluster := range m.Clusters { - if aws.StringValue(input.ClusterName) == aws.StringValue(cluster.Name) { - return &memorydb.DescribeClustersOutput{ - Clusters: []*memorydb.Cluster{cluster}, - }, nil - } - } - return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.ClusterName)) -} - -func (m *MemoryDBMock) ListTagsWithContext(_ aws.Context, input *memorydb.ListTagsInput, _ ...request.Option) (*memorydb.ListTagsOutput, error) { - if m.TagsByARN == nil { - return nil, trace.NotFound("no tags") - } - - tags, ok := m.TagsByARN[aws.StringValue(input.ResourceArn)] - if !ok { - return nil, trace.NotFound("no tags") - } - - return &memorydb.ListTagsOutput{ - TagList: tags, - }, nil -} - -func (m *MemoryDBMock) DescribeUsersWithContext(aws.Context, *memorydb.DescribeUsersInput, ...request.Option) (*memorydb.DescribeUsersOutput, error) { - return &memorydb.DescribeUsersOutput{ - Users: m.Users, - }, nil -} - -func (m *MemoryDBMock) UpdateUserWithContext(_ aws.Context, input *memorydb.UpdateUserInput, opts ...request.Option) (*memorydb.UpdateUserOutput, error) { - for _, user := range m.Users { - if aws.StringValue(user.Name) == aws.StringValue(input.UserName) { - return &memorydb.UpdateUserOutput{}, nil - } - } - return nil, trace.NotFound("user %s not found", aws.StringValue(input.UserName)) -} - -// checkEngineFilters checks RDS filters to detect unrecognized engine filters. -func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVersion) error { - if len(filters) == 0 { - return nil - } - recognizedEngines := make(map[string]struct{}) - for _, e := range engineVersions { - recognizedEngines[aws.StringValue(e.Engine)] = struct{}{} - } - for _, f := range filters { - if aws.StringValue(f.Name) != "engine" { - continue - } - for _, v := range f.Values { - if _, ok := recognizedEngines[aws.StringValue(v)]; !ok { - return trace.Errorf("unrecognized engine name %q", aws.StringValue(v)) - } - } - } - return nil -} - -// applyInstanceFilters filters RDS DBInstances using the provided RDS filters. -func applyInstanceFilters(in []*rds.DBInstance, filters []*rds.Filter) ([]*rds.DBInstance, error) { - if len(filters) == 0 { - return in, nil - } - var out []*rds.DBInstance - efs := engineFilterSet(filters) - for _, instance := range in { - if instanceEngineMatches(instance, efs) { - out = append(out, instance) - } - } - return out, nil -} - -// applyClusterFilters filters RDS DBClusters using the provided RDS filters. -func applyClusterFilters(in []*rds.DBCluster, filters []*rds.Filter) ([]*rds.DBCluster, error) { - if len(filters) == 0 { - return in, nil - } - var out []*rds.DBCluster - efs := engineFilterSet(filters) - for _, cluster := range in { - if clusterEngineMatches(cluster, efs) { - out = append(out, cluster) - } - } - return out, nil -} - -// engineFilterSet builds a string set of engine names from a list of RDS filters. -func engineFilterSet(filters []*rds.Filter) map[string]struct{} { - out := make(map[string]struct{}) - for _, f := range filters { - if aws.StringValue(f.Name) != "engine" { - continue - } - for _, v := range f.Values { - out[aws.StringValue(v)] = struct{}{} - } - } - return out -} - -// instanceEngineMatches returns whether an RDS DBInstance engine matches any engine name in a filter set. -func instanceEngineMatches(instance *rds.DBInstance, filterSet map[string]struct{}) bool { - _, ok := filterSet[aws.StringValue(instance.Engine)] - return ok -} - -// clusterEngineMatches returns whether an RDS DBCluster engine matches any engine name in a filter set. -func clusterEngineMatches(cluster *rds.DBCluster, filterSet map[string]struct{}) bool { - _, ok := filterSet[aws.StringValue(cluster.Engine)] - return ok -} - -// RedshiftGetClusterCredentialsOutput return a sample redshift.GetClusterCredentialsOutput. -func RedshiftGetClusterCredentialsOutput(user, password string, clock clockwork.Clock) *redshift.GetClusterCredentialsOutput { - if clock == nil { - clock = clockwork.NewRealClock() - } - return &redshift.GetClusterCredentialsOutput{ - DbUser: aws.String(user), - DbPassword: aws.String(password), - Expiration: aws.Time(clock.Now().Add(15 * time.Minute)), - } -} - // EKSMock is a mock EKS client. type EKSMock struct { eksiface.EKSAPI diff --git a/lib/cloud/mocks/aws_elasticache.go b/lib/cloud/mocks/aws_elasticache.go new file mode 100644 index 0000000000000..944e35ee2a32d --- /dev/null +++ b/lib/cloud/mocks/aws_elasticache.go @@ -0,0 +1,195 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mocks + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/elasticache" + "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" + "github.com/gravitational/trace" +) + +// ElastiCache mocks AWS ElastiCache API. +type ElastiCacheMock struct { + elasticacheiface.ElastiCacheAPI + // Unauth set to true will make API calls return unauthorized errors. + Unauth bool + + ReplicationGroups []*elasticache.ReplicationGroup + Users []*elasticache.User + TagsByARN map[string][]*elasticache.Tag +} + +func (m *ElastiCacheMock) AddMockUser(user *elasticache.User, tagsMap map[string]string) { + m.Users = append(m.Users, user) + m.addTags(aws.StringValue(user.ARN), tagsMap) +} + +func (m *ElastiCacheMock) addTags(arn string, tagsMap map[string]string) { + if m.TagsByARN == nil { + m.TagsByARN = make(map[string][]*elasticache.Tag) + } + + var tags []*elasticache.Tag + for key, value := range tagsMap { + tags = append(tags, &elasticache.Tag{ + Key: aws.String(key), + Value: aws.String(value), + }) + } + m.TagsByARN[arn] = tags +} + +func (m *ElastiCacheMock) DescribeUsersWithContext(_ aws.Context, input *elasticache.DescribeUsersInput, opts ...request.Option) (*elasticache.DescribeUsersOutput, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + if input.UserId == nil { + return &elasticache.DescribeUsersOutput{Users: m.Users}, nil + } + for _, user := range m.Users { + if aws.StringValue(user.UserId) == aws.StringValue(input.UserId) { + return &elasticache.DescribeUsersOutput{Users: []*elasticache.User{user}}, nil + } + } + return nil, trace.NotFound("ElastiCache UserId %v not found", aws.StringValue(input.UserId)) +} + +func (m *ElastiCacheMock) DescribeReplicationGroupsWithContext(_ aws.Context, input *elasticache.DescribeReplicationGroupsInput, opts ...request.Option) (*elasticache.DescribeReplicationGroupsOutput, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + for _, replicationGroup := range m.ReplicationGroups { + if aws.StringValue(replicationGroup.ReplicationGroupId) == aws.StringValue(input.ReplicationGroupId) { + return &elasticache.DescribeReplicationGroupsOutput{ + ReplicationGroups: []*elasticache.ReplicationGroup{replicationGroup}, + }, nil + } + } + return nil, trace.NotFound("ElastiCache %v not found", aws.StringValue(input.ReplicationGroupId)) +} + +func (m *ElastiCacheMock) DescribeReplicationGroupsPagesWithContext(_ aws.Context, _ *elasticache.DescribeReplicationGroupsInput, fn func(*elasticache.DescribeReplicationGroupsOutput, bool) bool, _ ...request.Option) error { + if m.Unauth { + return trace.AccessDenied("unauthorized") + } + fn(&elasticache.DescribeReplicationGroupsOutput{ + ReplicationGroups: m.ReplicationGroups, + }, true) + return nil +} + +func (m *ElastiCacheMock) DescribeUsersPagesWithContext(_ aws.Context, _ *elasticache.DescribeUsersInput, fn func(*elasticache.DescribeUsersOutput, bool) bool, _ ...request.Option) error { + if m.Unauth { + return trace.AccessDenied("unauthorized") + } + fn(&elasticache.DescribeUsersOutput{ + Users: m.Users, + }, true) + return nil +} + +func (m *ElastiCacheMock) DescribeCacheClustersPagesWithContext(aws.Context, *elasticache.DescribeCacheClustersInput, func(*elasticache.DescribeCacheClustersOutput, bool) bool, ...request.Option) error { + if m.Unauth { + return trace.AccessDenied("unauthorized") + } + return trace.NotImplemented("elasticache:DescribeCacheClustersPagesWithContext is not implemented") +} + +func (m *ElastiCacheMock) DescribeCacheSubnetGroupsPagesWithContext(aws.Context, *elasticache.DescribeCacheSubnetGroupsInput, func(*elasticache.DescribeCacheSubnetGroupsOutput, bool) bool, ...request.Option) error { + if m.Unauth { + return trace.AccessDenied("unauthorized") + } + return trace.NotImplemented("elasticache:DescribeCacheSubnetGroupsPagesWithContext is not implemented") +} + +func (m *ElastiCacheMock) ListTagsForResourceWithContext(_ aws.Context, input *elasticache.ListTagsForResourceInput, _ ...request.Option) (*elasticache.TagListMessage, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + if m.TagsByARN == nil { + return nil, trace.NotFound("no tags") + } + + tags, ok := m.TagsByARN[aws.StringValue(input.ResourceName)] + if !ok { + return nil, trace.NotFound("no tags") + } + + return &elasticache.TagListMessage{ + TagList: tags, + }, nil +} + +func (m *ElastiCacheMock) ModifyUserWithContext(_ aws.Context, input *elasticache.ModifyUserInput, opts ...request.Option) (*elasticache.ModifyUserOutput, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + for _, user := range m.Users { + if aws.StringValue(user.UserId) == aws.StringValue(input.UserId) { + return &elasticache.ModifyUserOutput{}, nil + } + } + return nil, trace.NotFound("user %s not found", aws.StringValue(input.UserId)) +} + +// ElastiCacheCluster returns a sample elasticache.ReplicationGroup. +func ElastiCacheCluster(name, region string, opts ...func(*elasticache.ReplicationGroup)) *elasticache.ReplicationGroup { + cluster := &elasticache.ReplicationGroup{ + ARN: aws.String(fmt.Sprintf("arn:aws:elasticache:%s:123456789012:replicationgroup:%s", region, name)), + ReplicationGroupId: aws.String(name), + Status: aws.String("available"), + TransitEncryptionEnabled: aws.Bool(true), + + // Default has one primary endpoint in the only node group. + NodeGroups: []*elasticache.NodeGroup{{ + PrimaryEndpoint: &elasticache.Endpoint{ + Address: aws.String(fmt.Sprintf("master.%v-cluster.xxxxxx.use1.cache.amazonaws.com", name)), + Port: aws.Int64(6379), + }, + }}, + } + + for _, opt := range opts { + opt(cluster) + } + return cluster +} + +// WithElastiCacheReaderEndpoint is an option function for +// MakeElastiCacheCluster to set a reader endpoint. +func WithElastiCacheReaderEndpoint(cluster *elasticache.ReplicationGroup) { + cluster.NodeGroups = append(cluster.NodeGroups, &elasticache.NodeGroup{ + ReaderEndpoint: &elasticache.Endpoint{ + Address: aws.String(fmt.Sprintf("replica.%v-cluster.xxxxxx.use1.cache.amazonaws.com", aws.StringValue(cluster.ReplicationGroupId))), + Port: aws.Int64(6379), + }, + }) +} + +// WithElastiCacheConfigurationEndpoint in an option function for +// MakeElastiCacheCluster to set a configuration endpoint. +func WithElastiCacheConfigurationEndpoint(cluster *elasticache.ReplicationGroup) { + cluster.ClusterEnabled = aws.Bool(true) + cluster.ConfigurationEndpoint = &elasticache.Endpoint{ + Address: aws.String(fmt.Sprintf("clustercfg.%v-shards.xxxxxx.use1.cache.amazonaws.com", aws.StringValue(cluster.ReplicationGroupId))), + Port: aws.Int64(6379), + } +} diff --git a/lib/cloud/mocks/aws_memorydb.go b/lib/cloud/mocks/aws_memorydb.go new file mode 100644 index 0000000000000..83cb714ee49c9 --- /dev/null +++ b/lib/cloud/mocks/aws_memorydb.go @@ -0,0 +1,126 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mocks + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/memorydb" + "github.com/aws/aws-sdk-go/service/memorydb/memorydbiface" + "github.com/gravitational/trace" +) + +// MemoryDBMock mocks AWS MemoryDB API. +type MemoryDBMock struct { + memorydbiface.MemoryDBAPI + + Clusters []*memorydb.Cluster + Users []*memorydb.User + TagsByARN map[string][]*memorydb.Tag +} + +func (m *MemoryDBMock) AddMockUser(user *memorydb.User, tagsMap map[string]string) { + m.Users = append(m.Users, user) + m.addTags(aws.StringValue(user.ARN), tagsMap) +} + +func (m *MemoryDBMock) addTags(arn string, tagsMap map[string]string) { + if m.TagsByARN == nil { + m.TagsByARN = make(map[string][]*memorydb.Tag) + } + + var tags []*memorydb.Tag + for key, value := range tagsMap { + tags = append(tags, &memorydb.Tag{ + Key: aws.String(key), + Value: aws.String(value), + }) + } + m.TagsByARN[arn] = tags +} + +func (m *MemoryDBMock) DescribeSubnetGroupsWithContext(aws.Context, *memorydb.DescribeSubnetGroupsInput, ...request.Option) (*memorydb.DescribeSubnetGroupsOutput, error) { + return nil, trace.AccessDenied("unauthorized") +} + +func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memorydb.DescribeClustersInput, _ ...request.Option) (*memorydb.DescribeClustersOutput, error) { + if aws.StringValue(input.ClusterName) == "" { + return &memorydb.DescribeClustersOutput{ + Clusters: m.Clusters, + }, nil + } + + for _, cluster := range m.Clusters { + if aws.StringValue(input.ClusterName) == aws.StringValue(cluster.Name) { + return &memorydb.DescribeClustersOutput{ + Clusters: []*memorydb.Cluster{cluster}, + }, nil + } + } + return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.ClusterName)) +} + +func (m *MemoryDBMock) ListTagsWithContext(_ aws.Context, input *memorydb.ListTagsInput, _ ...request.Option) (*memorydb.ListTagsOutput, error) { + if m.TagsByARN == nil { + return nil, trace.NotFound("no tags") + } + + tags, ok := m.TagsByARN[aws.StringValue(input.ResourceArn)] + if !ok { + return nil, trace.NotFound("no tags") + } + + return &memorydb.ListTagsOutput{ + TagList: tags, + }, nil +} + +func (m *MemoryDBMock) DescribeUsersWithContext(aws.Context, *memorydb.DescribeUsersInput, ...request.Option) (*memorydb.DescribeUsersOutput, error) { + return &memorydb.DescribeUsersOutput{ + Users: m.Users, + }, nil +} + +func (m *MemoryDBMock) UpdateUserWithContext(_ aws.Context, input *memorydb.UpdateUserInput, opts ...request.Option) (*memorydb.UpdateUserOutput, error) { + for _, user := range m.Users { + if aws.StringValue(user.Name) == aws.StringValue(input.UserName) { + return &memorydb.UpdateUserOutput{}, nil + } + } + return nil, trace.NotFound("user %s not found", aws.StringValue(input.UserName)) +} + +// MemoryDBCluster returns a sample memorydb.Cluster. +func MemoryDBCluster(name, region string, opts ...func(*memorydb.Cluster)) *memorydb.Cluster { + cluster := &memorydb.Cluster{ + ARN: aws.String(fmt.Sprintf("arn:aws:memorydb:%s:123456789012:cluster:%s", region, name)), + Name: aws.String(name), + Status: aws.String("available"), + TLSEnabled: aws.Bool(true), + ClusterEndpoint: &memorydb.Endpoint{ + Address: aws.String(fmt.Sprintf("clustercfg.%s.xxxxxx.memorydb.%s.amazonaws.com", name, region)), + Port: aws.Int64(6379), + }, + } + + for _, opt := range opts { + opt(cluster) + } + return cluster +} diff --git a/lib/cloud/mocks/aws_opensearch.go b/lib/cloud/mocks/aws_opensearch.go new file mode 100644 index 0000000000000..2004660b5de44 --- /dev/null +++ b/lib/cloud/mocks/aws_opensearch.go @@ -0,0 +1,98 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mocks + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/opensearchservice" + "github.com/aws/aws-sdk-go/service/opensearchservice/opensearchserviceiface" + "github.com/gravitational/trace" +) + +type OpenSearchMock struct { + opensearchserviceiface.OpenSearchServiceAPI + + Domains []*opensearchservice.DomainStatus + TagsByARN map[string][]*opensearchservice.Tag +} + +func (o *OpenSearchMock) ListDomainNamesWithContext(aws.Context, *opensearchservice.ListDomainNamesInput, ...request.Option) (*opensearchservice.ListDomainNamesOutput, error) { + out := &opensearchservice.ListDomainNamesOutput{} + for _, domain := range o.Domains { + out.DomainNames = append(out.DomainNames, &opensearchservice.DomainInfo{ + DomainName: domain.DomainName, + EngineType: aws.String("OpenSearch"), + }) + } + + return out, nil +} + +func (o *OpenSearchMock) DescribeDomainsWithContext(aws.Context, *opensearchservice.DescribeDomainsInput, ...request.Option) (*opensearchservice.DescribeDomainsOutput, error) { + out := &opensearchservice.DescribeDomainsOutput{DomainStatusList: o.Domains} + return out, nil +} + +func (o *OpenSearchMock) ListTagsWithContext(_ aws.Context, request *opensearchservice.ListTagsInput, _ ...request.Option) (*opensearchservice.ListTagsOutput, error) { + tags, found := o.TagsByARN[aws.StringValue(request.ARN)] + if !found { + return nil, trace.NotFound("tags not found") + } + return &opensearchservice.ListTagsOutput{TagList: tags}, nil +} + +// OpenSearchDomain returns a sample opensearchservice.DomainStatus. +func OpenSearchDomain(name, region string, opts ...func(status *opensearchservice.DomainStatus)) *opensearchservice.DomainStatus { + domain := &opensearchservice.DomainStatus{ + ARN: aws.String(fmt.Sprintf("arn:aws:es:%s:123456789012:domain/%s", region, name)), + DomainId: aws.String("123456789012/" + name), + DomainName: aws.String(name), + Created: aws.Bool(true), + Deleted: aws.Bool(false), + EngineVersion: aws.String("OpenSearch_2.5"), + + Endpoint: aws.String(fmt.Sprintf("search-%s-aaaabbbbcccc4444.%s.es.amazonaws.com", name, region)), + } + + for _, opt := range opts { + opt(domain) + } + return domain +} + +func WithOpenSearchVPCEndpoint(name string) func(*opensearchservice.DomainStatus) { + return func(status *opensearchservice.DomainStatus) { + if status.Endpoints == nil { + status.Endpoints = map[string]*string{} + } + status.Endpoints[name] = aws.String(fmt.Sprintf("vpc-%v-%v", name, aws.StringValue(status.Endpoint))) + status.Endpoint = nil + } +} + +func WithOpenSearchCustomEndpoint(endpoint string) func(*opensearchservice.DomainStatus) { + return func(status *opensearchservice.DomainStatus) { + status.DomainEndpointOptions = &opensearchservice.DomainEndpointOptions{ + CustomEndpoint: aws.String(endpoint), + CustomEndpointEnabled: aws.Bool(true), + EnforceHTTPS: aws.Bool(true), + } + } +} diff --git a/lib/cloud/mocks/aws_rds.go b/lib/cloud/mocks/aws_rds.go new file mode 100644 index 0000000000000..18c8a32189c76 --- /dev/null +++ b/lib/cloud/mocks/aws_rds.go @@ -0,0 +1,463 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mocks + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go/service/rds/rdsiface" + "github.com/google/uuid" + "github.com/gravitational/trace" + + libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" +) + +// RDSMock mocks AWS RDS API. +type RDSMock struct { + rdsiface.RDSAPI + DBInstances []*rds.DBInstance + DBClusters []*rds.DBCluster + DBProxies []*rds.DBProxy + DBProxyEndpoints []*rds.DBProxyEndpoint + DBEngineVersions []*rds.DBEngineVersion + DBProxyTargetPort int64 +} + +func (m *RDSMock) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { + if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { + return nil, trace.Wrap(err) + } + instances, err := applyInstanceFilters(m.DBInstances, input.Filters) + if err != nil { + return nil, trace.Wrap(err) + } + if aws.StringValue(input.DBInstanceIdentifier) == "" { + return &rds.DescribeDBInstancesOutput{ + DBInstances: instances, + }, nil + } + for _, instance := range instances { + if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) { + return &rds.DescribeDBInstancesOutput{ + DBInstances: []*rds.DBInstance{instance}, + }, nil + } + } + return nil, trace.NotFound("instance %v not found", aws.StringValue(input.DBInstanceIdentifier)) +} + +func (m *RDSMock) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { + if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { + return trace.Wrap(err) + } + instances, err := applyInstanceFilters(m.DBInstances, input.Filters) + if err != nil { + return trace.Wrap(err) + } + fn(&rds.DescribeDBInstancesOutput{ + DBInstances: instances, + }, true) + return nil +} + +func (m *RDSMock) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { + if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { + return nil, trace.Wrap(err) + } + clusters, err := applyClusterFilters(m.DBClusters, input.Filters) + if err != nil { + return nil, trace.Wrap(err) + } + if aws.StringValue(input.DBClusterIdentifier) == "" { + return &rds.DescribeDBClustersOutput{ + DBClusters: clusters, + }, nil + } + for _, cluster := range clusters { + if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) { + return &rds.DescribeDBClustersOutput{ + DBClusters: []*rds.DBCluster{cluster}, + }, nil + } + } + return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier)) +} + +func (m *RDSMock) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { + if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { + return trace.Wrap(err) + } + clusters, err := applyClusterFilters(m.DBClusters, input.Filters) + if err != nil { + return trace.Wrap(err) + } + fn(&rds.DescribeDBClustersOutput{ + DBClusters: clusters, + }, true) + return nil +} + +func (m *RDSMock) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { + for i, instance := range m.DBInstances { + if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) { + if aws.BoolValue(input.EnableIAMDatabaseAuthentication) { + m.DBInstances[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) + } + return &rds.ModifyDBInstanceOutput{ + DBInstance: m.DBInstances[i], + }, nil + } + } + return nil, trace.NotFound("instance %v not found", aws.StringValue(input.DBInstanceIdentifier)) +} + +func (m *RDSMock) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { + for i, cluster := range m.DBClusters { + if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) { + if aws.BoolValue(input.EnableIAMDatabaseAuthentication) { + m.DBClusters[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) + } + return &rds.ModifyDBClusterOutput{ + DBCluster: m.DBClusters[i], + }, nil + } + } + return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier)) +} + +func (m *RDSMock) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { + if aws.StringValue(input.DBProxyName) == "" { + return &rds.DescribeDBProxiesOutput{ + DBProxies: m.DBProxies, + }, nil + } + for _, dbProxy := range m.DBProxies { + if aws.StringValue(dbProxy.DBProxyName) == aws.StringValue(input.DBProxyName) { + return &rds.DescribeDBProxiesOutput{ + DBProxies: []*rds.DBProxy{dbProxy}, + }, nil + } + } + return nil, trace.NotFound("proxy %v not found", aws.StringValue(input.DBProxyName)) +} + +func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { + inputProxyName := aws.StringValue(input.DBProxyName) + inputProxyEndpointName := aws.StringValue(input.DBProxyEndpointName) + + if inputProxyName == "" && inputProxyEndpointName == "" { + return &rds.DescribeDBProxyEndpointsOutput{ + DBProxyEndpoints: m.DBProxyEndpoints, + }, nil + } + + var endpoints []*rds.DBProxyEndpoint + for _, dbProxyEndpoiont := range m.DBProxyEndpoints { + if inputProxyEndpointName != "" && + inputProxyEndpointName != aws.StringValue(dbProxyEndpoiont.DBProxyEndpointName) { + continue + } + + if inputProxyName != "" && + inputProxyName != aws.StringValue(dbProxyEndpoiont.DBProxyName) { + continue + } + + endpoints = append(endpoints, dbProxyEndpoiont) + } + if len(endpoints) == 0 { + return nil, trace.NotFound("proxy endpoint %v not found", aws.StringValue(input.DBProxyEndpointName)) + } + return &rds.DescribeDBProxyEndpointsOutput{DBProxyEndpoints: endpoints}, nil +} + +func (m *RDSMock) DescribeDBProxyTargetsWithContext(ctx aws.Context, input *rds.DescribeDBProxyTargetsInput, options ...request.Option) (*rds.DescribeDBProxyTargetsOutput, error) { + // only mocking to return a port here + return &rds.DescribeDBProxyTargetsOutput{ + Targets: []*rds.DBProxyTarget{{ + Port: aws.Int64(m.DBProxyTargetPort), + }}, + }, nil +} + +func (m *RDSMock) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { + fn(&rds.DescribeDBProxiesOutput{ + DBProxies: m.DBProxies, + }, true) + return nil +} + +func (m *RDSMock) DescribeDBProxyEndpointsPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, fn func(*rds.DescribeDBProxyEndpointsOutput, bool) bool, options ...request.Option) error { + fn(&rds.DescribeDBProxyEndpointsOutput{ + DBProxyEndpoints: m.DBProxyEndpoints, + }, true) + return nil +} + +func (m *RDSMock) ListTagsForResourceWithContext(ctx aws.Context, input *rds.ListTagsForResourceInput, options ...request.Option) (*rds.ListTagsForResourceOutput, error) { + return &rds.ListTagsForResourceOutput{}, nil +} + +// RDSMockUnauth is a mock RDS client that returns access denied to each call. +type RDSMockUnauth struct { + rdsiface.RDSAPI +} + +func (m *RDSMockUnauth) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { + return nil, trace.AccessDenied("unauthorized") +} + +func (m *RDSMockUnauth) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { + return nil, trace.AccessDenied("unauthorized") +} + +func (m *RDSMockUnauth) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { + return nil, trace.AccessDenied("unauthorized") +} + +func (m *RDSMockUnauth) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { + return nil, trace.AccessDenied("unauthorized") +} + +func (m *RDSMockUnauth) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { + return trace.AccessDenied("unauthorized") +} + +func (m *RDSMockUnauth) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { + return trace.AccessDenied("unauthorized") +} + +func (m *RDSMockUnauth) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { + return nil, trace.AccessDenied("unauthorized") +} + +func (m *RDSMockUnauth) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { + return nil, trace.AccessDenied("unauthorized") +} + +func (m *RDSMockUnauth) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { + return trace.AccessDenied("unauthorized") +} + +// RDSMockByDBType is a mock RDS client that mocks API calls by DB type +type RDSMockByDBType struct { + rdsiface.RDSAPI + DBInstances rdsiface.RDSAPI + DBClusters rdsiface.RDSAPI + DBProxies rdsiface.RDSAPI +} + +func (m *RDSMockByDBType) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { + return m.DBInstances.DescribeDBInstancesWithContext(ctx, input, options...) +} + +func (m *RDSMockByDBType) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { + return m.DBInstances.ModifyDBInstanceWithContext(ctx, input, options...) +} + +func (m *RDSMockByDBType) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { + return m.DBInstances.DescribeDBInstancesPagesWithContext(ctx, input, fn, options...) +} + +func (m *RDSMockByDBType) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { + return m.DBClusters.DescribeDBClustersWithContext(ctx, input, options...) +} + +func (m *RDSMockByDBType) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { + return m.DBClusters.ModifyDBClusterWithContext(ctx, input, options...) +} + +func (m *RDSMockByDBType) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { + return m.DBClusters.DescribeDBClustersPagesWithContext(aws, input, fn, options...) +} + +func (m *RDSMockByDBType) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { + return m.DBProxies.DescribeDBProxiesWithContext(ctx, input, options...) +} + +func (m *RDSMockByDBType) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { + return m.DBProxies.DescribeDBProxyEndpointsWithContext(ctx, input, options...) +} + +func (m *RDSMockByDBType) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { + return m.DBProxies.DescribeDBProxiesPagesWithContext(ctx, input, fn, options...) +} + +// checkEngineFilters checks RDS filters to detect unrecognized engine filters. +func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVersion) error { + if len(filters) == 0 { + return nil + } + recognizedEngines := make(map[string]struct{}) + for _, e := range engineVersions { + recognizedEngines[aws.StringValue(e.Engine)] = struct{}{} + } + for _, f := range filters { + if aws.StringValue(f.Name) != "engine" { + continue + } + for _, v := range f.Values { + if _, ok := recognizedEngines[aws.StringValue(v)]; !ok { + return trace.Errorf("unrecognized engine name %q", aws.StringValue(v)) + } + } + } + return nil +} + +// applyInstanceFilters filters RDS DBInstances using the provided RDS filters. +func applyInstanceFilters(in []*rds.DBInstance, filters []*rds.Filter) ([]*rds.DBInstance, error) { + if len(filters) == 0 { + return in, nil + } + var out []*rds.DBInstance + efs := engineFilterSet(filters) + for _, instance := range in { + if instanceEngineMatches(instance, efs) { + out = append(out, instance) + } + } + return out, nil +} + +// applyClusterFilters filters RDS DBClusters using the provided RDS filters. +func applyClusterFilters(in []*rds.DBCluster, filters []*rds.Filter) ([]*rds.DBCluster, error) { + if len(filters) == 0 { + return in, nil + } + var out []*rds.DBCluster + efs := engineFilterSet(filters) + for _, cluster := range in { + if clusterEngineMatches(cluster, efs) { + out = append(out, cluster) + } + } + return out, nil +} + +// engineFilterSet builds a string set of engine names from a list of RDS filters. +func engineFilterSet(filters []*rds.Filter) map[string]struct{} { + out := make(map[string]struct{}) + for _, f := range filters { + if aws.StringValue(f.Name) != "engine" { + continue + } + for _, v := range f.Values { + out[aws.StringValue(v)] = struct{}{} + } + } + return out +} + +// instanceEngineMatches returns whether an RDS DBInstance engine matches any engine name in a filter set. +func instanceEngineMatches(instance *rds.DBInstance, filterSet map[string]struct{}) bool { + _, ok := filterSet[aws.StringValue(instance.Engine)] + return ok +} + +// clusterEngineMatches returns whether an RDS DBCluster engine matches any engine name in a filter set. +func clusterEngineMatches(cluster *rds.DBCluster, filterSet map[string]struct{}) bool { + _, ok := filterSet[aws.StringValue(cluster.Engine)] + return ok +} + +// RDSInstance returns a sample rds.DBInstance. +func RDSInstance(name, region string, labels map[string]string, opts ...func(*rds.DBInstance)) *rds.DBInstance { + instance := &rds.DBInstance{ + DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db:%v", region, name)), + DBInstanceIdentifier: aws.String(name), + DbiResourceId: aws.String(uuid.New().String()), + Engine: aws.String("postgres"), + DBInstanceStatus: aws.String("available"), + Endpoint: &rds.Endpoint{ + Address: aws.String(fmt.Sprintf("%v.aabbccdd.%v.rds.amazonaws.com", name, region)), + Port: aws.Int64(5432), + }, + TagList: libcloudaws.LabelsToTags[rds.Tag](labels), + } + for _, opt := range opts { + opt(instance) + } + return instance +} + +// RDSCluster returns a sample rds.DBCluster. +func RDSCluster(name, region string, labels map[string]string, opts ...func(*rds.DBCluster)) *rds.DBCluster { + cluster := &rds.DBCluster{ + DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:cluster:%v", region, name)), + DBClusterIdentifier: aws.String(name), + DbClusterResourceId: aws.String(uuid.New().String()), + Engine: aws.String("aurora-mysql"), + EngineMode: aws.String("provisioned"), + Status: aws.String("available"), + Endpoint: aws.String(fmt.Sprintf("%v.cluster-aabbccdd.%v.rds.amazonaws.com", name, region)), + ReaderEndpoint: aws.String(fmt.Sprintf("%v-co.cluster-aabbccdd.%v.rds.amazonaws.com", name, region)), + Port: aws.Int64(3306), + TagList: libcloudaws.LabelsToTags[rds.Tag](labels), + DBClusterMembers: []*rds.DBClusterMember{{ + IsClusterWriter: aws.Bool(true), // One writer by default. + }}, + } + for _, opt := range opts { + opt(cluster) + } + return cluster +} + +func WithRDSClusterReader(cluster *rds.DBCluster) { + cluster.DBClusterMembers = append(cluster.DBClusterMembers, &rds.DBClusterMember{ + IsClusterWriter: aws.Bool(false), // Add reader. + }) +} + +func WithRDSClusterCustomEndpoint(name string) func(*rds.DBCluster) { + return func(cluster *rds.DBCluster) { + parsed, _ := arn.Parse(aws.StringValue(cluster.DBClusterArn)) + cluster.CustomEndpoints = append(cluster.CustomEndpoints, aws.String( + fmt.Sprintf("%v.cluster-custom-aabbccdd.%v.rds.amazonaws.com", name, parsed.Region), + )) + } +} + +// RDSProxy returns a sample rds.DBProxy. +func RDSProxy(name, region, vpcID string) *rds.DBProxy { + return &rds.DBProxy{ + DBProxyArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:123456789012:db-proxy:prx-%s", region, name)), + DBProxyName: aws.String(name), + EngineFamily: aws.String(rds.EngineFamilyMysql), + Endpoint: aws.String(fmt.Sprintf("%s.proxy-aabbccdd.%s.rds.amazonaws.com", name, region)), + VpcId: aws.String(vpcID), + RequireTLS: aws.Bool(true), + Status: aws.String("available"), + } +} + +// RDSProxyCustomEndpoint returns a sample rds.DBProxyEndpoint. +func RDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, name, region string) *rds.DBProxyEndpoint { + return &rds.DBProxyEndpoint{ + Endpoint: aws.String(fmt.Sprintf("%s.endpoint.proxy-aabbccdd.%s.rds.amazonaws.com", name, region)), + DBProxyEndpointName: aws.String(name), + DBProxyName: rdsProxy.DBProxyName, + DBProxyEndpointArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db-proxy-endpoint:prx-endpoint-%v", region, name)), + TargetRole: aws.String(rds.DBProxyEndpointTargetRoleReadOnly), + Status: aws.String("available"), + } +} diff --git a/lib/cloud/mocks/aws_redshift.go b/lib/cloud/mocks/aws_redshift.go new file mode 100644 index 0000000000000..6db5293235237 --- /dev/null +++ b/lib/cloud/mocks/aws_redshift.go @@ -0,0 +1,107 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mocks + +import ( + "fmt" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/redshift" + "github.com/aws/aws-sdk-go/service/redshift/redshiftiface" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" +) + +// RedshiftMock mocks AWS Redshift API. +type RedshiftMock struct { + redshiftiface.RedshiftAPI + Clusters []*redshift.Cluster + GetClusterCredentialsOutput *redshift.GetClusterCredentialsOutput +} + +func (m *RedshiftMock) GetClusterCredentialsWithContext(aws.Context, *redshift.GetClusterCredentialsInput, ...request.Option) (*redshift.GetClusterCredentialsOutput, error) { + if m.GetClusterCredentialsOutput == nil { + return nil, trace.AccessDenied("access denied") + } + return m.GetClusterCredentialsOutput, nil +} + +func (m *RedshiftMock) DescribeClustersWithContext(ctx aws.Context, input *redshift.DescribeClustersInput, options ...request.Option) (*redshift.DescribeClustersOutput, error) { + if aws.StringValue(input.ClusterIdentifier) == "" { + return &redshift.DescribeClustersOutput{ + Clusters: m.Clusters, + }, nil + } + for _, cluster := range m.Clusters { + if aws.StringValue(cluster.ClusterIdentifier) == aws.StringValue(input.ClusterIdentifier) { + return &redshift.DescribeClustersOutput{ + Clusters: []*redshift.Cluster{cluster}, + }, nil + } + } + return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.ClusterIdentifier)) +} + +func (m *RedshiftMock) DescribeClustersPagesWithContext(ctx aws.Context, input *redshift.DescribeClustersInput, fn func(*redshift.DescribeClustersOutput, bool) bool, options ...request.Option) error { + fn(&redshift.DescribeClustersOutput{ + Clusters: m.Clusters, + }, true) + return nil +} + +// RedshiftMockUnauth is a mock Redshift client that returns access denied to each call. +type RedshiftMockUnauth struct { + redshiftiface.RedshiftAPI +} + +func (m *RedshiftMockUnauth) DescribeClustersWithContext(ctx aws.Context, input *redshift.DescribeClustersInput, options ...request.Option) (*redshift.DescribeClustersOutput, error) { + return nil, trace.AccessDenied("unauthorized") +} + +// RedshiftGetClusterCredentialsOutput return a sample redshift.GetClusterCredentialsOutput. +func RedshiftGetClusterCredentialsOutput(user, password string, clock clockwork.Clock) *redshift.GetClusterCredentialsOutput { + if clock == nil { + clock = clockwork.NewRealClock() + } + return &redshift.GetClusterCredentialsOutput{ + DbUser: aws.String(user), + DbPassword: aws.String(password), + Expiration: aws.Time(clock.Now().Add(15 * time.Minute)), + } +} + +// RedshiftCluster returns a sample redshift.Cluster. +func RedshiftCluster(name, region string, labels map[string]string, opts ...func(*redshift.Cluster)) *redshift.Cluster { + cluster := &redshift.Cluster{ + ClusterIdentifier: aws.String(name), + ClusterNamespaceArn: aws.String(fmt.Sprintf("arn:aws:redshift:%s:123456789012:namespace:%s", region, name)), + ClusterStatus: aws.String("available"), + Endpoint: &redshift.Endpoint{ + Address: aws.String(fmt.Sprintf("%v.aabbccdd.%v.redshift.amazonaws.com", name, region)), + Port: aws.Int64(5439), + }, + Tags: libcloudaws.LabelsToTags[redshift.Tag](labels), + } + for _, opt := range opts { + opt(cluster) + } + return cluster +} diff --git a/lib/services/database.go b/lib/services/database.go index ebf45a192602f..c452972eeb296 100644 --- a/lib/services/database.go +++ b/lib/services/database.go @@ -793,6 +793,60 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster) (types.Da return databases, trace.NewAggregate(errors...) } +// NewDatabasesFromRDSCluster creates all database resources from an RDS Aurora +// cluster. +func NewDatabasesFromRDSCluster(cluster *rds.DBCluster) (types.Databases, error) { + var errors []error + var databases types.Databases + + // Find out what types of instances the cluster has. Some examples: + // - Aurora cluster with one instance: one writer + // - Aurora cluster with three instances: one writer and two readers + // - Secondary cluster of a global database: one or more readers + var hasWriterInstance, hasReaderInstance bool + for _, clusterMember := range cluster.DBClusterMembers { + if clusterMember != nil { + if aws.BoolValue(clusterMember.IsClusterWriter) { + hasWriterInstance = true + } else { + hasReaderInstance = true + } + } + } + + // Add a database from primary endpoint, if any writer instances. + if cluster.Endpoint != nil && hasWriterInstance { + database, err := NewDatabaseFromRDSCluster(cluster) + if err != nil { + errors = append(errors, err) + } else { + databases = append(databases, database) + } + } + + // Add a database from reader endpoint, if any reader instances. + // https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/Aurora.Overview.Endpoints.html#Aurora.Endpoints.Reader + if cluster.ReaderEndpoint != nil && hasReaderInstance { + database, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster) + if err != nil { + errors = append(errors, err) + } else { + databases = append(databases, database) + } + } + + // Add databases from custom endpoints + if len(cluster.CustomEndpoints) > 0 { + customEndpointDatabases, err := NewDatabasesFromRDSClusterCustomEndpoints(cluster) + if err != nil { + errors = append(errors, err) + } + databases = append(databases, customEndpointDatabases...) + } + + return databases, trace.NewAggregate(errors...) +} + // NewDatabaseFromRDSProxy creates database resource from RDS Proxy. func NewDatabaseFromRDSProxy(dbProxy *rds.DBProxy, port int64, tags []*rds.Tag) (types.Database, error) { metadata, err := MetadataFromRDSProxy(dbProxy) @@ -910,6 +964,27 @@ func NewDatabasesFromElastiCacheNodeGroups(cluster *elasticache.ReplicationGroup return databases, nil } +// NewDatabasesFromElastiCacheReplicationGroup creates all database resources +// from an ElastiCache ReplicationGroup. +func NewDatabasesFromElastiCacheReplicationGroup(cluster *elasticache.ReplicationGroup, extraLabels map[string]string) (types.Databases, error) { + // Create database using configuration endpoint for Redis with cluster + // mode enabled. + if aws.BoolValue(cluster.ClusterEnabled) { + database, err := NewDatabaseFromElastiCacheConfigurationEndpoint(cluster, extraLabels) + if err != nil { + return nil, trace.Wrap(err) + } + return types.Databases{database}, nil + } + + // Create databases using primary and reader endpoints for Redis with + // cluster mode disabled. When cluster mode is disabled, it is expected + // there is only one node group (aka shard) with one primary endpoint + // and one reader endpoint. + databases, err := NewDatabasesFromElastiCacheNodeGroups(cluster, extraLabels) + return databases, trace.Wrap(err) +} + // newElastiCacheDatabase returns a new ElastiCache database. func newElastiCacheDatabase(cluster *elasticache.ReplicationGroup, endpoint *elasticache.Endpoint, endpointType string, extraLabels map[string]string) (types.Database, error) { metadata, err := MetadataFromElastiCacheCluster(cluster, endpointType) diff --git a/lib/srv/discovery/fetchers/db/aws_elasticache.go b/lib/srv/discovery/fetchers/db/aws_elasticache.go index fab1883190377..d16fa61d05127 100644 --- a/lib/srv/discovery/fetchers/db/aws_elasticache.go +++ b/lib/srv/discovery/fetchers/db/aws_elasticache.go @@ -146,28 +146,11 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels extraLabels := services.ExtraElastiCacheLabels(cluster, tags, allNodes, allSubnetGroups) - // Create database using configuration endpoint for Redis with cluster - // mode enabled. - if aws.BoolValue(cluster.ClusterEnabled) { - if database, err := services.NewDatabaseFromElastiCacheConfigurationEndpoint(cluster, extraLabels); err != nil { - f.log.Infof("Could not convert ElastiCache cluster %q configuration endpoint to database resource: %v.", - aws.StringValue(cluster.ReplicationGroupId), err) - } else { - databases = append(databases, database) - } - - continue - } - - // Create databases using primary and reader endpoints for Redis with - // cluster mode disabled. When cluster mode is disabled, it is expected - // there is only one node group (aka shard) with one primary endpoint - // and one reader endpoint. - if databasesFromNodeGroups, err := services.NewDatabasesFromElastiCacheNodeGroups(cluster, extraLabels); err != nil { - f.log.Infof("Could not convert ElastiCache cluster %q node groups to database resources: %v.", + if dbs, err := services.NewDatabasesFromElastiCacheReplicationGroup(cluster, extraLabels); err != nil { + f.log.Infof("Could not convert ElastiCache cluster %q to database resources: %v.", aws.StringValue(cluster.ReplicationGroupId), err) } else { - databases = append(databases, databasesFromNodeGroups...) + databases = append(databases, dbs...) } } diff --git a/lib/srv/discovery/fetchers/db/aws_elasticache_test.go b/lib/srv/discovery/fetchers/db/aws_elasticache_test.go index 7afb22efe3acd..e7203296c0224 100644 --- a/lib/srv/discovery/fetchers/db/aws_elasticache_test.go +++ b/lib/srv/discovery/fetchers/db/aws_elasticache_test.go @@ -17,7 +17,6 @@ limitations under the License. package db import ( - "fmt" "testing" "github.com/aws/aws-sdk-go/aws" @@ -33,8 +32,8 @@ import ( func TestElastiCacheFetcher(t *testing.T) { t.Parallel() - elasticacheProd, elasticacheDatabaseProd, elasticacheProdTags := makeElastiCacheCluster(t, "ec1", "us-east-1", "prod") - elasticacheQA, elasticacheDatabaseQA, elasticacheQATags := makeElastiCacheCluster(t, "ec2", "us-east-1", "qa", withElastiCacheConfigurationEndpoint()) + elasticacheProd, elasticacheDatabasesProd, elasticacheProdTags := makeElastiCacheCluster(t, "ec1", "us-east-1", "prod", mocks.WithElastiCacheReaderEndpoint) + elasticacheQA, elasticacheDatabasesQA, elasticacheQATags := makeElastiCacheCluster(t, "ec2", "us-east-1", "qa", mocks.WithElastiCacheConfigurationEndpoint) elasticacheUnavailable, _, elasticacheUnavailableTags := makeElastiCacheCluster(t, "ec4", "us-east-1", "prod", func(cluster *elasticache.ReplicationGroup) { cluster.Status = aws.String("deleting") }) @@ -58,7 +57,7 @@ func TestElastiCacheFetcher(t *testing.T) { }, }, inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels), - wantDatabases: types.Databases{elasticacheDatabaseProd, elasticacheDatabaseQA}, + wantDatabases: append(elasticacheDatabasesProd, elasticacheDatabasesQA...), }, { name: "fetch prod", @@ -69,7 +68,7 @@ func TestElastiCacheFetcher(t *testing.T) { }, }, inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", envProdLabels), - wantDatabases: types.Databases{elasticacheDatabaseProd}, + wantDatabases: elasticacheDatabasesProd, }, { name: "skip unavailable", @@ -80,7 +79,7 @@ func TestElastiCacheFetcher(t *testing.T) { }, }, inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels), - wantDatabases: types.Databases{elasticacheDatabaseProd}, + wantDatabases: elasticacheDatabasesProd, }, { name: "skip unsupported", @@ -91,31 +90,14 @@ func TestElastiCacheFetcher(t *testing.T) { }, }, inputMatchers: makeAWSMatchersForType(services.AWSMatcherElastiCache, "us-east-1", wildcardLabels), - wantDatabases: types.Databases{elasticacheDatabaseProd}, + wantDatabases: elasticacheDatabasesProd, }, } testAWSFetchers(t, tests...) } -func makeElastiCacheCluster(t *testing.T, name, region, env string, opts ...func(*elasticache.ReplicationGroup)) (*elasticache.ReplicationGroup, types.Database, []*elasticache.Tag) { - cluster := &elasticache.ReplicationGroup{ - ARN: aws.String(fmt.Sprintf("arn:aws:elasticache:%s:123456789012:replicationgroup:%s", region, name)), - ReplicationGroupId: aws.String(name), - Status: aws.String("available"), - TransitEncryptionEnabled: aws.Bool(true), - - // Default has one primary endpoint in the only node group. - NodeGroups: []*elasticache.NodeGroup{{ - PrimaryEndpoint: &elasticache.Endpoint{ - Address: aws.String("primary.localhost"), - Port: aws.Int64(6379), - }, - }}, - } - - for _, opt := range opts { - opt(cluster) - } +func makeElastiCacheCluster(t *testing.T, name, region, env string, opts ...func(*elasticache.ReplicationGroup)) (*elasticache.ReplicationGroup, types.Databases, []*elasticache.Tag) { + cluster := mocks.ElastiCacheCluster(name, region, opts...) tags := []*elasticache.Tag{{ Key: aws.String("env"), @@ -126,23 +108,10 @@ func makeElastiCacheCluster(t *testing.T, name, region, env string, opts ...func if aws.BoolValue(cluster.ClusterEnabled) { database, err := services.NewDatabaseFromElastiCacheConfigurationEndpoint(cluster, extraLabels) require.NoError(t, err) - return cluster, database, tags + return cluster, types.Databases{database}, tags } databases, err := services.NewDatabasesFromElastiCacheNodeGroups(cluster, extraLabels) require.NoError(t, err) - require.Len(t, databases, 1) - return cluster, databases[0], tags -} - -// withElastiCacheConfigurationEndpoint returns an option function for -// makeElastiCacheCluster to set a configuration endpoint. -func withElastiCacheConfigurationEndpoint() func(*elasticache.ReplicationGroup) { - return func(cluster *elasticache.ReplicationGroup) { - cluster.ClusterEnabled = aws.Bool(true) - cluster.ConfigurationEndpoint = &elasticache.Endpoint{ - Address: aws.String("configuration.localhost"), - Port: aws.Int64(6379), - } - } + return cluster, databases, tags } diff --git a/lib/srv/discovery/fetchers/db/aws_memorydb_test.go b/lib/srv/discovery/fetchers/db/aws_memorydb_test.go index ab9340a6eac66..2b81c39ef5c19 100644 --- a/lib/srv/discovery/fetchers/db/aws_memorydb_test.go +++ b/lib/srv/discovery/fetchers/db/aws_memorydb_test.go @@ -16,7 +16,6 @@ limitations under the License. package db import ( - "fmt" "testing" "github.com/aws/aws-sdk-go/aws" @@ -97,20 +96,7 @@ func TestMemoryDBFetcher(t *testing.T) { } func makeMemoryDBCluster(t *testing.T, name, region, env string, opts ...func(*memorydb.Cluster)) (*memorydb.Cluster, types.Database, []*memorydb.Tag) { - cluster := &memorydb.Cluster{ - ARN: aws.String(fmt.Sprintf("arn:aws:memorydb:%s:123456789012:cluster:%s", region, name)), - Name: aws.String(name), - Status: aws.String("available"), - TLSEnabled: aws.Bool(true), - ClusterEndpoint: &memorydb.Endpoint{ - Address: aws.String("memorydb.localhost"), - Port: aws.Int64(6379), - }, - } - - for _, opt := range opts { - opt(cluster) - } + cluster := mocks.MemoryDBCluster(name, region, opts...) tags := []*memorydb.Tag{{ Key: aws.String("env"), diff --git a/lib/srv/discovery/fetchers/db/aws_opensearch_test.go b/lib/srv/discovery/fetchers/db/aws_opensearch_test.go index 3381a1c701069..c259762408430 100644 --- a/lib/srv/discovery/fetchers/db/aws_opensearch_test.go +++ b/lib/srv/discovery/fetchers/db/aws_opensearch_test.go @@ -15,7 +15,6 @@ package db import ( - "fmt" "testing" "github.com/aws/aws-sdk-go/aws" @@ -38,21 +37,8 @@ func TestOpenSearchFetcher(t *testing.T) { status.Created = aws.Bool(false) }) - prodVPC, prodVPCDBs := makeOpenSearchDomain(t, tags, "os3", "us-east-1", "prod", func(status *opensearchservice.DomainStatus) { - if status.Endpoints == nil { - status.Endpoints = map[string]*string{} - } - status.Endpoints["vpc"] = aws.String("vpc-" + aws.StringValue(status.Endpoint)) - status.Endpoint = nil - }) - - prodCustom, prodCustomDBs := makeOpenSearchDomain(t, tags, "os4", "us-east-1", "prod", func(status *opensearchservice.DomainStatus) { - status.DomainEndpointOptions = &opensearchservice.DomainEndpointOptions{ - CustomEndpoint: aws.String("opensearch.example.com"), - CustomEndpointEnabled: aws.Bool(true), - EnforceHTTPS: aws.Bool(true), - } - }) + prodVPC, prodVPCDBs := makeOpenSearchDomain(t, tags, "os3", "us-east-1", "prod", mocks.WithOpenSearchVPCEndpoint("vpc")) + prodCustom, prodCustomDBs := makeOpenSearchDomain(t, tags, "os4", "us-east-1", "prod", mocks.WithOpenSearchCustomEndpoint("opensearch.example.com")) test, testDBs := makeOpenSearchDomain(t, tags, "os5", "us-east-1", "test") @@ -128,20 +114,7 @@ func TestOpenSearchFetcher(t *testing.T) { } func makeOpenSearchDomain(t *testing.T, tagMap map[string][]*opensearchservice.Tag, name, region, env string, opts ...func(status *opensearchservice.DomainStatus)) (*opensearchservice.DomainStatus, types.Databases) { - domain := &opensearchservice.DomainStatus{ - ARN: aws.String(fmt.Sprintf("arn:aws:es:%s:123456789012:domain/%s", region, name)), - DomainId: aws.String("123456789012/" + name), - DomainName: aws.String(name), - Created: aws.Bool(true), - Deleted: aws.Bool(false), - EngineVersion: aws.String("OpenSearch_2.5"), - - Endpoint: aws.String(fmt.Sprintf("search-%s-aaaabbbbcccc4444.%s.es.amazonaws.com", name, region)), - } - - for _, opt := range opts { - opt(domain) - } + domain := mocks.OpenSearchDomain(name, region, opts...) tags := []*opensearchservice.Tag{{ Key: aws.String("env"), diff --git a/lib/srv/discovery/fetchers/db/aws_rds.go b/lib/srv/discovery/fetchers/db/aws_rds.go index 0d8aacf3cc7ca..18bc4e51c552f 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds.go +++ b/lib/srv/discovery/fetchers/db/aws_rds.go @@ -215,54 +215,12 @@ func (f *rdsAuroraClustersFetcher) getAuroraDatabases(ctx context.Context) (type continue } - // Find out what types of instances the cluster has. Some examples: - // - Aurora cluster with one instance: one writer - // - Aurora cluster with three instances: one writer and two readers - // - Secondary cluster of a global database: one or more readers - var hasWriterInstance, hasReaderInstance bool - for _, clusterMember := range cluster.DBClusterMembers { - if clusterMember != nil { - if aws.BoolValue(clusterMember.IsClusterWriter) { - hasWriterInstance = true - } else { - hasReaderInstance = true - } - } - } - - // Add a database from primary endpoint, if any writer instances. - if cluster.Endpoint != nil && hasWriterInstance { - database, err := services.NewDatabaseFromRDSCluster(cluster) - if err != nil { - f.log.Warnf("Could not convert RDS cluster %q to database resource: %v.", - aws.StringValue(cluster.DBClusterIdentifier), err) - } else { - databases = append(databases, database) - } - } - - // Add a database from reader endpoint, if any reader instances. - // https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/Aurora.Overview.Endpoints.html#Aurora.Endpoints.Reader - if cluster.ReaderEndpoint != nil && hasReaderInstance { - database, err := services.NewDatabaseFromRDSClusterReaderEndpoint(cluster) - if err != nil { - f.log.Warnf("Could not convert RDS cluster %q reader endpoint to database resource: %v.", - aws.StringValue(cluster.DBClusterIdentifier), err) - } else { - databases = append(databases, database) - } - } - - // Add databases from custom endpoints - if len(cluster.CustomEndpoints) > 0 { - customEndpointDatabases, err := services.NewDatabasesFromRDSClusterCustomEndpoints(cluster) - if err != nil { - f.log.Warnf("Could not convert RDS cluster %q custom endpoints to database resources: %v.", - aws.StringValue(cluster.DBClusterIdentifier), err) - } - - databases = append(databases, customEndpointDatabases...) + dbs, err := services.NewDatabasesFromRDSCluster(cluster) + if err != nil { + f.log.Warnf("Could not convert RDS cluster %q to database resources: %v.", + aws.StringValue(cluster.DBClusterIdentifier), err) } + databases = append(databases, dbs...) } return databases, nil } diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go index 9d901d7b9aa9f..481cee244b856 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go @@ -17,10 +17,8 @@ limitations under the License. package db import ( - "fmt" "testing" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/rds" "github.com/stretchr/testify/require" @@ -68,30 +66,14 @@ func TestRDSDBProxyFetcher(t *testing.T) { } func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rds.DBProxy, types.Database) { - rdsProxy := &rds.DBProxy{ - DBProxyArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:123456789012:db-proxy:prx-%s", region, name)), - DBProxyName: aws.String(name), - EngineFamily: aws.String(rds.EngineFamilyMysql), - Endpoint: aws.String("localhost"), - VpcId: aws.String(vpcID), - RequireTLS: aws.Bool(true), - Status: aws.String("available"), - } - + rdsProxy := mocks.RDSProxy(name, region, vpcID) rdsProxyDatabase, err := services.NewDatabaseFromRDSProxy(rdsProxy, 9999, nil) require.NoError(t, err) return rdsProxy, rdsProxyDatabase } func makeRDSProxyCustomEndpoint(t *testing.T, rdsProxy *rds.DBProxy, name, region string) (*rds.DBProxyEndpoint, types.Database) { - rdsProxyEndpoint := &rds.DBProxyEndpoint{ - Endpoint: aws.String("localhost"), - DBProxyEndpointName: aws.String(name), - DBProxyName: rdsProxy.DBProxyName, - DBProxyEndpointArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db-proxy-endpoint:prx-endpoint-%v", region, name)), - TargetRole: aws.String(rds.DBProxyEndpointTargetRoleReadOnly), - Status: aws.String("available"), - } + rdsProxyEndpoint := mocks.RDSProxyCustomEndpoint(rdsProxy, name, region) rdsProxyEndpointDatabase, err := services.NewDatabaseFromRDSProxyCustomEndpoint(rdsProxy, rdsProxyEndpoint, 9999, nil) require.NoError(t, err) return rdsProxyEndpoint, rdsProxyEndpointDatabase diff --git a/lib/srv/discovery/fetchers/db/aws_rds_test.go b/lib/srv/discovery/fetchers/db/aws_rds_test.go index f86f80cd4a661..5bfbb37743735 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_test.go @@ -17,18 +17,15 @@ limitations under the License. package db import ( - "fmt" "testing" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/rds/rdsiface" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud" - libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/services" ) @@ -205,71 +202,29 @@ func TestRDSFetchers(t *testing.T) { } func makeRDSInstance(t *testing.T, name, region string, labels map[string]string, opts ...func(*rds.DBInstance)) (*rds.DBInstance, types.Database) { - instance := &rds.DBInstance{ - DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db:%v", region, name)), - DBInstanceIdentifier: aws.String(name), - DbiResourceId: aws.String(uuid.New().String()), - Engine: aws.String(services.RDSEnginePostgres), - DBInstanceStatus: aws.String("available"), - Endpoint: &rds.Endpoint{ - Address: aws.String("localhost"), - Port: aws.Int64(5432), - }, - TagList: libcloudaws.LabelsToTags[rds.Tag](labels), - } - for _, opt := range opts { - opt(instance) - } - + instance := mocks.RDSInstance(name, region, labels, opts...) database, err := services.NewDatabaseFromRDSInstance(instance) require.NoError(t, err) return instance, database } func makeRDSCluster(t *testing.T, name, region string, labels map[string]string, opts ...func(*rds.DBCluster)) (*rds.DBCluster, types.Database) { - cluster := &rds.DBCluster{ - DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:cluster:%v", region, name)), - DBClusterIdentifier: aws.String(name), - DbClusterResourceId: aws.String(uuid.New().String()), - Engine: aws.String(services.RDSEngineAuroraMySQL), - EngineMode: aws.String(services.RDSEngineModeProvisioned), - Status: aws.String("available"), - Endpoint: aws.String("localhost"), - Port: aws.Int64(3306), - TagList: libcloudaws.LabelsToTags[rds.Tag](labels), - DBClusterMembers: []*rds.DBClusterMember{{ - IsClusterWriter: aws.Bool(true), // Only one writer. - }}, - } - for _, opt := range opts { - opt(cluster) - } - + cluster := mocks.RDSCluster(name, region, labels, opts...) database, err := services.NewDatabaseFromRDSCluster(cluster) require.NoError(t, err) return cluster, database } func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels map[string]string, hasWriter bool) (*rds.DBCluster, types.Databases) { - cluster := &rds.DBCluster{ - DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:cluster:%v", region, name)), - DBClusterIdentifier: aws.String(name), - DbClusterResourceId: aws.String(uuid.New().String()), - Engine: aws.String(services.RDSEngineAuroraMySQL), - EngineMode: aws.String(services.RDSEngineModeProvisioned), - Status: aws.String("available"), - Endpoint: aws.String("localhost"), - ReaderEndpoint: aws.String("reader.host"), - Port: aws.Int64(3306), - TagList: libcloudaws.LabelsToTags[rds.Tag](labels), - DBClusterMembers: []*rds.DBClusterMember{{ - IsClusterWriter: aws.Bool(false), // Add reader by default. Writer is added below based on hasWriter. - }}, - CustomEndpoints: []*string{ - aws.String("custom1.cluster-custom-example.us-east-1.rds.amazonaws.com"), - aws.String("custom2.cluster-custom-example.us-east-1.rds.amazonaws.com"), + cluster := mocks.RDSCluster(name, region, labels, + func(cluster *rds.DBCluster) { + // Disable writer by default. If hasWriter, writer endpoint will be added below. + cluster.DBClusterMembers = nil }, - } + mocks.WithRDSClusterReader, + mocks.WithRDSClusterCustomEndpoint("custom1"), + mocks.WithRDSClusterCustomEndpoint("custom2"), + ) var databases types.Databases diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_test.go b/lib/srv/discovery/fetchers/db/aws_redshift_test.go index 8a6d17be2bbe0..8665c086ea488 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_test.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_test.go @@ -17,7 +17,6 @@ limitations under the License. package db import ( - "fmt" "testing" "github.com/aws/aws-sdk-go/aws" @@ -74,22 +73,7 @@ func TestRedshiftFetcher(t *testing.T) { } func makeRedshiftCluster(t *testing.T, region, env string, opts ...func(*redshift.Cluster)) (*redshift.Cluster, types.Database) { - cluster := &redshift.Cluster{ - ClusterIdentifier: aws.String(env), - ClusterNamespaceArn: aws.String(fmt.Sprintf("arn:aws:redshift:%s:123456789012:namespace:%s", region, env)), - ClusterStatus: aws.String("available"), - Endpoint: &redshift.Endpoint{ - Address: aws.String("localhost"), - Port: aws.Int64(5439), - }, - Tags: []*redshift.Tag{{ - Key: aws.String("env"), - Value: aws.String(env), - }}, - } - for _, opt := range opts { - opt(cluster) - } + cluster := mocks.RedshiftCluster(env, region, map[string]string{"env": env}, opts...) database, err := services.NewDatabaseFromRedshiftCluster(cluster) require.NoError(t, err)