diff --git a/.github/ISSUE_TEMPLATE/testplan.md b/.github/ISSUE_TEMPLATE/testplan.md index 7c68d0c1d49ab..587089155f9c3 100644 --- a/.github/ISSUE_TEMPLATE/testplan.md +++ b/.github/ISSUE_TEMPLATE/testplan.md @@ -926,6 +926,7 @@ tsh bench web sessions --max=5000 --web user ls - [ ] Can update registered database using `tctl create -f`. - [ ] Can delete registered database using `tctl rm`. - [ ] Verify discovery. + Please configure discovery in Discovery Service instead of Database Service. - [ ] AWS - [ ] Can detect and register RDS instances. - [ ] Can detect and register RDS instances in an external AWS account when `assume_role_arn` and `external_id` is set. @@ -936,6 +937,7 @@ tsh bench web sessions --max=5000 --web user ls - [ ] Can detect and register Redshift serverless workgroups, and their VPC endpoints. - [ ] Can detect and register ElastiCache Redis clusters. - [ ] Can detect and register MemoryDB clusters. + - [ ] Can detect and register OpenSearch domains. - [ ] Azure - [ ] Can detect and register MySQL and Postgres single-server instances. - [ ] Can detect and register MySQL and Postgres flexible-server instances. diff --git a/api/types/database.go b/api/types/database.go index ba5b969817f1f..e95e393275e65 100644 --- a/api/types/database.go +++ b/api/types/database.go @@ -130,6 +130,8 @@ type Database interface { // SupportsAutoUsers returns true if this database supports automatic // user provisioning. SupportsAutoUsers() bool + // GetEndpointType returns the endpoint type of the database, if available. + GetEndpointType() string } // NewDatabaseV3 creates a new database resource. @@ -947,6 +949,22 @@ func (d *DatabaseV3) SupportAWSIAMRoleARNAsUsers() bool { return d.GetType() == DatabaseTypeMongoAtlas } +// GetEndpointType returns the endpoint type of the database, if available. +func (d *DatabaseV3) GetEndpointType() string { + if endpointType, ok := d.GetStaticLabels()[DiscoveryLabelEndpointType]; ok { + return endpointType + } + switch d.GetType() { + case DatabaseTypeElastiCache: + return d.GetAWS().ElastiCache.EndpointType + case DatabaseTypeMemoryDB: + return d.GetAWS().MemoryDB.EndpointType + case DatabaseTypeOpenSearch: + return d.GetAWS().OpenSearch.EndpointType + } + return "" +} + const ( // DatabaseProtocolPostgreSQL is the PostgreSQL database protocol. DatabaseProtocolPostgreSQL = "postgres" diff --git a/api/utils/aws/endpoint.go b/api/utils/aws/endpoint.go index 7fef2c2dffa56..e7458cd8747cb 100644 --- a/api/utils/aws/endpoint.go +++ b/api/utils/aws/endpoint.go @@ -74,6 +74,11 @@ func IsKeyspacesEndpoint(uri string) bool { return hasCassandraPrefix && IsAWSEndpoint(uri) } +// IsOpenSearchEndpoint returns true if input URI is an OpenSearch endpoint. +func IsOpenSearchEndpoint(uri string) bool { + return isAWSServiceEndpoint(uri, OpenSearchServiceName) +} + // RDSEndpointDetails contains information about an RDS endpoint. type RDSEndpointDetails struct { // InstanceID is the identifier of an RDS instance. diff --git a/lib/cloud/mocks/aws_memorydb.go b/lib/cloud/mocks/aws_memorydb.go index 83cb714ee49c9..36c0d528c2bfc 100644 --- a/lib/cloud/mocks/aws_memorydb.go +++ b/lib/cloud/mocks/aws_memorydb.go @@ -30,6 +30,7 @@ import ( type MemoryDBMock struct { memorydbiface.MemoryDBAPI + Unauth bool Clusters []*memorydb.Cluster Users []*memorydb.User TagsByARN map[string][]*memorydb.Tag @@ -60,6 +61,9 @@ func (m *MemoryDBMock) DescribeSubnetGroupsWithContext(aws.Context, *memorydb.De } func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memorydb.DescribeClustersInput, _ ...request.Option) (*memorydb.DescribeClustersOutput, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") + } if aws.StringValue(input.ClusterName) == "" { return &memorydb.DescribeClustersOutput{ Clusters: m.Clusters, @@ -77,6 +81,9 @@ func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memoryd } func (m *MemoryDBMock) ListTagsWithContext(_ aws.Context, input *memorydb.ListTagsInput, _ ...request.Option) (*memorydb.ListTagsOutput, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") + } if m.TagsByARN == nil { return nil, trace.NotFound("no tags") } @@ -92,12 +99,18 @@ func (m *MemoryDBMock) ListTagsWithContext(_ aws.Context, input *memorydb.ListTa } func (m *MemoryDBMock) DescribeUsersWithContext(aws.Context, *memorydb.DescribeUsersInput, ...request.Option) (*memorydb.DescribeUsersOutput, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") + } return &memorydb.DescribeUsersOutput{ Users: m.Users, }, nil } func (m *MemoryDBMock) UpdateUserWithContext(_ aws.Context, input *memorydb.UpdateUserInput, opts ...request.Option) (*memorydb.UpdateUserOutput, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") + } for _, user := range m.Users { if aws.StringValue(user.Name) == aws.StringValue(input.UserName) { return &memorydb.UpdateUserOutput{}, nil diff --git a/lib/cloud/mocks/aws_opensearch.go b/lib/cloud/mocks/aws_opensearch.go index 2004660b5de44..604553d33dc3f 100644 --- a/lib/cloud/mocks/aws_opensearch.go +++ b/lib/cloud/mocks/aws_opensearch.go @@ -24,16 +24,21 @@ import ( "github.com/aws/aws-sdk-go/service/opensearchservice" "github.com/aws/aws-sdk-go/service/opensearchservice/opensearchserviceiface" "github.com/gravitational/trace" + "golang.org/x/exp/slices" ) type OpenSearchMock struct { opensearchserviceiface.OpenSearchServiceAPI + Unauth bool Domains []*opensearchservice.DomainStatus TagsByARN map[string][]*opensearchservice.Tag } func (o *OpenSearchMock) ListDomainNamesWithContext(aws.Context, *opensearchservice.ListDomainNamesInput, ...request.Option) (*opensearchservice.ListDomainNamesOutput, error) { + if o.Unauth { + return nil, trace.AccessDenied("unauthorized") + } out := &opensearchservice.ListDomainNamesOutput{} for _, domain := range o.Domains { out.DomainNames = append(out.DomainNames, &opensearchservice.DomainInfo{ @@ -45,12 +50,25 @@ func (o *OpenSearchMock) ListDomainNamesWithContext(aws.Context, *opensearchserv return out, nil } -func (o *OpenSearchMock) DescribeDomainsWithContext(aws.Context, *opensearchservice.DescribeDomainsInput, ...request.Option) (*opensearchservice.DescribeDomainsOutput, error) { - out := &opensearchservice.DescribeDomainsOutput{DomainStatusList: o.Domains} +func (o *OpenSearchMock) DescribeDomainsWithContext(_ aws.Context, input *opensearchservice.DescribeDomainsInput, _ ...request.Option) (*opensearchservice.DescribeDomainsOutput, error) { + if o.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + out := &opensearchservice.DescribeDomainsOutput{} + for _, domain := range o.Domains { + if slices.ContainsFunc(input.DomainNames, func(other *string) bool { + return aws.StringValue(other) == aws.StringValue(domain.DomainName) + }) { + out.DomainStatusList = append(out.DomainStatusList, domain) + } + } return out, nil } func (o *OpenSearchMock) ListTagsWithContext(_ aws.Context, request *opensearchservice.ListTagsInput, _ ...request.Option) (*opensearchservice.ListTagsOutput, error) { + if o.Unauth { + return nil, trace.AccessDenied("unauthorized") + } tags, found := o.TagsByARN[aws.StringValue(request.ARN)] if !found { return nil, trace.NotFound("tags not found") diff --git a/lib/cloud/mocks/aws_redshift_serverless.go b/lib/cloud/mocks/aws_redshift_serverless.go index f2a00d85f8ccc..0dda9d5b8e317 100644 --- a/lib/cloud/mocks/aws_redshift_serverless.go +++ b/lib/cloud/mocks/aws_redshift_serverless.go @@ -32,6 +32,7 @@ import ( type RedshiftServerlessMock struct { redshiftserverlessiface.RedshiftServerlessAPI + Unauth bool Workgroups []*redshiftserverless.Workgroup Endpoints []*redshiftserverless.EndpointAccess TagsByARN map[string][]*redshiftserverless.Tag @@ -39,6 +40,10 @@ type RedshiftServerlessMock struct { } func (m RedshiftServerlessMock) GetWorkgroupWithContext(_ aws.Context, input *redshiftserverless.GetWorkgroupInput, _ ...request.Option) (*redshiftserverless.GetWorkgroupOutput, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + for _, workgroup := range m.Workgroups { if aws.StringValue(workgroup.WorkgroupName) == aws.StringValue(input.WorkgroupName) { return new(redshiftserverless.GetWorkgroupOutput).SetWorkgroup(workgroup), nil @@ -47,6 +52,9 @@ func (m RedshiftServerlessMock) GetWorkgroupWithContext(_ aws.Context, input *re return nil, trace.NotFound("workgroup %q not found", aws.StringValue(input.WorkgroupName)) } func (m RedshiftServerlessMock) GetEndpointAccessWithContext(_ aws.Context, input *redshiftserverless.GetEndpointAccessInput, _ ...request.Option) (*redshiftserverless.GetEndpointAccessOutput, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") + } for _, endpoint := range m.Endpoints { if aws.StringValue(endpoint.EndpointName) == aws.StringValue(input.EndpointName) { return new(redshiftserverless.GetEndpointAccessOutput).SetEndpoint(endpoint), nil @@ -55,18 +63,27 @@ func (m RedshiftServerlessMock) GetEndpointAccessWithContext(_ aws.Context, inpu return nil, trace.NotFound("endpoint %q not found", aws.StringValue(input.EndpointName)) } func (m RedshiftServerlessMock) ListWorkgroupsPagesWithContext(_ aws.Context, input *redshiftserverless.ListWorkgroupsInput, fn func(*redshiftserverless.ListWorkgroupsOutput, bool) bool, _ ...request.Option) error { + if m.Unauth { + return trace.AccessDenied("unauthorized") + } fn(&redshiftserverless.ListWorkgroupsOutput{ Workgroups: m.Workgroups, }, true) return nil } func (m RedshiftServerlessMock) ListEndpointAccessPagesWithContext(_ aws.Context, input *redshiftserverless.ListEndpointAccessInput, fn func(*redshiftserverless.ListEndpointAccessOutput, bool) bool, _ ...request.Option) error { + if m.Unauth { + return trace.AccessDenied("unauthorized") + } fn(&redshiftserverless.ListEndpointAccessOutput{ Endpoints: m.Endpoints, }, true) return nil } func (m RedshiftServerlessMock) ListTagsForResourceWithContext(_ aws.Context, input *redshiftserverless.ListTagsForResourceInput, _ ...request.Option) (*redshiftserverless.ListTagsForResourceOutput, error) { + if m.Unauth { + return nil, trace.AccessDenied("unauthorized") + } if m.TagsByARN == nil { return &redshiftserverless.ListTagsForResourceOutput{}, nil } @@ -75,7 +92,7 @@ func (m RedshiftServerlessMock) ListTagsForResourceWithContext(_ aws.Context, in }, nil } func (m RedshiftServerlessMock) GetCredentialsWithContext(aws.Context, *redshiftserverless.GetCredentialsInput, ...request.Option) (*redshiftserverless.GetCredentialsOutput, error) { - if m.GetCredentialsOutput == nil { + if m.Unauth || m.GetCredentialsOutput == nil { return nil, trace.AccessDenied("access denied") } return m.GetCredentialsOutput, nil diff --git a/lib/configurators/aws/aws.go b/lib/configurators/aws/aws.go index 82971d55030b7..4d82f94e5bbe1 100644 --- a/lib/configurators/aws/aws.go +++ b/lib/configurators/aws/aws.go @@ -126,9 +126,11 @@ func makeDatabaseActionsBuildOption(flags configurators.BootstrapFlags, targetCf case configurators.DatabaseServiceByDiscoveryServiceConfig: return databaseActionsBuildOption{ withDiscovery: false, - withMetadata: false, // Discovered databases should have correct metadata. withAuth: true, withAuthBoundary: boundary, + // Discovered databases should be checked by URL validator which + // requires same permissions as the metadata service. + withMetadata: true, } case configurators.DatabaseService: diff --git a/lib/configurators/aws/aws_test.go b/lib/configurators/aws/aws_test.go index 23d1290c19cde..e400af3ef50a4 100644 --- a/lib/configurators/aws/aws_test.go +++ b/lib/configurators/aws/aws_test.go @@ -1147,7 +1147,7 @@ func TestAWSIAMDocuments(t *testing.T) { { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"rds:ModifyDBInstance", "rds:ModifyDBCluster"}, + Actions: []string{"rds:DescribeDBInstances", "rds:DescribeDBClusters", "rds:ModifyDBInstance", "rds:ModifyDBCluster"}, }, { Effect: awslib.EffectAllow, @@ -1159,7 +1159,7 @@ func TestAWSIAMDocuments(t *testing.T) { { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"rds:ModifyDBInstance", "rds:ModifyDBCluster", "rds-db:connect"}, + Actions: []string{"rds:DescribeDBInstances", "rds:DescribeDBClusters", "rds:ModifyDBInstance", "rds:ModifyDBCluster", "rds-db:connect"}, }, { Effect: awslib.EffectAllow, @@ -1178,6 +1178,11 @@ func TestAWSIAMDocuments(t *testing.T) { }, }, statements: []*awslib.Statement{ + { + Effect: awslib.EffectAllow, + Resources: []string{"*"}, + Actions: []string{"rds:DescribeDBProxies", "rds:DescribeDBProxyEndpoints"}, + }, { Effect: awslib.EffectAllow, Resources: []string{roleTarget.String()}, @@ -1188,7 +1193,7 @@ func TestAWSIAMDocuments(t *testing.T) { { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"rds-db:connect"}, + Actions: []string{"rds:DescribeDBProxies", "rds:DescribeDBProxyEndpoints", "rds-db:connect"}, }, { Effect: awslib.EffectAllow, @@ -1207,6 +1212,11 @@ func TestAWSIAMDocuments(t *testing.T) { }, }, statements: []*awslib.Statement{ + { + Effect: awslib.EffectAllow, + Resources: []string{"*"}, + Actions: []string{"redshift:DescribeClusters"}, + }, { Effect: awslib.EffectAllow, Resources: []string{roleTarget.String()}, @@ -1217,7 +1227,7 @@ func TestAWSIAMDocuments(t *testing.T) { { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"redshift:GetClusterCredentials"}, + Actions: []string{"redshift:DescribeClusters", "redshift:GetClusterCredentials"}, }, { Effect: awslib.EffectAllow, @@ -1235,11 +1245,18 @@ func TestAWSIAMDocuments(t *testing.T) { }, }, }, + statements: []*awslib.Statement{ + { + Effect: awslib.EffectAllow, + Resources: []string{"*"}, + Actions: []string{"redshift-serverless:GetEndpointAccess", "redshift-serverless:GetWorkgroup"}, + }, + }, boundaryStatements: []*awslib.Statement{ { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"sts:AssumeRole"}, + Actions: []string{"redshift-serverless:GetEndpointAccess", "redshift-serverless:GetWorkgroup", "sts:AssumeRole"}, }, }, }, @@ -1256,7 +1273,7 @@ func TestAWSIAMDocuments(t *testing.T) { { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"elasticache:DescribeUsers", "elasticache:ModifyUser"}, + Actions: []string{"elasticache:DescribeReplicationGroups", "elasticache:DescribeUsers", "elasticache:ModifyUser"}, }, { Effect: awslib.EffectAllow, @@ -1278,7 +1295,7 @@ func TestAWSIAMDocuments(t *testing.T) { { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"elasticache:DescribeUsers", "elasticache:ModifyUser", "elasticache:Connect"}, + Actions: []string{"elasticache:DescribeReplicationGroups", "elasticache:DescribeUsers", "elasticache:ModifyUser", "elasticache:Connect"}, }, { Effect: awslib.EffectAllow, @@ -1310,7 +1327,7 @@ func TestAWSIAMDocuments(t *testing.T) { { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"memorydb:DescribeUsers", "memorydb:UpdateUser"}, + Actions: []string{"memorydb:DescribeClusters", "memorydb:DescribeUsers", "memorydb:UpdateUser"}, }, { Effect: awslib.EffectAllow, @@ -1327,7 +1344,7 @@ func TestAWSIAMDocuments(t *testing.T) { { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"memorydb:DescribeUsers", "memorydb:UpdateUser"}, + Actions: []string{"memorydb:DescribeClusters", "memorydb:DescribeUsers", "memorydb:UpdateUser"}, }, { Effect: awslib.EffectAllow, @@ -1350,11 +1367,18 @@ func TestAWSIAMDocuments(t *testing.T) { }, }, }, + statements: []*awslib.Statement{ + { + Effect: awslib.EffectAllow, + Resources: []string{"*"}, + Actions: []string{"es:DescribeDomains"}, + }, + }, boundaryStatements: []*awslib.Statement{ { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"sts:AssumeRole"}, + Actions: []string{"es:DescribeDomains", "sts:AssumeRole"}, }, }, }, @@ -1373,7 +1397,17 @@ func TestAWSIAMDocuments(t *testing.T) { { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"rds:ModifyDBInstance", "rds:ModifyDBCluster"}, + Actions: []string{"rds:DescribeDBInstances", "rds:DescribeDBClusters", "rds:ModifyDBInstance", "rds:ModifyDBCluster"}, + }, + { + Effect: awslib.EffectAllow, + Resources: []string{"*"}, + Actions: []string{"rds:DescribeDBProxies", "rds:DescribeDBProxyEndpoints"}, + }, + { + Effect: awslib.EffectAllow, + Resources: []string{"*"}, + Actions: []string{"redshift:DescribeClusters"}, }, { Effect: awslib.EffectAllow, @@ -1390,12 +1424,17 @@ func TestAWSIAMDocuments(t *testing.T) { { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"rds:ModifyDBInstance", "rds:ModifyDBCluster", "rds-db:connect"}, + Actions: []string{"rds:DescribeDBInstances", "rds:DescribeDBClusters", "rds:ModifyDBInstance", "rds:ModifyDBCluster", "rds-db:connect"}, + }, + { + Effect: awslib.EffectAllow, + Resources: []string{"*"}, + Actions: []string{"rds:DescribeDBProxies", "rds:DescribeDBProxyEndpoints"}, }, { Effect: awslib.EffectAllow, Resources: []string{"*"}, - Actions: []string{"redshift:GetClusterCredentials"}, + Actions: []string{"redshift:DescribeClusters", "redshift:GetClusterCredentials"}, }, { Effect: awslib.EffectAllow, diff --git a/lib/srv/db/cloud/meta.go b/lib/srv/db/cloud/meta.go index 36e315b833191..48e082b052f01 100644 --- a/lib/srv/db/cloud/meta.go +++ b/lib/srv/db/cloud/meta.go @@ -339,7 +339,7 @@ func describeRDSProxy(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyName // fetchRDSProxyCustomEndpointMetadata fetches metadata about specified RDS // proxy custom endpoint. func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyEndpointName, uri string) (*types.AWS, error) { - rdsProxyEndpoint, err := describeRDSProxyCustomEndpoint(ctx, rdsClient, proxyEndpointName, uri) + rdsProxyEndpoint, err := describeRDSProxyCustomEndpointAndFindURI(ctx, rdsClient, proxyEndpointName, uri) if err != nil { return nil, trace.Wrap(err) } @@ -352,9 +352,9 @@ func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, rdsClient rdsiface return services.MetadataFromRDSProxyCustomEndpoint(rdsProxy, rdsProxyEndpoint) } -// describeRDSProxyCustomEndpoint returns AWS RDS Proxy endpoint for the -// specified RDS Proxy custom endpoint. -func describeRDSProxyCustomEndpoint(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyEndpointName, uri string) (*rds.DBProxyEndpoint, error) { +// describeRDSProxyCustomEndpointAndFindURI returns AWS RDS Proxy endpoint for +// the specified RDS Proxy custom endpoint. +func describeRDSProxyCustomEndpointAndFindURI(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyEndpointName, uri string) (*rds.DBProxyEndpoint, error) { out, err := rdsClient.DescribeDBProxyEndpointsWithContext(ctx, &rds.DescribeDBProxyEndpointsInput{ DBProxyEndpointName: aws.String(proxyEndpointName), }) diff --git a/lib/srv/db/cloud/resource_checker.go b/lib/srv/db/cloud/resource_checker.go index 780c5b292a0c1..70d56fdd4a793 100644 --- a/lib/srv/db/cloud/resource_checker.go +++ b/lib/srv/db/cloud/resource_checker.go @@ -72,16 +72,18 @@ func NewDiscoveryResourceChecker(cfg DiscoveryResourceCheckerConfig) (DiscoveryR return nil, trace.Wrap(err) } - c := &discoveryResourceChecker{} - - // TODO(greedy52) implement url checker. - // TODO(greedy52) implement name checker. - if checker, err := newCrednentialsChecker(cfg); err != nil { + credentialsChecker, err := newCrednentialsChecker(cfg) + if err != nil { return nil, trace.Wrap(err) - } else { - c.checkers = append(c.checkers, checker) } - return c, nil + + // TODO(greedy52) implement name checker. + return &discoveryResourceChecker{ + checkers: []DiscoveryResourceChecker{ + newURLChecker(cfg), + credentialsChecker, + }, + }, nil } // discoveryResourceChecker is a composite checker. diff --git a/lib/srv/db/cloud/resource_checker_url.go b/lib/srv/db/cloud/resource_checker_url.go new file mode 100644 index 0000000000000..f8ae2e85fb334 --- /dev/null +++ b/lib/srv/db/cloud/resource_checker_url.go @@ -0,0 +1,145 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cloud + +import ( + "context" + "fmt" + "net" + "os" + "sync" + + "github.com/aws/aws-sdk-go/aws" + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils" + apiawsutils "github.com/gravitational/teleport/api/utils/aws" + "github.com/gravitational/teleport/lib/cloud" +) + +// urlChecker validates the database has the correct URL. +type urlChecker struct { + clients cloud.Clients + log logrus.FieldLogger + warnOnError bool + + warnAWSOnce sync.Once + + // TODO(greedy52) consider caching describe call responses to avoid + // repeated calls: + // - metadata service + // - multiple endpoints from the same cloud resource +} + +func newURLChecker(cfg DiscoveryResourceCheckerConfig) *urlChecker { + return &urlChecker{ + clients: cfg.Clients, + log: cfg.Log, + warnOnError: getWarnOnError(), + } +} + +// getWarnOnError returns true if urlChecker should only log a warning instead +// of returning errors when check fails. +// +// DELETE IN 16.0.0 The environement variable is a temporary toggle to disable +// returning errors by urlChecker, in case Database Service doesn't have proper +// permissions and basic endpoint checks fail for unknown reasons. Remove after +// one or two releases when implementation is stable. +func getWarnOnError() bool { + value := os.Getenv("TELEPORT_DATABASE_URL_CHECK_WARN_ON_ERROR") + if value == "" { + return false + } + + boolValue, err := utils.ParseBool(value) + if err != nil { + logrus.Warnf("Invalid bool value for TELEPORT_DATABASE_URL_CHECK_WARN_ON_ERROR: %q.", value) + } + return boolValue +} + +type checkDatabaseFunc func(context.Context, types.Database) error +type isEndpointFunc func(string) bool + +func convIsEndpoint(isEndpoint isEndpointFunc) checkDatabaseFunc { + return func(_ context.Context, database types.Database) error { + if isEndpoint(database.GetURI()) { + return nil + } + return trace.BadParameter("expect a %q endpoint for database %q but got %v", database.GetType(), database.GetName(), database.GetURI()) + } +} + +// Check permforms url checks. +func (c *urlChecker) Check(ctx context.Context, database types.Database) error { + checkersByDatabaseType := map[string]checkDatabaseFunc{ + types.DatabaseTypeRDS: c.checkAWS(c.checkRDS, convIsEndpoint(apiawsutils.IsRDSEndpoint)), + types.DatabaseTypeRDSProxy: c.checkAWS(c.checkRDSProxy, convIsEndpoint(apiawsutils.IsRDSEndpoint)), + types.DatabaseTypeRedshift: c.checkAWS(c.checkRedshift, convIsEndpoint(apiawsutils.IsRedshiftEndpoint)), + types.DatabaseTypeRedshiftServerless: c.checkAWS(c.checkRedshiftServerless, convIsEndpoint(apiawsutils.IsRedshiftServerlessEndpoint)), + types.DatabaseTypeElastiCache: c.checkAWS(c.checkElastiCache, convIsEndpoint(apiawsutils.IsElastiCacheEndpoint)), + types.DatabaseTypeMemoryDB: c.checkAWS(c.checkMemoryDB, convIsEndpoint(apiawsutils.IsMemoryDBEndpoint)), + types.DatabaseTypeOpenSearch: c.checkAWS(c.checkOpenSearch, c.checkOpenSearchEndpoint), + types.DatabaseTypeAzure: c.checkAzure, + } + + if check := checkersByDatabaseType[database.GetType()]; check != nil { + err := check(ctx, database) + if err != nil && c.warnOnError { + c.log.Warnf("URL check failed for %q: %v.", database.GetName(), err) + return nil + } + return trace.Wrap(err) + } + + c.log.Debugf("URL checker does not support database type %q.", database.GetType()) + return nil +} + +func requireDatabaseIsEndpoint(ctx context.Context, database types.Database, isEndpoint isEndpointFunc) error { + return trace.Wrap(convIsEndpoint(isEndpoint)(ctx, database)) +} + +func requireDatabaseAddressPort(database types.Database, wantURLHost *string, wantURLPort *int64) error { + wantURL := fmt.Sprintf("%v:%v", aws.StringValue(wantURLHost), aws.Int64Value(wantURLPort)) + if database.GetURI() != wantURL { + return trace.BadParameter("expect database URL %q but got %q for database %q", wantURL, database.GetURI(), database.GetName()) + } + return nil +} + +func requireDatabaseHost(database types.Database, wantURLHost string) error { + host, _, _ := net.SplitHostPort(database.GetURI()) + if host != wantURLHost { + return trace.BadParameter("expect database URL %q: but got %q for database %q", wantURLHost, database.GetURI(), database.GetName()) + } + return nil +} + +func requireContainsDatabaseURLAndEndpointType(in types.Databases, database types.Database, resource any) error { + matchURLAndType := func(other types.Database) bool { + return other.GetURI() == database.GetURI() && other.GetEndpointType() == database.GetEndpointType() + } + if slices.ContainsFunc(in, matchURLAndType) { + return nil + } + return trace.BadParameter("cannot find %v in %#v", database.GetURI(), resource) +} diff --git a/lib/srv/db/cloud/resource_checker_url_aws.go b/lib/srv/db/cloud/resource_checker_url_aws.go new file mode 100644 index 0000000000000..7787e69822cb1 --- /dev/null +++ b/lib/srv/db/cloud/resource_checker_url_aws.go @@ -0,0 +1,250 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cloud + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/opensearchservice" + "github.com/aws/aws-sdk-go/service/rds/rdsiface" + "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + apiawsutils "github.com/gravitational/teleport/api/utils/aws" + "github.com/gravitational/teleport/lib/cloud" + cloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/services" +) + +func (c *urlChecker) checkAWS(describeCheck, basicEndpointCheck checkDatabaseFunc) checkDatabaseFunc { + return func(ctx context.Context, database types.Database) error { + err := describeCheck(ctx, database) + + // Database Service may not have enough permissions to permform the + // describes. Log a warning and permform a basic endpoint validation + // instead. + if trace.IsAccessDenied(err) { + c.logAWSAccessDeniedError(database, err) + + if err := basicEndpointCheck(ctx, database); err != nil { + return trace.Wrap(err) + } + c.log.Debugf("AWS database %q URL validated by basic endpoint check.", database.GetName()) + return nil + } + + if err != nil { + return trace.Wrap(err) + } + c.log.Debugf("AWS database %q URL validated by describe check.", database.GetName()) + return nil + } +} + +func (c *urlChecker) logAWSAccessDeniedError(database types.Database, accessDeniedError error) { + c.warnAWSOnce.Do(func() { + // TODO(greedy52) add links to doc. + c.log.Warn("No permissions to describe AWS resource metadata that is needed for validating databases created by Discovery Service. Basic AWS endpoint validation will be performed instead. For best security, please provide the Database Service with the proper IAM permissions. Enable --debug mode to see details on which databases require more IAM permissions. See Database Access documentation for more details.") + }) + + c.log.Debugf("No permissions to describe database %q for URL validation.", database.GetName()) +} + +func (c *urlChecker) checkRDS(ctx context.Context, database types.Database) error { + meta := database.GetAWS() + rdsClient, err := c.clients.GetAWSRDSClient(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta)) + if err != nil { + return trace.Wrap(err) + } + + if meta.RDS.ClusterID != "" { + return trace.Wrap(c.checkRDSCluster(ctx, database, rdsClient, meta.RDS.ClusterID)) + } + return trace.Wrap(c.checkRDSInstance(ctx, database, rdsClient, meta.RDS.InstanceID)) +} + +func (c *urlChecker) checkRDSInstance(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, instanceID string) error { + rdsInstance, err := describeRDSInstance(ctx, rdsClient, instanceID) + if err != nil { + return trace.Wrap(err) + } + if rdsInstance.Endpoint == nil { + return trace.BadParameter("empty endpoint") + } + return trace.Wrap(requireDatabaseAddressPort(database, rdsInstance.Endpoint.Address, rdsInstance.Endpoint.Port)) +} + +func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, clusterID string) error { + rdsCluster, err := describeRDSCluster(ctx, rdsClient, clusterID) + if err != nil { + return trace.Wrap(err) + } + databases, err := services.NewDatabasesFromRDSCluster(rdsCluster) + if err != nil { + c.log.Warnf("Could not convert RDS cluster %q to database resources: %v.", + aws.StringValue(rdsCluster.DBClusterIdentifier), err) + + // services.NewDatabasesFromRDSCluster maybe partially successful. + if len(databases) == 0 { + return nil + } + } + return trace.Wrap(requireContainsDatabaseURLAndEndpointType(databases, database, rdsCluster)) +} + +func (c *urlChecker) checkRDSProxy(ctx context.Context, database types.Database) error { + meta := database.GetAWS() + rdsClient, err := c.clients.GetAWSRDSClient(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta)) + if err != nil { + return trace.Wrap(err) + } + if meta.RDSProxy.CustomEndpointName != "" { + return trace.Wrap(c.checkRDSProxyCustomEndpoint(ctx, database, rdsClient, meta.RDSProxy.CustomEndpointName)) + } + return trace.Wrap(c.checkRDSProxyPrimaryEndpoint(ctx, database, rdsClient, meta.RDSProxy.Name)) +} + +func (c *urlChecker) checkRDSProxyPrimaryEndpoint(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, proxyName string) error { + rdsProxy, err := describeRDSProxy(ctx, rdsClient, proxyName) + if err != nil { + return trace.Wrap(err) + } + // Port has to be fetched from a separate API. Instead of fetching that, + // just validate the host domain. + return requireDatabaseHost(database, aws.StringValue(rdsProxy.Endpoint)) +} + +func (c *urlChecker) checkRDSProxyCustomEndpoint(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, proxyEndpointName string) error { + _, err := describeRDSProxyCustomEndpointAndFindURI(ctx, rdsClient, proxyEndpointName, database.GetURI()) + return trace.Wrap(err) +} + +func (c *urlChecker) checkRedshift(ctx context.Context, database types.Database) error { + meta := database.GetAWS() + redshift, err := c.clients.GetAWSRedshiftClient(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta)) + if err != nil { + return trace.Wrap(err) + } + cluster, err := describeRedshiftCluster(ctx, redshift, meta.Redshift.ClusterID) + if err != nil { + return trace.Wrap(err) + } + if cluster.Endpoint == nil { + return trace.BadParameter("missing endpoint in Redshift cluster %v", aws.StringValue(cluster.ClusterIdentifier)) + } + return trace.Wrap(requireDatabaseAddressPort(database, cluster.Endpoint.Address, cluster.Endpoint.Port)) +} + +func (c *urlChecker) checkRedshiftServerless(ctx context.Context, database types.Database) error { + meta := database.GetAWS() + client, err := c.clients.GetAWSRedshiftServerlessClient(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta)) + if err != nil { + return trace.Wrap(err) + } + + if meta.RedshiftServerless.EndpointName != "" { + return trace.Wrap(c.checkRedshiftServerlessVPCEndpoint(ctx, database, client, meta.RedshiftServerless.EndpointName)) + } + return trace.Wrap(c.checkRedshiftServerlessWorkgroup(ctx, database, client, meta.RedshiftServerless.WorkgroupName)) +} + +func (c *urlChecker) checkRedshiftServerlessVPCEndpoint(ctx context.Context, database types.Database, client redshiftserverlessiface.RedshiftServerlessAPI, endpointName string) error { + endpoint, err := describeRedshiftServerlessVCPEndpoint(ctx, client, endpointName) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(requireDatabaseAddressPort(database, endpoint.Address, endpoint.Port)) +} + +func (c *urlChecker) checkRedshiftServerlessWorkgroup(ctx context.Context, database types.Database, client redshiftserverlessiface.RedshiftServerlessAPI, workgroupName string) error { + workgroup, err := describeRedshiftServerlessWorkgroup(ctx, client, workgroupName) + if err != nil { + return trace.Wrap(err) + } + if workgroup.Endpoint == nil { + return trace.BadParameter("missing endpoint") + } + return trace.Wrap(requireDatabaseAddressPort(database, workgroup.Endpoint.Address, workgroup.Endpoint.Port)) +} + +func (c *urlChecker) checkElastiCache(ctx context.Context, database types.Database) error { + meta := database.GetAWS() + elastiCacheClient, err := c.clients.GetAWSElastiCacheClient(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta)) + if err != nil { + return trace.Wrap(err) + } + cluster, err := describeElastiCacheCluster(ctx, elastiCacheClient, meta.ElastiCache.ReplicationGroupID) + if err != nil { + return trace.Wrap(err) + } + databases, err := services.NewDatabasesFromElastiCacheReplicationGroup(cluster, nil) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(requireContainsDatabaseURLAndEndpointType(databases, database, cluster)) +} + +func (c *urlChecker) checkMemoryDB(ctx context.Context, database types.Database) error { + meta := database.GetAWS() + memoryDBClient, err := c.clients.GetAWSMemoryDBClient(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta)) + if err != nil { + return trace.Wrap(err) + } + cluster, err := describeMemoryDBCluster(ctx, memoryDBClient, meta.MemoryDB.ClusterName) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(requireDatabaseAddressPort(database, cluster.ClusterEndpoint.Address, cluster.ClusterEndpoint.Port)) +} + +func (c *urlChecker) checkOpenSearch(ctx context.Context, database types.Database) error { + meta := database.GetAWS() + client, err := c.clients.GetAWSOpenSearchClient(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta)) + if err != nil { + return trace.Wrap(err) + } + + domains, err := client.DescribeDomainsWithContext(ctx, &opensearchservice.DescribeDomainsInput{ + DomainNames: []*string{aws.String(meta.OpenSearch.DomainName)}, + }) + if err != nil { + return trace.Wrap(cloudaws.ConvertRequestFailureError(err)) + } + if len(domains.DomainStatusList) != 1 { + return trace.BadParameter("expect 1 domain but got %v", domains.DomainStatusList) + } + + databases, err := services.NewDatabasesFromOpenSearchDomain(domains.DomainStatusList[0], nil) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(requireContainsDatabaseURLAndEndpointType(databases, database, domains.DomainStatusList[0])) +} + +func (c *urlChecker) checkOpenSearchEndpoint(ctx context.Context, database types.Database) error { + switch database.GetAWS().OpenSearch.EndpointType { + case apiawsutils.OpenSearchDefaultEndpoint, apiawsutils.OpenSearchVPCEndpoint: + return trace.Wrap(convIsEndpoint(apiawsutils.IsOpenSearchEndpoint)(ctx, database)) + default: + // Custom endpoint can be anything. For best security, don't allow it. + // Primary endpoint should also be discovered and users can still use + // that. + return trace.BadParameter(`cannot validate OpenSearch custom domain %v. Please provide Database Service "es:DescribeDomains" permission to validate the URL.`, database.GetURI()) + } +} diff --git a/lib/srv/db/cloud/resource_checker_url_aws_test.go b/lib/srv/db/cloud/resource_checker_url_aws_test.go new file mode 100644 index 0000000000000..edc3b32098e10 --- /dev/null +++ b/lib/srv/db/cloud/resource_checker_url_aws_test.go @@ -0,0 +1,195 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cloud + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go/service/elasticache" + "github.com/aws/aws-sdk-go/service/memorydb" + "github.com/aws/aws-sdk-go/service/opensearchservice" + "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go/service/redshift" + "github.com/aws/aws-sdk-go/service/redshiftserverless" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + apiawsutils "github.com/gravitational/teleport/api/utils/aws" + "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/mocks" + "github.com/gravitational/teleport/lib/services" +) + +func TestURLChecker_AWS(t *testing.T) { + t.Parallel() + + log := logrus.New() + log.SetLevel(logrus.DebugLevel) + ctx := context.Background() + region := "us-west-2" + var testCases types.Databases + + // RDS. + rdsInstance := mocks.RDSInstance("rds-instance", region, nil) + rdsInstanceDB, err := services.NewDatabaseFromRDSInstance(rdsInstance) + require.NoError(t, err) + rdsCluster := mocks.RDSCluster("rds-cluster", region, nil, + mocks.WithRDSClusterReader, + mocks.WithRDSClusterCustomEndpoint("my-custom"), + ) + rdsClusterDBs, err := services.NewDatabasesFromRDSCluster(rdsCluster) + require.NoError(t, err) + require.Len(t, rdsClusterDBs, 3) // Primary, reader, custom. + testCases = append(testCases, append(rdsClusterDBs, rdsInstanceDB)...) + + // RDS Proxy. + rdsProxy := mocks.RDSProxy("rds-proxy", region, "some-vpc") + rdsProxyDB, err := services.NewDatabaseFromRDSProxy(rdsProxy, 1234, nil) + require.NoError(t, err) + rdsProxyCustomEndpoint := mocks.RDSProxyCustomEndpoint(rdsProxy, "my-custom", region) + rdsProxyCustomEndpointDB, err := services.NewDatabaseFromRDSProxyCustomEndpoint(rdsProxy, rdsProxyCustomEndpoint, 1234, nil) + require.NoError(t, err) + testCases = append(testCases, rdsProxyDB, rdsProxyCustomEndpointDB) + + // Redshift. + redshiftCluster := mocks.RedshiftCluster("redshift-cluster", region, nil) + redshiftClusterDB, err := services.NewDatabaseFromRedshiftCluster(redshiftCluster) + require.NoError(t, err) + testCases = append(testCases, redshiftClusterDB) + + // Redshift Serverless. + redshiftServerlessWorkgroup := mocks.RedshiftServerlessWorkgroup("redshift-serverless", region) + redshiftServerlessDB, err := services.NewDatabaseFromRedshiftServerlessWorkgroup(redshiftServerlessWorkgroup, nil) + require.NoError(t, err) + redshiftServerlessVPCEndpoint := mocks.RedshiftServerlessEndpointAccess(redshiftServerlessWorkgroup, "vpc-endpoint", region) + redshiftServerlessVPCEndpointDB, err := services.NewDatabaseFromRedshiftServerlessVPCEndpoint(redshiftServerlessVPCEndpoint, redshiftServerlessWorkgroup, nil) + require.NoError(t, err) + testCases = append(testCases, redshiftServerlessDB, redshiftServerlessVPCEndpointDB) + + // ElastiCache. + elastiCacheCluster := mocks.ElastiCacheCluster("elasticache", region, mocks.WithElastiCacheReaderEndpoint) + elastiCacheClusterDBs, err := services.NewDatabasesFromElastiCacheNodeGroups(elastiCacheCluster, nil) + require.NoError(t, err) + require.Len(t, elastiCacheClusterDBs, 2) // Primary, reader. + elastiCacheClusterConfigurationMode := mocks.ElastiCacheCluster("elasticache-configuration", region, mocks.WithElastiCacheConfigurationEndpoint) + elastiCacheClusterConfigurationModeDB, err := services.NewDatabaseFromElastiCacheConfigurationEndpoint(elastiCacheClusterConfigurationMode, nil) + require.NoError(t, err) + testCases = append(testCases, append(elastiCacheClusterDBs, elastiCacheClusterConfigurationModeDB)...) + + // MemoryDB. + memoryDBCluster := mocks.MemoryDBCluster("memorydb", region) + memoryDBClusterDB, err := services.NewDatabaseFromMemoryDBCluster(memoryDBCluster, nil) + require.NoError(t, err) + testCases = append(testCases, memoryDBClusterDB) + + // OpenSearch. + openSearchDomain := mocks.OpenSearchDomain("opensearch", region, mocks.WithOpenSearchCustomEndpoint("custom.com")) + openSearchDBs, err := services.NewDatabasesFromOpenSearchDomain(openSearchDomain, nil) + require.NoError(t, err) + require.Len(t, openSearchDBs, 2) // Primary, custom. + openSearchVPCDomain := mocks.OpenSearchDomain("opensearch-vpc", region, mocks.WithOpenSearchVPCEndpoint("vpc")) + openSearchVPCDomainDBs, err := services.NewDatabasesFromOpenSearchDomain(openSearchVPCDomain, nil) + require.NoError(t, err) + require.Len(t, openSearchVPCDomainDBs, 1) + testCases = append(testCases, append(openSearchDBs, openSearchVPCDomainDBs...)...) + + // Mock cloud clients. + mockClients := &cloud.TestCloudClients{ + RDS: &mocks.RDSMock{ + DBInstances: []*rds.DBInstance{rdsInstance}, + DBClusters: []*rds.DBCluster{rdsCluster}, + DBProxies: []*rds.DBProxy{rdsProxy}, + DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyCustomEndpoint}, + }, + Redshift: &mocks.RedshiftMock{ + Clusters: []*redshift.Cluster{redshiftCluster}, + }, + RedshiftServerless: &mocks.RedshiftServerlessMock{ + Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup}, + Endpoints: []*redshiftserverless.EndpointAccess{redshiftServerlessVPCEndpoint}, + }, + ElastiCache: &mocks.ElastiCacheMock{ + ReplicationGroups: []*elasticache.ReplicationGroup{elastiCacheClusterConfigurationMode, elastiCacheCluster}, + }, + MemoryDB: &mocks.MemoryDBMock{ + Clusters: []*memorydb.Cluster{memoryDBCluster}, + }, + OpenSearch: &mocks.OpenSearchMock{ + Domains: []*opensearchservice.DomainStatus{openSearchDomain, openSearchVPCDomain}, + }, + STS: &mocks.STSMock{}, + } + mockClientsUnauth := &cloud.TestCloudClients{ + RDS: &mocks.RDSMockUnauth{}, + Redshift: &mocks.RedshiftMockUnauth{}, + RedshiftServerless: &mocks.RedshiftServerlessMock{Unauth: true}, + ElastiCache: &mocks.ElastiCacheMock{Unauth: true}, + MemoryDB: &mocks.MemoryDBMock{Unauth: true}, + OpenSearch: &mocks.OpenSearchMock{Unauth: true}, + STS: &mocks.STSMock{}, + } + + // Test both check methods. + // Note that "No permissions" logs should only be printed during the second + // group ("basic endpoint check"). + methods := []struct { + name string + clients cloud.Clients + }{ + { + name: "API check", + clients: mockClients, + }, + { + name: "basic endpoint check", + clients: mockClientsUnauth, + }, + } + + for _, method := range methods { + t.Run(method.name, func(t *testing.T) { + c := newURLChecker(DiscoveryResourceCheckerConfig{ + Clients: method.clients, + Log: log, + }) + + for _, database := range testCases { + t.Run(database.GetName(), func(t *testing.T) { + t.Run("valid", func(t *testing.T) { + // Special case for OpenSearch custom endpoint where basic endpoint check always fails. + if database.GetAWS().OpenSearch.EndpointType == apiawsutils.OpenSearchCustomEndpoint && + method.name == "basic endpoint check" { + require.Error(t, c.Check(ctx, database)) + return + } + + require.NoError(t, c.Check(ctx, database)) + }) + + // Make a copy and set an invalid URI. + t.Run("invalid", func(t *testing.T) { + invalid := database.Copy() + invalid.SetURI("localhost:12345") + require.Error(t, c.Check(ctx, invalid)) + }) + }) + } + }) + } +} diff --git a/lib/srv/db/cloud/resource_checker_url_azure.go b/lib/srv/db/cloud/resource_checker_url_azure.go new file mode 100644 index 0000000000000..7ead927cb77db --- /dev/null +++ b/lib/srv/db/cloud/resource_checker_url_azure.go @@ -0,0 +1,51 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cloud + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/azure" + "github.com/gravitational/teleport/lib/defaults" +) + +func (c *urlChecker) checkAzure(ctx context.Context, database types.Database) error { + // TODO check by fetching the resources from Azure and compare the URLs. + if err := c.checkIsAzureEndpoint(ctx, database); err != nil { + return trace.Wrap(err) + } + c.log.Debugf("Azure database %q URL validated.", database.GetName()) + return nil +} + +func (c *urlChecker) checkIsAzureEndpoint(ctx context.Context, database types.Database) error { + switch database.GetProtocol() { + case defaults.ProtocolRedis: + return trace.Wrap(requireDatabaseIsEndpoint(ctx, database, azure.IsCacheForRedisEndpoint)) + + case defaults.ProtocolMySQL, defaults.ProtocolPostgres: + return trace.Wrap(requireDatabaseIsEndpoint(ctx, database, azure.IsDatabaseEndpoint)) + + case defaults.ProtocolSQLServer: + return trace.Wrap(requireDatabaseIsEndpoint(ctx, database, azure.IsMSSQLServerEndpoint)) + } + c.log.Debugf("URL checker does not support Azure database type %q protocol %q.", database.GetType(), database.GetProtocol()) + return nil +} diff --git a/lib/srv/db/cloud/resource_checker_url_azure_test.go b/lib/srv/db/cloud/resource_checker_url_azure_test.go new file mode 100644 index 0000000000000..32ee138a483c3 --- /dev/null +++ b/lib/srv/db/cloud/resource_checker_url_azure_test.go @@ -0,0 +1,82 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cloud + +import ( + "context" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" +) + +func TestURLChecker_Azure(t *testing.T) { + t.Parallel() + + log := logrus.New() + log.SetLevel(logrus.DebugLevel) + ctx := context.Background() + + testCases := types.Databases{ + mustMakeAzureDatabase(t, "mysql", defaults.ProtocolMySQL, "mysql.mysql.database.azure.com:3306", types.Azure{}), + mustMakeAzureDatabase(t, "postgres", defaults.ProtocolPostgres, "postgres.postgres.database.azure.com:5432", types.Azure{}), + mustMakeAzureDatabase(t, "redis", defaults.ProtocolRedis, "redis.redis.cache.windows.net:6380", types.Azure{ + ResourceID: "/subscriptions//resourceGroups//providers/Microsoft.Cache/Redis/redis", + }), + mustMakeAzureDatabase(t, "redis-enterprise", defaults.ProtocolRedis, "redis-enterprise.region.redisenterprise.cache.azure.net", types.Azure{ + ResourceID: "/subscriptions//resourceGroups//providers/Microsoft.Cache/redisEnterprise/databases/default", + }), + mustMakeAzureDatabase(t, "sqlserver", defaults.ProtocolSQLServer, "sqlserver.database.windows.net:1433", types.Azure{}), + } + + c := newURLChecker(DiscoveryResourceCheckerConfig{ + Log: log, + }) + for _, database := range testCases { + t.Run(database.GetName(), func(t *testing.T) { + t.Run("valid", func(t *testing.T) { + require.NoError(t, c.Check(ctx, database)) + }) + + // Make a copy and set an invalid URI. + t.Run("invalid", func(t *testing.T) { + invalid := database.Copy() + invalid.SetURI("localhost:12345") + require.Error(t, c.Check(ctx, invalid)) + }) + }) + } +} + +func mustMakeAzureDatabase(t *testing.T, name, protocol, uri string, azure types.Azure) types.Database { + t.Helper() + + database, err := types.NewDatabaseV3( + types.Metadata{ + Name: name, + }, types.DatabaseSpecV3{ + URI: uri, + Protocol: protocol, + Azure: azure, + }, + ) + require.NoError(t, err) + return database +}