Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/ISSUE_TEMPLATE/testplan.md
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,7 @@ tsh bench web sessions --max=5000 --web user ls
- [ ] Can update registered database using `tctl create -f`.
- [ ] Can delete registered database using `tctl rm`.
- [ ] Verify discovery.
Please configure discovery in Discovery Service instead of Database Service.
- [ ] AWS
- [ ] Can detect and register RDS instances.
- [ ] Can detect and register RDS instances in an external AWS account when `assume_role_arn` and `external_id` is set.
Expand All @@ -936,6 +937,7 @@ tsh bench web sessions --max=5000 --web user ls
- [ ] Can detect and register Redshift serverless workgroups, and their VPC endpoints.
- [ ] Can detect and register ElastiCache Redis clusters.
- [ ] Can detect and register MemoryDB clusters.
- [ ] Can detect and register OpenSearch domains.
- [ ] Azure
- [ ] Can detect and register MySQL and Postgres single-server instances.
- [ ] Can detect and register MySQL and Postgres flexible-server instances.
Expand Down
18 changes: 18 additions & 0 deletions api/types/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ type Database interface {
// SupportsAutoUsers returns true if this database supports automatic
// user provisioning.
SupportsAutoUsers() bool
// GetEndpointType returns the endpoint type of the database, if available.
GetEndpointType() string
}

// NewDatabaseV3 creates a new database resource.
Expand Down Expand Up @@ -947,6 +949,22 @@ func (d *DatabaseV3) SupportAWSIAMRoleARNAsUsers() bool {
return d.GetType() == DatabaseTypeMongoAtlas
}

// GetEndpointType returns the endpoint type of the database, if available.
func (d *DatabaseV3) GetEndpointType() string {
if endpointType, ok := d.GetStaticLabels()[DiscoveryLabelEndpointType]; ok {
return endpointType
}
switch d.GetType() {
case DatabaseTypeElastiCache:
return d.GetAWS().ElastiCache.EndpointType
case DatabaseTypeMemoryDB:
return d.GetAWS().MemoryDB.EndpointType
case DatabaseTypeOpenSearch:
return d.GetAWS().OpenSearch.EndpointType
}
return ""
}

const (
// DatabaseProtocolPostgreSQL is the PostgreSQL database protocol.
DatabaseProtocolPostgreSQL = "postgres"
Expand Down
5 changes: 5 additions & 0 deletions api/utils/aws/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ func IsKeyspacesEndpoint(uri string) bool {
return hasCassandraPrefix && IsAWSEndpoint(uri)
}

// IsOpenSearchEndpoint returns true if input URI is an OpenSearch endpoint.
func IsOpenSearchEndpoint(uri string) bool {
return isAWSServiceEndpoint(uri, OpenSearchServiceName)
}

// RDSEndpointDetails contains information about an RDS endpoint.
type RDSEndpointDetails struct {
// InstanceID is the identifier of an RDS instance.
Expand Down
13 changes: 13 additions & 0 deletions lib/cloud/mocks/aws_memorydb.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
type MemoryDBMock struct {
memorydbiface.MemoryDBAPI

Unauth bool
Clusters []*memorydb.Cluster
Users []*memorydb.User
TagsByARN map[string][]*memorydb.Tag
Expand Down Expand Up @@ -60,6 +61,9 @@ func (m *MemoryDBMock) DescribeSubnetGroupsWithContext(aws.Context, *memorydb.De
}

func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memorydb.DescribeClustersInput, _ ...request.Option) (*memorydb.DescribeClustersOutput, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
}
if aws.StringValue(input.ClusterName) == "" {
return &memorydb.DescribeClustersOutput{
Clusters: m.Clusters,
Expand All @@ -77,6 +81,9 @@ func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memoryd
}

func (m *MemoryDBMock) ListTagsWithContext(_ aws.Context, input *memorydb.ListTagsInput, _ ...request.Option) (*memorydb.ListTagsOutput, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
}
if m.TagsByARN == nil {
return nil, trace.NotFound("no tags")
}
Expand All @@ -92,12 +99,18 @@ func (m *MemoryDBMock) ListTagsWithContext(_ aws.Context, input *memorydb.ListTa
}

func (m *MemoryDBMock) DescribeUsersWithContext(aws.Context, *memorydb.DescribeUsersInput, ...request.Option) (*memorydb.DescribeUsersOutput, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
}
return &memorydb.DescribeUsersOutput{
Users: m.Users,
}, nil
}

func (m *MemoryDBMock) UpdateUserWithContext(_ aws.Context, input *memorydb.UpdateUserInput, opts ...request.Option) (*memorydb.UpdateUserOutput, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
}
for _, user := range m.Users {
if aws.StringValue(user.Name) == aws.StringValue(input.UserName) {
return &memorydb.UpdateUserOutput{}, nil
Expand Down
22 changes: 20 additions & 2 deletions lib/cloud/mocks/aws_opensearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,21 @@ import (
"github.com/aws/aws-sdk-go/service/opensearchservice"
"github.com/aws/aws-sdk-go/service/opensearchservice/opensearchserviceiface"
"github.com/gravitational/trace"
"golang.org/x/exp/slices"
)

type OpenSearchMock struct {
opensearchserviceiface.OpenSearchServiceAPI

Unauth bool
Domains []*opensearchservice.DomainStatus
TagsByARN map[string][]*opensearchservice.Tag
}

func (o *OpenSearchMock) ListDomainNamesWithContext(aws.Context, *opensearchservice.ListDomainNamesInput, ...request.Option) (*opensearchservice.ListDomainNamesOutput, error) {
if o.Unauth {
return nil, trace.AccessDenied("unauthorized")
}
out := &opensearchservice.ListDomainNamesOutput{}
for _, domain := range o.Domains {
out.DomainNames = append(out.DomainNames, &opensearchservice.DomainInfo{
Expand All @@ -45,12 +50,25 @@ func (o *OpenSearchMock) ListDomainNamesWithContext(aws.Context, *opensearchserv
return out, nil
}

func (o *OpenSearchMock) DescribeDomainsWithContext(aws.Context, *opensearchservice.DescribeDomainsInput, ...request.Option) (*opensearchservice.DescribeDomainsOutput, error) {
out := &opensearchservice.DescribeDomainsOutput{DomainStatusList: o.Domains}
func (o *OpenSearchMock) DescribeDomainsWithContext(_ aws.Context, input *opensearchservice.DescribeDomainsInput, _ ...request.Option) (*opensearchservice.DescribeDomainsOutput, error) {
if o.Unauth {
return nil, trace.AccessDenied("unauthorized")
}
out := &opensearchservice.DescribeDomainsOutput{}
for _, domain := range o.Domains {
if slices.ContainsFunc(input.DomainNames, func(other *string) bool {
return aws.StringValue(other) == aws.StringValue(domain.DomainName)
}) {
out.DomainStatusList = append(out.DomainStatusList, domain)
}
}
return out, nil
}

func (o *OpenSearchMock) ListTagsWithContext(_ aws.Context, request *opensearchservice.ListTagsInput, _ ...request.Option) (*opensearchservice.ListTagsOutput, error) {
if o.Unauth {
return nil, trace.AccessDenied("unauthorized")
}
tags, found := o.TagsByARN[aws.StringValue(request.ARN)]
if !found {
return nil, trace.NotFound("tags not found")
Expand Down
19 changes: 18 additions & 1 deletion lib/cloud/mocks/aws_redshift_serverless.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,18 @@ import (
type RedshiftServerlessMock struct {
redshiftserverlessiface.RedshiftServerlessAPI

Unauth bool
Workgroups []*redshiftserverless.Workgroup
Endpoints []*redshiftserverless.EndpointAccess
TagsByARN map[string][]*redshiftserverless.Tag
GetCredentialsOutput *redshiftserverless.GetCredentialsOutput
}

func (m RedshiftServerlessMock) GetWorkgroupWithContext(_ aws.Context, input *redshiftserverless.GetWorkgroupInput, _ ...request.Option) (*redshiftserverless.GetWorkgroupOutput, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
}

for _, workgroup := range m.Workgroups {
if aws.StringValue(workgroup.WorkgroupName) == aws.StringValue(input.WorkgroupName) {
return new(redshiftserverless.GetWorkgroupOutput).SetWorkgroup(workgroup), nil
Expand All @@ -47,6 +52,9 @@ func (m RedshiftServerlessMock) GetWorkgroupWithContext(_ aws.Context, input *re
return nil, trace.NotFound("workgroup %q not found", aws.StringValue(input.WorkgroupName))
}
func (m RedshiftServerlessMock) GetEndpointAccessWithContext(_ aws.Context, input *redshiftserverless.GetEndpointAccessInput, _ ...request.Option) (*redshiftserverless.GetEndpointAccessOutput, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
}
for _, endpoint := range m.Endpoints {
if aws.StringValue(endpoint.EndpointName) == aws.StringValue(input.EndpointName) {
return new(redshiftserverless.GetEndpointAccessOutput).SetEndpoint(endpoint), nil
Expand All @@ -55,18 +63,27 @@ func (m RedshiftServerlessMock) GetEndpointAccessWithContext(_ aws.Context, inpu
return nil, trace.NotFound("endpoint %q not found", aws.StringValue(input.EndpointName))
}
func (m RedshiftServerlessMock) ListWorkgroupsPagesWithContext(_ aws.Context, input *redshiftserverless.ListWorkgroupsInput, fn func(*redshiftserverless.ListWorkgroupsOutput, bool) bool, _ ...request.Option) error {
if m.Unauth {
return trace.AccessDenied("unauthorized")
}
fn(&redshiftserverless.ListWorkgroupsOutput{
Workgroups: m.Workgroups,
}, true)
return nil
}
func (m RedshiftServerlessMock) ListEndpointAccessPagesWithContext(_ aws.Context, input *redshiftserverless.ListEndpointAccessInput, fn func(*redshiftserverless.ListEndpointAccessOutput, bool) bool, _ ...request.Option) error {
if m.Unauth {
return trace.AccessDenied("unauthorized")
}
fn(&redshiftserverless.ListEndpointAccessOutput{
Endpoints: m.Endpoints,
}, true)
return nil
}
func (m RedshiftServerlessMock) ListTagsForResourceWithContext(_ aws.Context, input *redshiftserverless.ListTagsForResourceInput, _ ...request.Option) (*redshiftserverless.ListTagsForResourceOutput, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
}
if m.TagsByARN == nil {
return &redshiftserverless.ListTagsForResourceOutput{}, nil
}
Expand All @@ -75,7 +92,7 @@ func (m RedshiftServerlessMock) ListTagsForResourceWithContext(_ aws.Context, in
}, nil
}
func (m RedshiftServerlessMock) GetCredentialsWithContext(aws.Context, *redshiftserverless.GetCredentialsInput, ...request.Option) (*redshiftserverless.GetCredentialsOutput, error) {
if m.GetCredentialsOutput == nil {
if m.Unauth || m.GetCredentialsOutput == nil {
return nil, trace.AccessDenied("access denied")
}
return m.GetCredentialsOutput, nil
Expand Down
4 changes: 3 additions & 1 deletion lib/configurators/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ func makeDatabaseActionsBuildOption(flags configurators.BootstrapFlags, targetCf
case configurators.DatabaseServiceByDiscoveryServiceConfig:
return databaseActionsBuildOption{
withDiscovery: false,
withMetadata: false, // Discovered databases should have correct metadata.
withAuth: true,
withAuthBoundary: boundary,
// Discovered databases should be checked by URL validator which
// requires same permissions as the metadata service.
withMetadata: true,
}

case configurators.DatabaseService:
Expand Down
Loading