diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index ac0deddb52b67..d11a77caa22ab 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -610,10 +610,10 @@ func (c *cloudClients) getAWSSessionForRegion(region string) (*awssession.Sessio // getAWSSessionForRole returns AWS session for the specified region and role. func (c *cloudClients) getAWSSessionForRole(ctx context.Context, region string, options awsAssumeRoleOpts) (*awssession.Session, error) { - assumeRoler := sts.New(options.baseSession) cacheKey := fmt.Sprintf("Region[%s]:RoleARN[%s]:ExternalID[%s]", region, options.assumeRoleARN, options.assumeRoleExternalID) return utils.FnCacheGet(ctx, c.awsSessionsCache, cacheKey, func(ctx context.Context) (*awssession.Session, error) { - return newSessionWithRole(ctx, assumeRoler, region, options.assumeRoleARN, options.assumeRoleExternalID) + stsClient := sts.New(options.baseSession) + return newSessionWithRole(ctx, stsClient, region, options.assumeRoleARN, options.assumeRoleExternalID) }) } diff --git a/lib/integrations/awsoidc/listdatabases_test.go b/lib/integrations/awsoidc/listdatabases_test.go index 5438b66d8ad4f..c4f3d3927ba7d 100644 --- a/lib/integrations/awsoidc/listdatabases_test.go +++ b/lib/integrations/awsoidc/listdatabases_test.go @@ -24,6 +24,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" rdsTypes "github.com/aws/aws-sdk-go-v2/service/rds/types" + "github.com/google/go-cmp/cmp" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -187,12 +188,14 @@ func TestListDatabases(t *testing.T) { Name: "my-db", Description: "RDS instance in ", Labels: map[string]string{ - "account-id": "123456789012", - "endpoint-type": "instance", - "engine": "postgres", - "engine-version": "", - "region": "", - "status": "available", + "account-id": "123456789012", + "endpoint-type": "instance", + "engine": "postgres", + "engine-version": "", + "region": "", + "status": "available", + "teleport.dev/cloud": "AWS", + "teleport.dev/origin": "cloud", }, }, types.DatabaseSpecV3{ @@ -208,7 +211,7 @@ func TestListDatabases(t *testing.T) { }, ) require.NoError(t, err) - require.Equal(t, expectedDB, ldr.Databases[0]) + require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0])) }, errCheck: noErrorFunc, }, @@ -250,12 +253,14 @@ func TestListDatabases(t *testing.T) { Name: "my-db", Description: "RDS instance in ", Labels: map[string]string{ - "account-id": "123456789012", - "endpoint-type": "instance", - "engine": "postgres", - "engine-version": "", - "region": "", - "status": "available", + "account-id": "123456789012", + "endpoint-type": "instance", + "engine": "postgres", + "engine-version": "", + "region": "", + "status": "available", + "teleport.dev/cloud": "AWS", + "teleport.dev/origin": "cloud", }, }, types.DatabaseSpecV3{ @@ -271,7 +276,7 @@ func TestListDatabases(t *testing.T) { }, ) require.NoError(t, err) - require.Equal(t, expectedDB, ldr.Databases[0]) + require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0])) }, errCheck: noErrorFunc, }, @@ -300,12 +305,14 @@ func TestListDatabases(t *testing.T) { Name: "my-dbc", Description: "Aurora cluster in ", Labels: map[string]string{ - "account-id": "123456789012", - "endpoint-type": "primary", - "engine": "aurora-postgresql", - "engine-version": "", - "region": "", - "status": "available", + "account-id": "123456789012", + "endpoint-type": "primary", + "engine": "aurora-postgresql", + "engine-version": "", + "region": "", + "status": "available", + "teleport.dev/cloud": "AWS", + "teleport.dev/origin": "cloud", }, }, types.DatabaseSpecV3{ @@ -322,7 +329,7 @@ func TestListDatabases(t *testing.T) { }, ) require.NoError(t, err) - require.Equal(t, expectedDB, ldr.Databases[0]) + require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0])) }, errCheck: noErrorFunc, }, diff --git a/lib/services/database.go b/lib/services/database.go index 7dceeeaa7af0d..cb5e15d95ef78 100644 --- a/lib/services/database.go +++ b/lib/services/database.go @@ -948,8 +948,8 @@ func newElastiCacheDatabase(cluster *elasticache.ReplicationGroup, endpoint *ela }) } -// NewDatabaseFromOpenSearchDomain creates a database resource from an OpenSearch domain. -func NewDatabaseFromOpenSearchDomain(domain *opensearchservice.DomainStatus, tags []*opensearchservice.Tag) (types.Databases, error) { +// NewDatabasesFromOpenSearchDomain creates database resources from an OpenSearch domain. +func NewDatabasesFromOpenSearchDomain(domain *opensearchservice.DomainStatus, tags []*opensearchservice.Tag) (types.Databases, error) { var databases types.Databases if aws.StringValue(domain.Endpoint) != "" { diff --git a/lib/srv/db/watcher_test.go b/lib/srv/db/watcher_test.go index 93e09f63d117e..9880c1115c8f1 100644 --- a/lib/srv/db/watcher_test.go +++ b/lib/srv/db/watcher_test.go @@ -36,6 +36,7 @@ import ( "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" + discovery "github.com/gravitational/teleport/lib/srv/discovery/common" ) // TestWatcher verifies that database server properly detects and applies @@ -263,6 +264,7 @@ func TestWatcherCloudFetchers(t *testing.T) { redshiftServerlessDatabase.SetStatusAWS(redshiftServerlessDatabase.GetAWS()) setDiscoveryGroupLabel(redshiftServerlessDatabase, "") redshiftServerlessDatabase.SetOrigin(types.OriginCloud) + discovery.ApplyAWSDatabaseNameSuffix(redshiftServerlessDatabase, services.AWSMatcherRedshiftServerless) // Test an Azure fetcher. azSQLServer, azSQLServerDatabase := makeAzureSQLServer(t, "discovery-azure", "group") setDiscoveryGroupLabel(azSQLServerDatabase, "") @@ -375,5 +377,6 @@ func makeAzureSQLServer(t *testing.T, name, group string) (*armsql.Server, types } database, err := services.NewDatabaseFromAzureSQLServer(server) require.NoError(t, err) + discovery.ApplyAzureDatabaseNameSuffix(database, services.AzureMatcherSQLServer) return server, database } diff --git a/lib/srv/discovery/common/renaming.go b/lib/srv/discovery/common/renaming.go new file mode 100644 index 0000000000000..2709cbb451ca9 --- /dev/null +++ b/lib/srv/discovery/common/renaming.go @@ -0,0 +1,271 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "fmt" + "regexp" + "strings" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" +) + +// ApplyAWSDatabaseNameSuffix applies the AWS Database Discovery name suffix to +// the given database. +// Format: ---. +func ApplyAWSDatabaseNameSuffix(db types.Database, matcherType string) { + if hasOverrideLabel(db, types.AWSDatabaseNameOverrideLabels...) { + // don't rewrite manual name override. + return + } + meta := db.GetAWS() + suffix := makeAWSDiscoverySuffix(databaseNamePartValidator, + db.GetName(), + matcherType, + getDBMatcherSubtype(matcherType, db), + meta.Region, + meta.AccountID, + ) + applyDiscoveryNameSuffix(db, suffix) +} + +// ApplyAzureDatabaseNameSuffix applies the Azure Database Discovery name suffix +// to the given database. +// Format: ----. +func ApplyAzureDatabaseNameSuffix(db types.Database, matcherType string) { + if hasOverrideLabel(db, types.AzureDatabaseNameOverrideLabel) { + // don't rewrite manual name override. + return + } + region, _ := db.GetLabel(types.DiscoveryLabelRegion) + group, _ := db.GetLabel(types.DiscoveryLabelAzureResourceGroup) + subID, _ := db.GetLabel(types.DiscoveryLabelAzureSubscriptionID) + suffix := makeAzureDiscoverySuffix(databaseNamePartValidator, + db.GetName(), + matcherType, + getDBMatcherSubtype(matcherType, db), + region, + group, + subID, + ) + applyDiscoveryNameSuffix(db, suffix) +} + +// getDBMatcherSubtype gets a "subtype" for a given DB matcher, based on the +// database metadata. This is needed for AWS RDS and Azure Redis databases +// to ensure unique naming. +// For example, an Aurora cluster can be named the same as an RDS instance in +// the same account, region, etc. +// Likewise, an Azure Redis database can be named the same as an Azure Redis +// Enterprise database. +// By subtyping the matcher type, we can ensure these names do not collide. +func getDBMatcherSubtype(matcherType string, db types.Database) string { + switch matcherType { + case services.AWSMatcherRDS: + if db.GetAWS().RDS.InstanceID == "" { + // distinguish RDS instances from clusters by subtyping the RDS + // matcher as "rds-aurora". + return "aurora" + } + case services.AzureMatcherRedis: + if db.GetAzure().Redis.ClusteringPolicy != "" { + // distinguish Redis databases from Redis Enterprise database by + // subtyping the redis matcher as "redis-enterprise". + return "enterprise" + } + } + return "" +} + +// ApplyEKSNameSuffix applies the AWS EKS Discovery name suffix to the given +// kube cluster. +// Format: -eks--. +func ApplyEKSNameSuffix(cluster types.KubeCluster) { + if hasOverrideLabel(cluster, types.AWSKubeClusterNameOverrideLabels...) { + // don't rewrite manual name override. + return + } + meta := cluster.GetAWSConfig() + suffix := makeAWSDiscoverySuffix(kubeClusterNamePartValidator, + cluster.GetName(), + services.AWSMatcherEKS, + "", // no EKS subtype + meta.Region, + meta.AccountID, + ) + applyDiscoveryNameSuffix(cluster, suffix) +} + +// ApplyAKSNameSuffix applies the Azure AKS Discovery name suffix to the given +// kube cluster. +// Format: -aks---. +func ApplyAKSNameSuffix(cluster types.KubeCluster) { + if hasOverrideLabel(cluster, types.AzureKubeClusterNameOverrideLabel) { + // don't rewrite manual name override. + return + } + meta := cluster.GetAzureConfig() + region, _ := cluster.GetLabel(types.DiscoveryLabelRegion) + suffix := makeAzureDiscoverySuffix(kubeClusterNamePartValidator, + cluster.GetName(), + services.AzureMatcherKubernetes, + "", // no AKS subtype + region, + meta.ResourceGroup, + meta.SubscriptionID, + ) + applyDiscoveryNameSuffix(cluster, suffix) +} + +// ApplyGKENameSuffix applies the GCP GKE Discovery name suffix to the given +// kube cluster. +// Format: -gke--. +func ApplyGKENameSuffix(cluster types.KubeCluster) { + if hasOverrideLabel(cluster, types.GCPKubeClusterNameOverrideLabel) { + // don't rewrite manual name override. + return + } + meta := cluster.GetGCPConfig() + suffix := makeGCPDiscoverySuffix(kubeClusterNamePartValidator, + cluster.GetName(), + services.GCPMatcherKubernetes, + "", // no GKE subtype + meta.Location, + meta.ProjectID, + ) + applyDiscoveryNameSuffix(cluster, suffix) +} + +// hasOverrideLabel is a helper func to check for presence of a name override +// label. +func hasOverrideLabel(r types.ResourceWithLabels, overrideLabels ...string) bool { + for _, label := range overrideLabels { + if val, ok := r.GetLabel(label); ok && val != "" { + return true + } + } + return false +} + +// makeAWSDiscoverySuffix makes a discovery suffix for AWS resources, of the +// form ---. +func makeAWSDiscoverySuffix(fn suffixValidatorFn, name, matcherType, subType, region, accountID string) string { + return makeDiscoverySuffix(fn, name, matcherType, subType, region, accountID) +} + +// makeAzureDiscoverySuffix makes a discovery suffix for Azure resources, of the +// form ----. +func makeAzureDiscoverySuffix(fn suffixValidatorFn, name, matcherType, subType, region, resourceGroup, subscriptionID string) string { + return makeDiscoverySuffix(fn, name, matcherType, subType, region, resourceGroup, subscriptionID) +} + +// makeGCPDiscoverySuffix makes a discovery suffix for GCP resources, of the +// form ---. +func makeGCPDiscoverySuffix(fn suffixValidatorFn, name, matcherType, subType, location, projectID string) string { + return makeDiscoverySuffix(fn, name, matcherType, subType, location, projectID) +} + +// applyDiscoveryNameSuffix takes a resource with labels and a suffix to add +// to the name, then modifies the resource to add a label containing the +// original name and sets a new name with the suffix appended. +// This function does nothing if the suffix is empty. +func applyDiscoveryNameSuffix(resource types.ResourceWithLabels, suffix string) { + if suffix == "" { + // nop if suffix parts aren't given. + return + } + discoveredName := resource.GetName() + labels := resource.GetStaticLabels() + if labels == nil { + labels = make(map[string]string) + } + // set the originally discovered name as a label. + labels[types.DiscoveredNameLabel] = discoveredName + resource.SetStaticLabels(labels) + // update the resource name with a suffix. + resource.SetName(fmt.Sprintf("%s-%s", discoveredName, suffix)) +} + +// suffixValidatorFn is a func that validates a suffix. +type suffixValidatorFn func(string) error + +// databaseNamePartValidator is a suffixValidatorFn for database name suffix +// parts. +func databaseNamePartValidator(part string) error { + // validate the suffix part adding a simple stub prefix "a" and + // validating it as a full database name. + return types.ValidateDatabaseName("a" + part) +} + +// kubeClusterNamePartValidator is a suffixValidatorFn for kube cluster suffix +// parts. +func kubeClusterNamePartValidator(part string) error { + // validate the suffix part adding a simple stub prefix "a" and + // validating it as a full kube cluster name. + return types.ValidateKubeClusterName("a" + part) +} + +// makeDiscoverySuffix takes a list of suffix parts and a suffix validator func, +// sanitizes each part and checks it for validity, then joins the result with +// hyphens "-". +func makeDiscoverySuffix(validatorFn suffixValidatorFn, name string, parts ...string) string { + // convert name to lower case for substring checking. + name = strings.ToLower(name) + var out []string + for _, part := range parts { + part = sanitizeSuffixPart(part) + // skip blank parts. + if part == "" { + continue + } + // skip redundant parts. + if strings.Contains(name, strings.ToLower(part)) { + continue + } + // skip invalid parts. + if err := validatorFn(part); err != nil { + continue + } + out = append(out, part) + } + if len(out) == 0 { + return "" + } + suffix := strings.Join(out, "-") + if err := validatorFn(suffix); err != nil { + // sanity check for the full suffix - if it's somehow invalid, then + // discard it. + return "" + } + return suffix +} + +// sanitizeSuffixPart cleans a suffix part to remove all whitespace, repeating +// and leading/trailing hyphens, and converts the string to all lowercase. +func sanitizeSuffixPart(part string) string { + // convert all whitespace to "-". + part = strings.ReplaceAll(part, " ", "-") + // compact repeating "-" into a single "-", e.g. "a--b----" => "a-b-". + part = repeatingHyphensRegexp.ReplaceAllLiteralString(part, "-") + // trim leading/trailing hyphens out. + part = strings.Trim(part, "-") + return part +} + +// repeatingHyphensRegexp represents a repeating hyphen chars pattern. +var repeatingHyphensRegexp = regexp.MustCompile(`--+`) diff --git a/lib/srv/discovery/common/renaming_test.go b/lib/srv/discovery/common/renaming_test.go new file mode 100644 index 0000000000000..c11f3c4fa9fbf --- /dev/null +++ b/lib/srv/discovery/common/renaming_test.go @@ -0,0 +1,548 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "fmt" + "testing" + + "cloud.google.com/go/container/apiv1/containerpb" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysqlflexibleservers" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v2" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/eks" + "github.com/aws/aws-sdk-go/service/rds" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + azureutils "github.com/gravitational/teleport/api/utils/azure" + libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/azure" + "github.com/gravitational/teleport/lib/cloud/gcp" + "github.com/gravitational/teleport/lib/services" +) + +// TestMakeDiscoverySuffix tests makeDiscoverySuffix in isolation. +func TestMakeDiscoverySuffix(t *testing.T) { + tests := []struct { + name string + resourceName string + extraParts []string + wantSuffix string + }{ + { + name: "no suffix made without extra parts", + resourceName: "foo", + wantSuffix: "", + }, + { + name: "simple parts", + resourceName: "foo", + extraParts: []string{"one", "two", "three"}, + wantSuffix: "one-two-three", + }, + { + name: "skips empty parts", + resourceName: "foo", + extraParts: []string{"one", "", "three"}, + wantSuffix: "one-three", + }, + { + name: "converts extra whitespace to hyphens", + resourceName: "foo", + extraParts: []string{"one", "t w o", "three"}, + wantSuffix: "one-t-w-o-three", + }, + { + name: "removes repeated hypens", + resourceName: "foo", + extraParts: []string{"one---", "t w -- o ", "---three"}, + wantSuffix: "one-t-w-o-three", + }, + { + name: "removes leading and trailing hypens", + resourceName: "foo", + extraParts: []string{"one---", "t w -- o ", "---three"}, + wantSuffix: "one-t-w-o-three", + }, + { + name: "skips adding redundant info", + resourceName: "PostgreSQL-RDS-us-west-1", + // suffixes are added to make resource names unique. + // Adding info as a suffix when that info is already contained in + // the resource name verbatim would pointlessly make a resource name + // longer and ugly, i.e. we don't want users to see like + // "PostgreSQL-RDS-us-west-1-rds-us-west-1-123456789012" as a resource + // name. + extraParts: []string{"rds", "us-west-1", "123456789012"}, + wantSuffix: "123456789012", + }, + { + name: "skips invalid parts", + resourceName: "foo", + // parentheses are illegal in both database and kube cluster names in Teleport. + extraParts: []string{"mysql", "EastUS", "weird)(group-name", "11111111-2222-3333-4444-555555555555"}, + wantSuffix: "mysql-EastUS-11111111-2222-3333-4444-555555555555", + }, + } + for validatorKind, validatorFn := range map[string]suffixValidatorFn{ + "databases": databaseNamePartValidator, + "kube clusters": kubeClusterNamePartValidator, + } { + for _, test := range tests { + t.Run(fmt.Sprintf("%s/%s", validatorKind, test.name), func(t *testing.T) { + got := makeDiscoverySuffix(validatorFn, test.resourceName, test.extraParts...) + require.Equal(t, test.wantSuffix, got) + }) + } + } +} + +// renameFunc is a callback to specialize on the renaming func to use for a +// resource under test. +type renameFunc func(types.ResourceWithLabels) + +// renameTest is a test helper struct to group common test structure for +// renaming resources. +type renameTest struct { + // resource is the resource under test. It will be modified during test run + // if the resource is renamed. + resource types.ResourceWithLabels + // renameFn is used to specialize the renaming func to use. + renameFn renameFunc + // originalName is the name of the resource as it was before renaming. + originalName string + // nameOverrideLabel is the cloud override label used to manually override a + // resource name. Renaming should be skipped when this label is present. + nameOverrideLabel string + // wantNewName is the name the test expects after the resource is renamed + // according to the discovery renaming format. + wantNewName string +} + +func runRenameTest(t *testing.T, test renameTest) { + t.Helper() + // all tests should start out with the override name label set, to indicate that the resource shouldn't be renamed. + requireOverrideLabelIsSet(t, test.resource, test.nameOverrideLabel) + // try renaming the resource. + test.renameFn(test.resource) + // verify it was not renamed. + requireOverrideLabelSkipsRenaming(t, test.resource, test.originalName, test.nameOverrideLabel) + // clear the override label. + labels := test.resource.GetStaticLabels() + delete(labels, test.nameOverrideLabel) + test.resource.SetStaticLabels(labels) + // now try renaming without an override label. + test.renameFn(test.resource) + // verify that the resource was renamed as we expected. + require.Equal(t, test.wantNewName, test.resource.GetName()) + // verify that the original name was saved as a label after renaming. + requireDiscoveredNameLabel(t, test.resource, test.originalName, test.nameOverrideLabel) +} + +func TestApplyAWSDatabaseNameSuffix(t *testing.T) { + tests := []struct { + desc, + dbName, + region, + accountID, + wantRename string + makeDBFunc func(t *testing.T, name, region, account, overrideLabel string) types.Database + }{ + { + desc: "RDS instance", + dbName: "some-db", + region: "us-west-1", + accountID: "123456789012", + wantRename: "some-db-rds-us-west-1-123456789012", + makeDBFunc: makeRDSInstanceDB, + }, + { + desc: "RDS Aurora cluster", + dbName: "some-db", + region: "us-west-1", + accountID: "123456789012", + wantRename: "some-db-rds-aurora-us-west-1-123456789012", + makeDBFunc: makeAuroraPrimaryDB, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + for _, overrideLabel := range types.AWSDatabaseNameOverrideLabels { + database := tt.makeDBFunc(t, tt.dbName, tt.region, tt.accountID, overrideLabel) + test := renameTest{ + resource: database, + renameFn: func(r types.ResourceWithLabels) { + db := r.(types.Database) + ApplyAWSDatabaseNameSuffix(db, services.AWSMatcherRDS) + }, + originalName: tt.dbName, + nameOverrideLabel: overrideLabel, + wantNewName: tt.wantRename, + } + runRenameTest(t, test) + } + }) + } +} + +func TestApplyAzureDatabaseNameSuffix(t *testing.T) { + tests := []struct { + desc, + dbName, + region, + resourceGroup, + subscriptionID, + matcherType, + wantRename string + makeDBFunc func(t *testing.T, name, region, group, subscription string) types.Database + }{ + { + desc: "Azure MySQL Flex", + dbName: "some-db", + region: "East US", // we normalize regions, so this should become "eastus". + resourceGroup: "Some Group", + subscriptionID: "11111111-2222-3333-4444-555555555555", + matcherType: services.AzureMatcherMySQL, + wantRename: "some-db-mysql-eastus-Some-Group-11111111-2222-3333-4444-555555555555", + makeDBFunc: makeAzureMySQLFlexDatabase, + }, + { + desc: "skips invalid resource group", + dbName: "some-db", + region: "eastus", // use the normalized region. + resourceGroup: "(parens are invalid)", + subscriptionID: "11111111-2222-3333-4444-555555555555", + matcherType: services.AzureMatcherMySQL, + wantRename: "some-db-mysql-eastus-11111111-2222-3333-4444-555555555555", + makeDBFunc: makeAzureMySQLFlexDatabase, + }, + { + desc: "Azure Redis", + dbName: "some-db", + region: "eastus", + resourceGroup: "Some Group", + subscriptionID: "11111111-2222-3333-4444-555555555555", + matcherType: services.AzureMatcherRedis, + wantRename: "some-db-redis-eastus-Some-Group-11111111-2222-3333-4444-555555555555", + makeDBFunc: makeAzureRedisDB, + }, + { + desc: "Azure Redis Enterprise", + dbName: "some-db", + region: "eastus", + resourceGroup: "Some Group", + subscriptionID: "11111111-2222-3333-4444-555555555555", + matcherType: services.AzureMatcherRedis, + wantRename: "some-db-redis-enterprise-eastus-Some-Group-11111111-2222-3333-4444-555555555555", + makeDBFunc: makeAzureRedisEnterpriseDB, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + database := tt.makeDBFunc(t, tt.dbName, tt.region, tt.resourceGroup, tt.subscriptionID) + runRenameTest(t, renameTest{ + resource: database, + renameFn: func(r types.ResourceWithLabels) { + db := r.(types.Database) + ApplyAzureDatabaseNameSuffix(db, tt.matcherType) + }, + originalName: tt.dbName, + nameOverrideLabel: types.AzureDatabaseNameOverrideLabel, + wantNewName: tt.wantRename, + }) + }) + } +} + +func TestApplyEKSNameSuffix(t *testing.T) { + clusterName := "some-cluster" + region := "us-west-1" + accountID := "123456789012" + for _, overrideLabel := range types.AWSKubeClusterNameOverrideLabels { + cluster := makeEKSKubeCluster(t, clusterName, region, accountID, overrideLabel) + test := renameTest{ + resource: cluster, + renameFn: func(r types.ResourceWithLabels) { + c := r.(types.KubeCluster) + ApplyEKSNameSuffix(c) + }, + originalName: clusterName, + nameOverrideLabel: overrideLabel, + wantNewName: "some-cluster-eks-us-west-1-123456789012", + } + runRenameTest(t, test) + } +} + +func TestApplyAKSNameSuffix(t *testing.T) { + clusterName := "some-cluster" + region := "westus" + resourceGroup := "Some Group" + subscriptionID := "11111111-2222-3333-4444-555555555555" + cluster := makeAKSKubeCluster(t, clusterName, region, resourceGroup, subscriptionID) + test := renameTest{ + resource: cluster, + renameFn: func(r types.ResourceWithLabels) { + c := r.(types.KubeCluster) + ApplyAKSNameSuffix(c) + }, + originalName: clusterName, + nameOverrideLabel: types.AzureKubeClusterNameOverrideLabel, + wantNewName: "some-cluster-aks-westus-Some-Group-11111111-2222-3333-4444-555555555555", + } + runRenameTest(t, test) +} + +func TestApplyGKENameSuffix(t *testing.T) { + clusterName := "some-cluster" + region := "central-1" + projectID := "dev-123456" + cluster := makeGKEKubeCluster(t, clusterName, region, projectID) + test := renameTest{ + resource: cluster, + renameFn: func(r types.ResourceWithLabels) { + c := r.(types.KubeCluster) + ApplyGKENameSuffix(c) + }, + originalName: clusterName, + nameOverrideLabel: types.GCPKubeClusterNameOverrideLabel, + wantNewName: "some-cluster-gke-central-1-dev-123456", + } + runRenameTest(t, test) +} + +// requireDiscoveredNameLabel is a test helper that requires a resource have +// the originally "discovered" name as a label. +func requireDiscoveredNameLabel(t *testing.T, r types.ResourceWithLabels, want, overrideLabel string) { + t.Helper() + override, ok := r.GetLabel(overrideLabel) + require.False(t, ok, "override label should not be present") + require.Empty(t, override, "override label should not be present") + got, gotOk := r.GetLabel(types.DiscoveredNameLabel) + require.True(t, gotOk, "should have the original discovered name saved in a label") + require.Equal(t, want, got, "should have the original discovered name saved in a label") +} + +func requireOverrideLabelIsSet(t *testing.T, r types.ResourceWithLabels, overrideLabel string) { + t.Helper() + override, ok := r.GetLabel(overrideLabel) + require.True(t, ok, "override label %v should be present", overrideLabel) + require.NotEmpty(t, override, "override label %v should be present", overrideLabel) + require.Equal(t, override, r.GetName(), "name should equal the %v override label", overrideLabel) +} + +// requireDiscoveredNameLabel is a test helper that requires a resource +// not have the originally "discovered" name as a label, and did not change its name. +func requireOverrideLabelSkipsRenaming(t *testing.T, r types.ResourceWithLabels, originalName, overrideLabel string) { + t.Helper() + requireOverrideLabelIsSet(t, r, overrideLabel) + got, gotOk := r.GetLabel(types.DiscoveredNameLabel) + require.False(t, gotOk, "should not have the original discovered name saved in a label") + require.Empty(t, got, "should not have the original discovered name saved in a label") + require.Equal(t, originalName, r.GetName(), + "should not have renamed the resource when override label %v is present", overrideLabel) +} + +func makeAuroraPrimaryDB(t *testing.T, name, region, accountID, overrideLabel string) types.Database { + t.Helper() + cluster := &rds.DBCluster{ + DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:%s:cluster:%v", region, accountID, name)), + DBClusterIdentifier: aws.String("cluster-1"), + DbClusterResourceId: aws.String("resource-1"), + IAMDatabaseAuthenticationEnabled: aws.Bool(true), + Engine: aws.String("aurora-mysql"), + EngineVersion: aws.String("8.0.0"), + Endpoint: aws.String("localhost"), + Port: aws.Int64(3306), + TagList: libcloudaws.LabelsToTags[rds.Tag](map[string]string{ + overrideLabel: name, + }), + } + database, err := services.NewDatabaseFromRDSCluster(cluster) + require.NoError(t, err) + return database +} + +func makeRDSInstanceDB(t *testing.T, name, region, accountID, overrideLabel string) types.Database { + t.Helper() + instance := &rds.DBInstance{ + DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:%s:db:%v", region, accountID, name)), + DBInstanceIdentifier: aws.String(name), + DbiResourceId: aws.String(uuid.New().String()), + Engine: aws.String(services.RDSEnginePostgres), + DBInstanceStatus: aws.String("available"), + Endpoint: &rds.Endpoint{ + Address: aws.String("localhost"), + Port: aws.Int64(5432), + }, + TagList: libcloudaws.LabelsToTags[rds.Tag](map[string]string{ + overrideLabel: name, + }), + } + database, err := services.NewDatabaseFromRDSInstance(instance) + require.NoError(t, err) + return database +} + +func makeAzureMySQLFlexDatabase(t *testing.T, name, region, group, subscription string) types.Database { + t.Helper() + resourceType := "Microsoft.DBforMySQL/flexibleServers" + id := fmt.Sprintf("/subscriptions/%v/resourceGroups/%v/providers/%v/%v", + subscription, + group, + resourceType, + name, + ) + + fqdn := name + ".mysql" + azureutils.DatabaseEndpointSuffix + state := armmysqlflexibleservers.ServerStateReady + version := armmysqlflexibleservers.ServerVersionEight021 + server := &armmysqlflexibleservers.Server{ + Location: ®ion, + Properties: &armmysqlflexibleservers.ServerProperties{ + FullyQualifiedDomainName: &fqdn, + State: &state, + Version: &version, + }, + Tags: labelsToAzureTags(map[string]string{ + types.AzureDatabaseNameOverrideLabel: name, + }), + ID: &id, + Name: &name, + Type: &resourceType, + } + database, err := services.NewDatabaseFromAzureMySQLFlexServer(server) + require.NoError(t, err) + return database +} + +func makeAzureRedisDB(t *testing.T, name, region, group, subscription string) types.Database { + id := fmt.Sprintf("/subscriptions/%v/resourceGroups/%v/providers/Microsoft.Cache/Redis/%v", subscription, group, name) + resourceInfo := &armredis.ResourceInfo{ + Name: to.Ptr(name), + ID: to.Ptr(id), + Location: to.Ptr(region), + Tags: labelsToAzureTags(map[string]string{ + types.AzureDatabaseNameOverrideLabel: name, + }), + Properties: &armredis.Properties{ + HostName: to.Ptr(fmt.Sprintf("%v.redis.cache.windows.net", name)), + SSLPort: to.Ptr(int32(6380)), + ProvisioningState: to.Ptr(armredis.ProvisioningStateSucceeded), + RedisVersion: to.Ptr("6.0"), + }, + } + database, err := services.NewDatabaseFromAzureRedis(resourceInfo) + require.NoError(t, err) + return database +} + +func makeAzureRedisEnterpriseDB(t *testing.T, name, region, group, subscription string) types.Database { + clusterID := fmt.Sprintf("/subscriptions/%v/resourceGroups/%v/providers/Microsoft.Cache/redisEnterprise/%v", subscription, group, name) + databaseID := fmt.Sprintf("%v/databases/default", clusterID) + armCluster := &armredisenterprise.Cluster{ + Name: to.Ptr(name), + ID: to.Ptr(clusterID), + Location: to.Ptr(region), + Tags: labelsToAzureTags(map[string]string{ + types.AzureDatabaseNameOverrideLabel: name, + }), + Properties: &armredisenterprise.ClusterProperties{ + HostName: to.Ptr(fmt.Sprintf("%v.%v.redisenterprise.cache.azure.net", name, region)), + RedisVersion: to.Ptr("6.0"), + }, + } + armDatabase := &armredisenterprise.Database{ + Name: to.Ptr("default"), + ID: to.Ptr(databaseID), + Properties: &armredisenterprise.DatabaseProperties{ + ProvisioningState: to.Ptr(armredisenterprise.ProvisioningStateSucceeded), + Port: to.Ptr(int32(10000)), + ClusteringPolicy: to.Ptr(armredisenterprise.ClusteringPolicyOSSCluster), + ClientProtocol: to.Ptr(armredisenterprise.ProtocolEncrypted), + }, + } + database, err := services.NewDatabaseFromAzureRedisEnterprise(armCluster, armDatabase) + require.NoError(t, err) + return database +} + +func labelsToAzureTags(labels map[string]string) map[string]*string { + tags := make(map[string]*string, len(labels)) + for k, v := range labels { + v := v + tags[k] = &v + } + return tags +} + +func makeEKSKubeCluster(t *testing.T, name, region, accountID, overrideLabel string) types.KubeCluster { + t.Helper() + eksCluster := &eks.Cluster{ + Name: aws.String(name), + Arn: aws.String(fmt.Sprintf("arn:aws:eks:%s:%s:cluster/%s", region, accountID, name)), + Status: aws.String(eks.ClusterStatusActive), + Tags: map[string]*string{ + overrideLabel: aws.String(name), + }, + } + kubeCluster, err := services.NewKubeClusterFromAWSEKS(eksCluster) + require.NoError(t, err) + require.True(t, kubeCluster.IsAWS()) + return kubeCluster +} + +func makeAKSKubeCluster(t *testing.T, name, location, group, subID string) types.KubeCluster { + t.Helper() + aksCluster := &azure.AKSCluster{ + Name: name, + GroupName: group, + TenantID: "tenantID", + Location: location, + SubscriptionID: subID, + Tags: map[string]string{ + types.AzureKubeClusterNameOverrideLabel: name, + }, + Properties: azure.AKSClusterProperties{}, + } + kubeCluster, err := services.NewKubeClusterFromAzureAKS(aksCluster) + require.NoError(t, err) + require.True(t, kubeCluster.IsAzure()) + return kubeCluster +} + +func makeGKEKubeCluster(t *testing.T, name, location, projectID string) types.KubeCluster { + gkeCluster := gcp.GKECluster{ + Name: name, + Status: containerpb.Cluster_RUNNING, + Labels: map[string]string{ + types.GCPKubeClusterNameOverrideLabel: name, + }, + ProjectID: projectID, + Location: location, + Description: "desc1", + } + + kubeCluster, err := services.NewKubeClusterFromGCPGKE(gkeCluster) + require.NoError(t, err) + require.True(t, kubeCluster.IsGCP()) + return kubeCluster +} diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index 20623d30dc531..2619b1aa9ad8a 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -746,10 +746,8 @@ func (s *Server) Start() error { if s.gcpWatcher != nil { go s.handleGCPDiscovery() } - if len(s.kubeFetchers) > 0 { - if err := s.startKubeWatchers(); err != nil { - return trace.Wrap(err) - } + if err := s.startKubeWatchers(); err != nil { + return trace.Wrap(err) } if err := s.startDatabaseWatchers(); err != nil { return trace.Wrap(err) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 80a516c31f0ec..041a72ecd97c7 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -62,6 +62,7 @@ import ( "github.com/gravitational/teleport/lib/cloud/mocks" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" "github.com/gravitational/teleport/lib/srv/server" ) @@ -506,7 +507,7 @@ func TestDiscoveryKube(t *testing.T) { mustConvertEKSToKubeCluster(t, eksMockClusters[0], mainDiscoveryGroup), mustConvertEKSToKubeCluster(t, eksMockClusters[1], mainDiscoveryGroup), }, - clustersNotUpdated: []string{"eks-cluster1"}, + clustersNotUpdated: []string{mustConvertEKSToKubeCluster(t, eksMockClusters[0], mainDiscoveryGroup).GetName()}, }, { name: "1 cluster in auth that belongs the same discovery group but has unmatched labels + import 2 prod clusters from EKS", @@ -593,7 +594,7 @@ func TestDiscoveryKube(t *testing.T) { mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][0], mainDiscoveryGroup), mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][1], mainDiscoveryGroup), }, - clustersNotUpdated: []string{"aks-cluster1"}, + clustersNotUpdated: []string{mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][0], mainDiscoveryGroup).GetName()}, }, { name: "no clusters in auth server, import 2 prod clusters from GKE", @@ -895,6 +896,7 @@ func mustConvertEKSToKubeCluster(t *testing.T, eksCluster *eks.Cluster, discover cluster, err := services.NewKubeClusterFromAWSEKS(eksCluster) require.NoError(t, err) cluster.GetStaticLabels()[types.TeleportInternalDiscoveryGroupName] = discoveryGroup + common.ApplyEKSNameSuffix(cluster) return cluster } @@ -902,6 +904,7 @@ func mustConvertAKSToKubeCluster(t *testing.T, azureCluster *azure.AKSCluster, d cluster, err := services.NewKubeClusterFromAzureAKS(azureCluster) require.NoError(t, err) cluster.GetStaticLabels()[types.TeleportInternalDiscoveryGroupName] = discoveryGroup + common.ApplyAKSNameSuffix(cluster) return cluster } @@ -975,6 +978,7 @@ func mustConvertGKEToKubeCluster(t *testing.T, gkeCluster gcp.GKECluster, discov cluster, err := services.NewKubeClusterFromGCPGKE(gkeCluster) require.NoError(t, err) cluster.GetStaticLabels()[types.TeleportInternalDiscoveryGroupName] = discoveryGroup + common.ApplyGKENameSuffix(cluster) return cluster } @@ -1061,7 +1065,7 @@ func TestDiscoveryDatabase(t *testing.T) { name: "update existing database", existingDatabases: []types.Database{ mustNewDatabase(t, types.Metadata{ - Name: "aws-redshift", + Name: awsRedshiftDB.GetName(), Description: "should be updated", Labels: map[string]string{types.OriginLabel: types.OriginCloud, types.TeleportInternalDiscoveryGroupName: mainDiscoveryGroup}, }, types.DatabaseSpecV3{ @@ -1085,7 +1089,7 @@ func TestDiscoveryDatabase(t *testing.T) { name: "update existing database with assumed role", existingDatabases: []types.Database{ mustNewDatabase(t, types.Metadata{ - Name: "aws-rds", + Name: awsRDSDBWithRole.GetName(), Description: "should be updated", Labels: map[string]string{types.OriginLabel: types.OriginCloud, types.TeleportInternalDiscoveryGroupName: mainDiscoveryGroup}, }, types.DatabaseSpecV3{ @@ -1105,7 +1109,7 @@ func TestDiscoveryDatabase(t *testing.T) { name: "delete existing database", existingDatabases: []types.Database{ mustNewDatabase(t, types.Metadata{ - Name: "aws-redshift", + Name: awsRedshiftDB.GetName(), Description: "should not be deleted", Labels: map[string]string{types.OriginLabel: types.OriginCloud}, }, types.DatabaseSpecV3{ @@ -1120,7 +1124,7 @@ func TestDiscoveryDatabase(t *testing.T) { }}, expectDatabases: []types.Database{ mustNewDatabase(t, types.Metadata{ - Name: "aws-redshift", + Name: awsRedshiftDB.GetName(), Description: "should not be deleted", Labels: map[string]string{types.OriginLabel: types.OriginCloud}, }, types.DatabaseSpecV3{ @@ -1241,6 +1245,7 @@ func makeRDSInstance(t *testing.T, name, region string, discoveryGroup string) ( staticLabels := database.GetStaticLabels() staticLabels[types.TeleportInternalDiscoveryGroupName] = discoveryGroup database.SetStaticLabels(staticLabels) + common.ApplyAWSDatabaseNameSuffix(database, services.AWSMatcherRDS) return instance, database } @@ -1262,6 +1267,7 @@ func makeRedshiftCluster(t *testing.T, name, region string, discoveryGroup strin staticLabels := database.GetStaticLabels() staticLabels[types.TeleportInternalDiscoveryGroupName] = discoveryGroup database.SetStaticLabels(staticLabels) + common.ApplyAWSDatabaseNameSuffix(database, services.AWSMatcherRedshift) return cluster, database } @@ -1284,6 +1290,7 @@ func makeAzureRedisServer(t *testing.T, name, subscription, group, region string staticLabels := database.GetStaticLabels() staticLabels[types.TeleportInternalDiscoveryGroupName] = discoveryGroup database.SetStaticLabels(staticLabels) + common.ApplyAzureDatabaseNameSuffix(database, services.AzureMatcherRedis) return resourceInfo, database } diff --git a/lib/srv/discovery/fetchers/aks.go b/lib/srv/discovery/fetchers/aks.go index a9eb2af088c4c..a239f7ab310df 100644 --- a/lib/srv/discovery/fetchers/aks.go +++ b/lib/srv/discovery/fetchers/aks.go @@ -103,9 +103,18 @@ func (a *aksFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) kubeClusters = append(kubeClusters, kubeCluster) } + + a.rewriteKubeClusters(kubeClusters) return kubeClusters.AsResources(), nil } +// rewriteKubeClusters rewrites the discovered kube clusters. +func (a *aksFetcher) rewriteKubeClusters(clusters types.KubeClusters) { + for _, c := range clusters { + common.ApplyAKSNameSuffix(c) + } +} + func (a *aksFetcher) getAKSClusters(ctx context.Context) ([]*azure.AKSCluster, error) { var ( clusters []*azure.AKSCluster diff --git a/lib/srv/discovery/fetchers/aks_test.go b/lib/srv/discovery/fetchers/aks_test.go index 5dd1b65b65845..35ce44972dbf9 100644 --- a/lib/srv/discovery/fetchers/aks_test.go +++ b/lib/srv/discovery/fetchers/aks_test.go @@ -26,6 +26,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) func TestAKSFetcher(t *testing.T) { @@ -211,6 +212,7 @@ func aksClustersToResources(t *testing.T, clusters ...*azure.AKSCluster) types.R kubeCluster, err := services.NewKubeClusterFromAzureAKS(cluster) require.NoError(t, err) require.True(t, kubeCluster.IsAzure()) + common.ApplyAKSNameSuffix(kubeCluster) kubeClusters = append(kubeClusters, kubeCluster) } return kubeClusters.AsResources() diff --git a/lib/srv/discovery/fetchers/db/aws.go b/lib/srv/discovery/fetchers/db/aws.go index f3e900358c258..e3a01a8abf99a 100644 --- a/lib/srv/discovery/fetchers/db/aws.go +++ b/lib/srv/discovery/fetchers/db/aws.go @@ -17,11 +17,122 @@ limitations under the License. package db import ( + "context" + "fmt" + + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) +// awsFetcherPlugin defines an interface that provides database type specific +// functions for use by the common AWS database fetcher. +type awsFetcherPlugin interface { + // GetDatabases fetches databases from AWS API and converts the results to + // Teleport types.Databases. + GetDatabases(context.Context, *awsFetcherConfig) (types.Databases, error) + // ComponentShortName provides the plugin's short component name for + // logging purposes. + ComponentShortName() string +} + +// awsFetcherConfig is the AWS database fetcher configuration. +type awsFetcherConfig struct { + // AWSClients are the AWS API clients. + AWSClients cloud.AWSClients + // Type is the type of DB matcher, for example "rds", "redshift", etc. + Type string + // AssumeRole provides a role ARN and ExternalID to assume an AWS role + // when fetching databases. + AssumeRole types.AssumeRole + // Labels is a selector to match cloud database tags. + Labels types.Labels + // Region is the AWS region selector to match cloud databases. + Region string + // Log is a field logger to provide structured logging for each matcher, + // based on its config settings by default. + Log logrus.FieldLogger +} + +// CheckAndSetDefaults validates the config and sets defaults. +func (cfg *awsFetcherConfig) CheckAndSetDefaults(component string) error { + if cfg.AWSClients == nil { + return trace.BadParameter("missing parameter AWSClients") + } + if cfg.Type == "" { + return trace.BadParameter("missing parameter Type") + } + if len(cfg.Labels) == 0 { + return trace.BadParameter("missing parameter Labels") + } + if cfg.Region == "" { + return trace.BadParameter("missing parameter Region") + } + if cfg.Log == nil { + cfg.Log = logrus.WithFields(logrus.Fields{ + trace.Component: "watch:" + component, + "labels": cfg.Labels, + "region": cfg.Region, + "role": cfg.AssumeRole, + }) + } + return nil +} + +// newAWSFetcher returns a AWS database fetcher for the provided selectors +// and AWS database-type specific fetcher plugin. +func newAWSFetcher(cfg awsFetcherConfig, plugin awsFetcherPlugin) (*awsFetcher, error) { + if err := cfg.CheckAndSetDefaults(plugin.ComponentShortName()); err != nil { + return nil, trace.Wrap(err) + } + return &awsFetcher{cfg: cfg, plugin: plugin}, nil +} + // awsFetcher is the common base for AWS database fetchers. type awsFetcher struct { + // cfg is the awsFetcher configuration. + cfg awsFetcherConfig + // plugin does AWS database type specific API calls fetch databases. + plugin awsFetcherPlugin +} + +// awsFetcher implements common.Fetcher. +var _ common.Fetcher = (*awsFetcher)(nil) + +// Get returns AWS databases matching the fetcher's selectors. +func (f *awsFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { + databases, err := f.getDatabases(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + f.rewriteDatabases(databases) + return databases.AsResources(), nil +} + +func (f *awsFetcher) getDatabases(ctx context.Context) (types.Databases, error) { + databases, err := f.plugin.GetDatabases(ctx, &f.cfg) + if err != nil { + return nil, trace.Wrap(err) + } + return filterDatabasesByLabels(databases, f.cfg.Labels, f.cfg.Log), nil +} + +// rewriteDatabases rewrites the discovered databases. +func (f *awsFetcher) rewriteDatabases(databases types.Databases) { + for _, db := range databases { + f.applyAssumeRole(db) + common.ApplyAWSDatabaseNameSuffix(db, f.cfg.Type) + } +} + +// applyAssumeRole sets the database AWS AssumeRole metadata to match the +// fetcher's setting. +func (f *awsFetcher) applyAssumeRole(db types.Database) { + db.SetAWSAssumeRole(f.cfg.AssumeRole.RoleARN) + db.SetAWSExternalID(f.cfg.AssumeRole.ExternalID) } // Cloud returns the cloud the fetcher is operating. @@ -34,6 +145,12 @@ func (f *awsFetcher) ResourceType() string { return types.KindDatabase } +// String returns the fetcher's string description. +func (f *awsFetcher) String() string { + return fmt.Sprintf("awsFetcher(Type: %v, Region=%v, Labels=%v)", + f.cfg.Type, f.cfg.Region, f.cfg.Labels) +} + // maxAWSPages is the maximum number of pages to iterate over when fetching aws // databases. const maxAWSPages = 10 diff --git a/lib/srv/discovery/fetchers/db/aws_elasticache.go b/lib/srv/discovery/fetchers/db/aws_elasticache.go index fab1883190377..067e9da39a4f0 100644 --- a/lib/srv/discovery/fetchers/db/aws_elasticache.go +++ b/lib/srv/discovery/fetchers/db/aws_elasticache.go @@ -17,75 +17,41 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// elastiCacheFetcherConfig is the ElastiCache databases fetcher configuration. -type elastiCacheFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // ElastiCache is the ElastiCache API client. - ElastiCache elasticacheiface.ElastiCacheAPI - // Region is the AWS region to query databases in. - Region string - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole +// newElastiCacheFetcher returns a new AWS fetcher for ElastiCache databases. +func newElastiCacheFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &elastiCachePlugin{}) } -// CheckAndSetDefaults validates the config and sets defaults. -func (c *elastiCacheFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.ElastiCache == nil { - return trace.BadParameter("missing parameter ElastiCache") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - return nil -} +// elastiCachePlugin retrieves ElastiCache Redis databases. +type elastiCachePlugin struct{} -// elastiCacheFetcher retrieves ElastiCache Redis databases. -type elastiCacheFetcher struct { - awsFetcher - - cfg elastiCacheFetcherConfig - log logrus.FieldLogger +func (f *elastiCachePlugin) ComponentShortName() string { + return "elasticache" } -// newElastiCacheFetcher returns a new ElastiCache databases fetcher instance. -func newElastiCacheFetcher(config elastiCacheFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - return &elastiCacheFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:elasticache", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns ElastiCache Redis databases matching the watcher's selectors. +// GetDatabases returns ElastiCache Redis databases matching the watcher's selectors. // // TODO(greedy52) support ElastiCache global datastore. -func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - clusters, err := getElastiCacheClusters(ctx, f.cfg.ElastiCache) +func (f *elastiCachePlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + ecClient, err := cfg.AWSClients.GetAWSElastiCacheClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) + if err != nil { + return nil, trace.Wrap(err) + } + clusters, err := getElastiCacheClusters(ctx, ecClient) if err != nil { return nil, trace.Wrap(err) } @@ -93,12 +59,12 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels var eligibleClusters []*elasticache.ReplicationGroup for _, cluster := range clusters { if !services.IsElastiCacheClusterSupported(cluster) { - f.log.Debugf("ElastiCache cluster %q is not supported. Skipping.", aws.StringValue(cluster.ReplicationGroupId)) + cfg.Log.Debugf("ElastiCache cluster %q is not supported. Skipping.", aws.StringValue(cluster.ReplicationGroupId)) continue } if !services.IsElastiCacheClusterAvailable(cluster) { - f.log.Debugf("The current status of ElastiCache cluster %q is %q. Skipping.", + cfg.Log.Debugf("The current status of ElastiCache cluster %q is %q. Skipping.", aws.StringValue(cluster.ReplicationGroupId), aws.StringValue(cluster.Status)) continue @@ -108,25 +74,25 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels } if len(eligibleClusters) == 0 { - return types.ResourcesWithLabels{}, nil + return nil, nil } // Fetch more information to provide extra labels. Do not fail because some // of these labels are missing. - allNodes, err := getElastiCacheNodes(ctx, f.cfg.ElastiCache) + allNodes, err := getElastiCacheNodes(ctx, ecClient) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to describe nodes") + cfg.Log.WithError(err).Debug("No permissions to describe nodes") } else { - f.log.WithError(err).Info("Failed to describe nodes.") + cfg.Log.WithError(err).Info("Failed to describe nodes.") } } - allSubnetGroups, err := getElastiCacheSubnetGroups(ctx, f.cfg.ElastiCache) + allSubnetGroups, err := getElastiCacheSubnetGroups(ctx, ecClient) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to describe subnet groups") + cfg.Log.WithError(err).Debug("No permissions to describe subnet groups") } else { - f.log.WithError(err).Info("Failed to describe subnet groups.") + cfg.Log.WithError(err).Info("Failed to describe subnet groups.") } } @@ -135,12 +101,12 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels // Resource tags are not found in elasticache.ReplicationGroup but can // be on obtained by elasticache.ListTagsForResource (one call per // resource). - tags, err := getElastiCacheResourceTags(ctx, f.cfg.ElastiCache, cluster.ARN) + tags, err := getElastiCacheResourceTags(ctx, ecClient, cluster.ARN) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to list resource tags") + cfg.Log.WithError(err).Debug("No permissions to list resource tags") } else { - f.log.WithError(err).Infof("Failed to list resource tags for ElastiCache cluster %q.", aws.StringValue(cluster.ReplicationGroupId)) + cfg.Log.WithError(err).Infof("Failed to list resource tags for ElastiCache cluster %q.", aws.StringValue(cluster.ReplicationGroupId)) } } @@ -150,7 +116,7 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels // mode enabled. if aws.BoolValue(cluster.ClusterEnabled) { if database, err := services.NewDatabaseFromElastiCacheConfigurationEndpoint(cluster, extraLabels); err != nil { - f.log.Infof("Could not convert ElastiCache cluster %q configuration endpoint to database resource: %v.", + cfg.Log.Infof("Could not convert ElastiCache cluster %q configuration endpoint to database resource: %v.", aws.StringValue(cluster.ReplicationGroupId), err) } else { databases = append(databases, database) @@ -164,21 +130,13 @@ func (f *elastiCacheFetcher) Get(ctx context.Context) (types.ResourcesWithLabels // there is only one node group (aka shard) with one primary endpoint // and one reader endpoint. if databasesFromNodeGroups, err := services.NewDatabasesFromElastiCacheNodeGroups(cluster, extraLabels); err != nil { - f.log.Infof("Could not convert ElastiCache cluster %q node groups to database resources: %v.", + cfg.Log.Infof("Could not convert ElastiCache cluster %q node groups to database resources: %v.", aws.StringValue(cluster.ReplicationGroupId), err) } else { databases = append(databases, databasesFromNodeGroups...) } } - - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil -} - -// String returns the fetcher's string description. -func (f *elastiCacheFetcher) String() string { - return fmt.Sprintf("elastiCacheFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) + return databases, nil } // getElastiCacheClusters fetches all ElastiCache replication groups. diff --git a/lib/srv/discovery/fetchers/db/aws_elasticache_test.go b/lib/srv/discovery/fetchers/db/aws_elasticache_test.go index 7afb22efe3acd..613a54627e715 100644 --- a/lib/srv/discovery/fetchers/db/aws_elasticache_test.go +++ b/lib/srv/discovery/fetchers/db/aws_elasticache_test.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) func TestElastiCacheFetcher(t *testing.T) { @@ -126,12 +127,14 @@ func makeElastiCacheCluster(t *testing.T, name, region, env string, opts ...func if aws.BoolValue(cluster.ClusterEnabled) { database, err := services.NewDatabaseFromElastiCacheConfigurationEndpoint(cluster, extraLabels) require.NoError(t, err) + common.ApplyAWSDatabaseNameSuffix(database, services.AWSMatcherElastiCache) return cluster, database, tags } databases, err := services.NewDatabasesFromElastiCacheNodeGroups(cluster, extraLabels) require.NoError(t, err) require.Len(t, databases, 1) + common.ApplyAWSDatabaseNameSuffix(databases[0], services.AWSMatcherElastiCache) return cluster, databases[0], tags } diff --git a/lib/srv/discovery/fetchers/db/aws_memorydb.go b/lib/srv/discovery/fetchers/db/aws_memorydb.go index f6732aeb9403f..11efb3f51662f 100644 --- a/lib/srv/discovery/fetchers/db/aws_memorydb.go +++ b/lib/srv/discovery/fetchers/db/aws_memorydb.go @@ -17,73 +17,39 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/memorydb/memorydbiface" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// memoryDBFetcherConfig is the MemoryDB databases fetcher configuration. -type memoryDBFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // MemoryDB is the MemoryDB API client. - MemoryDB memorydbiface.MemoryDBAPI - // Region is the AWS region to query databases in. - Region string - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole -} +// memoryDBPlugin retrieves MemoryDB Redis databases. +type memoryDBPlugin struct{} -// CheckAndSetDefaults validates the config and sets defaults. -func (c *memoryDBFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.MemoryDB == nil { - return trace.BadParameter("missing parameter MemoryDB") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - return nil +// newMemoryDBFetcher returns a new AWS fetcher for MemoryDB databases. +func newMemoryDBFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &memoryDBPlugin{}) } -// memoryDBFetcher retrieves MemoryDB Redis databases. -type memoryDBFetcher struct { - awsFetcher - - cfg memoryDBFetcherConfig - log logrus.FieldLogger +func (f *memoryDBPlugin) ComponentShortName() string { + return "memorydb" } -// newMemoryDBFetcher returns a new MemoryDB databases fetcher instance. -func newMemoryDBFetcher(config memoryDBFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { +// GetDatabases returns MemoryDB databases matching the watcher's selectors. +func (f *memoryDBPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + memDBClient, err := cfg.AWSClients.GetAWSMemoryDBClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) + if err != nil { return nil, trace.Wrap(err) } - return &memoryDBFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:memorydb", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns MemoryDB databases matching the watcher's selectors. -func (f *memoryDBFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - clusters, err := getMemoryDBClusters(ctx, f.cfg.MemoryDB) + clusters, err := getMemoryDBClusters(ctx, memDBClient) if err != nil { return nil, trace.Wrap(err) } @@ -91,12 +57,12 @@ func (f *memoryDBFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e var eligibleClusters []*memorydb.Cluster for _, cluster := range clusters { if !services.IsMemoryDBClusterSupported(cluster) { - f.log.Debugf("MemoryDB cluster %q is not supported. Skipping.", aws.StringValue(cluster.Name)) + cfg.Log.Debugf("MemoryDB cluster %q is not supported. Skipping.", aws.StringValue(cluster.Name)) continue } if !services.IsMemoryDBClusterAvailable(cluster) { - f.log.Debugf("The current status of MemoryDB cluster %q is %q. Skipping.", + cfg.Log.Debugf("The current status of MemoryDB cluster %q is %q. Skipping.", aws.StringValue(cluster.Name), aws.StringValue(cluster.Status)) continue @@ -106,47 +72,40 @@ func (f *memoryDBFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e } if len(eligibleClusters) == 0 { - return types.ResourcesWithLabels{}, nil + return nil, nil } // Fetch more information to provide extra labels. Do not fail because some // of these labels are missing. - allSubnetGroups, err := getMemoryDBSubnetGroups(ctx, f.cfg.MemoryDB) + allSubnetGroups, err := getMemoryDBSubnetGroups(ctx, memDBClient) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to describe subnet groups") + cfg.Log.WithError(err).Debug("No permissions to describe subnet groups") } else { - f.log.WithError(err).Info("Failed to describe subnet groups.") + cfg.Log.WithError(err).Info("Failed to describe subnet groups.") } } var databases types.Databases for _, cluster := range eligibleClusters { - tags, err := getMemoryDBResourceTags(ctx, f.cfg.MemoryDB, cluster.ARN) + tags, err := getMemoryDBResourceTags(ctx, memDBClient, cluster.ARN) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to list resource tags") + cfg.Log.WithError(err).Debug("No permissions to list resource tags") } else { - f.log.WithError(err).Infof("Failed to list resource tags for MemoryDB cluster %q.", aws.StringValue(cluster.Name)) + cfg.Log.WithError(err).Infof("Failed to list resource tags for MemoryDB cluster %q.", aws.StringValue(cluster.Name)) } } extraLabels := services.ExtraMemoryDBLabels(cluster, tags, allSubnetGroups) database, err := services.NewDatabaseFromMemoryDBCluster(cluster, extraLabels) if err != nil { - f.log.WithError(err).Infof("Could not convert memorydb cluster %q configuration endpoint to database resource.", aws.StringValue(cluster.Name)) + cfg.Log.WithError(err).Infof("Could not convert memorydb cluster %q configuration endpoint to database resource.", aws.StringValue(cluster.Name)) } else { databases = append(databases, database) } } - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil -} - -// String returns the fetcher's string description. -func (f *memoryDBFetcher) String() string { - return fmt.Sprintf("memorydbFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) + return databases, nil } // getMemoryDBClusters fetches all MemoryDB clusters. diff --git a/lib/srv/discovery/fetchers/db/aws_memorydb_test.go b/lib/srv/discovery/fetchers/db/aws_memorydb_test.go index ab9340a6eac66..e389cdcca1fd0 100644 --- a/lib/srv/discovery/fetchers/db/aws_memorydb_test.go +++ b/lib/srv/discovery/fetchers/db/aws_memorydb_test.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) func TestMemoryDBFetcher(t *testing.T) { @@ -120,5 +121,6 @@ func makeMemoryDBCluster(t *testing.T, name, region, env string, opts ...func(*m database, err := services.NewDatabaseFromMemoryDBCluster(cluster, extraLabels) require.NoError(t, err) + common.ApplyAWSDatabaseNameSuffix(database, services.AWSMatcherMemoryDB) return cluster, database, tags } diff --git a/lib/srv/discovery/fetchers/db/aws_opensearch.go b/lib/srv/discovery/fetchers/db/aws_opensearch.go index 4e491b9b1c68e..f47720029d3f0 100644 --- a/lib/srv/discovery/fetchers/db/aws_opensearch.go +++ b/lib/srv/discovery/fetchers/db/aws_opensearch.go @@ -16,80 +16,46 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/opensearchservice" "github.com/aws/aws-sdk-go/service/opensearchservice/opensearchserviceiface" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// openSearchFetcherConfig is the OpenSearch databases fetcher configuration. -type openSearchFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // openSearch is the OpenSearch API client. - openSearch opensearchserviceiface.OpenSearchServiceAPI - // Region is the AWS region to query databases in. - Region string - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole +// newOpenSearchFetcher returns a new AWS fetcher for OpenSearch databases. +func newOpenSearchFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &openSearchPlugin{}) } -// CheckAndSetDefaults validates the config and sets defaults. -func (c *openSearchFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.openSearch == nil { - return trace.BadParameter("missing parameter openSearch") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - return nil -} - -// openSearchFetcher retrieves OpenSearch databases. -type openSearchFetcher struct { - awsFetcher +// openSearchPlugin retrieves OpenSearch databases. +type openSearchPlugin struct{} - cfg openSearchFetcherConfig - log logrus.FieldLogger +func (f *openSearchPlugin) ComponentShortName() string { + return "opensearch" } -// newOpenSearchFetcher returns a new OpenSearch databases fetcher instance. -func newOpenSearchFetcher(config openSearchFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { +// GetDatabases returns OpenSearch databases. +func (f *openSearchPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + opensearchClient, err := cfg.AWSClients.GetAWSOpenSearchClient(ctx, + cfg.Region, cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) + if err != nil { return nil, trace.Wrap(err) } - return &openSearchFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:opensearch", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns OpenSearch databases matching the watcher's selectors. -func (f *openSearchFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - domains, err := getOpenSearchDomains(ctx, f.cfg.openSearch) + domains, err := getOpenSearchDomains(ctx, opensearchClient) if err != nil { return nil, trace.Wrap(err) } var eligibleDomains []*opensearchservice.DomainStatus for _, domain := range domains { if !services.IsOpenSearchDomainAvailable(domain) { - f.log.Debugf("OpenSearch domain %q is unavailable. Skipping.", aws.StringValue(domain.DomainName)) + cfg.Log.Debugf("OpenSearch domain %q is unavailable. Skipping.", aws.StringValue(domain.DomainName)) continue } @@ -97,37 +63,29 @@ func (f *openSearchFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, } if len(eligibleDomains) == 0 { - return types.ResourcesWithLabels{}, nil + return nil, nil } var databases types.Databases for _, domain := range eligibleDomains { - tags, err := getOpenSearchResourceTags(ctx, f.cfg.openSearch, domain.ARN) + tags, err := getOpenSearchResourceTags(ctx, opensearchClient, domain.ARN) if err != nil { if trace.IsAccessDenied(err) { - f.log.WithError(err).Debug("No permissions to list resource tags") + cfg.Log.WithError(err).Debug("No permissions to list resource tags") } else { - f.log.WithError(err).Infof("Failed to list resource tags for OpenSearch domain %q.", aws.StringValue(domain.DomainName)) + cfg.Log.WithError(err).Infof("Failed to list resource tags for OpenSearch domain %q.", aws.StringValue(domain.DomainName)) } } - dbs, err := services.NewDatabaseFromOpenSearchDomain(domain, tags) + dbs, err := services.NewDatabasesFromOpenSearchDomain(domain, tags) if err != nil { - f.log.WithError(err).Infof("Could not convert OpenSearch domain %q configuration to database resource.", aws.StringValue(domain.DomainName)) + cfg.Log.WithError(err).Infof("Could not convert OpenSearch domain %q configuration to database resource.", aws.StringValue(domain.DomainName)) } else { databases = append(databases, dbs...) } } - - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil -} - -// String returns the fetcher's string description. -func (f *openSearchFetcher) String() string { - return fmt.Sprintf("openSearchFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) + return databases, nil } // getOpenSearchDomains fetches all OpenSearch domains. diff --git a/lib/srv/discovery/fetchers/db/aws_opensearch_test.go b/lib/srv/discovery/fetchers/db/aws_opensearch_test.go index 3381a1c701069..a92dce61eb5a0 100644 --- a/lib/srv/discovery/fetchers/db/aws_opensearch_test.go +++ b/lib/srv/discovery/fetchers/db/aws_opensearch_test.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) func TestOpenSearchFetcher(t *testing.T) { @@ -150,8 +151,11 @@ func makeOpenSearchDomain(t *testing.T, tagMap map[string][]*opensearchservice.T tagMap[aws.StringValue(domain.ARN)] = tags - database, err := services.NewDatabaseFromOpenSearchDomain(domain, tags) + databases, err := services.NewDatabasesFromOpenSearchDomain(domain, tags) require.NoError(t, err) - return domain, database + for _, db := range databases { + common.ApplyAWSDatabaseNameSuffix(db, services.AWSMatcherOpenSearch) + } + return domain, databases } diff --git a/lib/srv/discovery/fetchers/db/aws_rds.go b/lib/srv/discovery/fetchers/db/aws_rds.go index 0d8aacf3cc7ca..a2d3ccee9c84b 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds.go +++ b/lib/srv/discovery/fetchers/db/aws_rds.go @@ -18,7 +18,6 @@ package db import ( "context" - "fmt" "strings" "github.com/aws/aws-sdk-go/aws" @@ -28,82 +27,39 @@ import ( "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// rdsFetcherConfig is the RDS databases fetcher configuration. -type rdsFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // RDS is the RDS API client. - RDS rdsiface.RDSAPI - // Region is the AWS region to query databases in. - Region string - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole +// newRDSDBInstancesFetcher returns a new AWS fetcher for RDS databases. +func newRDSDBInstancesFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &rdsDBInstancesPlugin{}) } -// CheckAndSetDefaults validates the config and sets defaults. -func (c *rdsFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.RDS == nil { - return trace.BadParameter("missing parameter RDS") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - return nil -} - -// rdsDBInstancesFetcher retrieves RDS DB instances. -type rdsDBInstancesFetcher struct { - awsFetcher +// rdsDBInstancesPlugin retrieves RDS DB instances. +type rdsDBInstancesPlugin struct{} - cfg rdsFetcherConfig - log logrus.FieldLogger +func (f *rdsDBInstancesPlugin) ComponentShortName() string { + return "rds" } -// newRDSDBInstancesFetcher returns a new RDS DB instances fetcher instance. -func newRDSDBInstancesFetcher(config rdsFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - return &rdsDBInstancesFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:rds", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns RDS DB instances matching the watcher's selectors. -func (f *rdsDBInstancesFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - rdsDatabases, err := f.getRDSDatabases(ctx) +// GetDatabases returns a list of database resources representing RDS instances. +func (f *rdsDBInstancesPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } - - applyAssumeRoleToDatabases(rdsDatabases, f.cfg.AssumeRole) - return filterDatabasesByLabels(rdsDatabases, f.cfg.Labels, f.log).AsResources(), nil -} - -// getRDSDatabases returns a list of database resources representing RDS instances. -func (f *rdsDBInstancesFetcher) getRDSDatabases(ctx context.Context) (types.Databases, error) { - instances, err := getAllDBInstances(ctx, f.cfg.RDS, maxAWSPages, f.log) + instances, err := getAllDBInstances(ctx, rdsClient, maxAWSPages, cfg.Log) if err != nil { return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) } databases := make(types.Databases, 0, len(instances)) for _, instance := range instances { if !services.IsRDSInstanceSupported(instance) { - f.log.Debugf("RDS instance %q (engine mode %v, engine version %v) doesn't support IAM authentication. Skipping.", + cfg.Log.Debugf("RDS instance %q (engine mode %v, engine version %v) doesn't support IAM authentication. Skipping.", aws.StringValue(instance.DBInstanceIdentifier), aws.StringValue(instance.Engine), aws.StringValue(instance.EngineVersion)) @@ -111,7 +67,7 @@ func (f *rdsDBInstancesFetcher) getRDSDatabases(ctx context.Context) (types.Data } if !services.IsRDSInstanceAvailable(instance.DBInstanceStatus, instance.DBInstanceIdentifier) { - f.log.Debugf("The current status of RDS instance %q is %q. Skipping.", + cfg.Log.Debugf("The current status of RDS instance %q is %q. Skipping.", aws.StringValue(instance.DBInstanceIdentifier), aws.StringValue(instance.DBInstanceStatus)) continue @@ -119,7 +75,7 @@ func (f *rdsDBInstancesFetcher) getRDSDatabases(ctx context.Context) (types.Data database, err := services.NewDatabaseFromRDSInstance(instance) if err != nil { - f.log.Warnf("Could not convert RDS instance %q to database resource: %v.", + cfg.Log.Warnf("Could not convert RDS instance %q to database resource: %v.", aws.StringValue(instance.DBInstanceIdentifier), err) } else { databases = append(databases, database) @@ -151,57 +107,34 @@ func getAllDBInstances(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages return instances, trace.Wrap(err) } -// String returns the fetcher's string description. -func (f *rdsDBInstancesFetcher) String() string { - return fmt.Sprintf("rdsDBInstancesFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) +// newRDSAuroraClustersFetcher returns a new AWS fetcher for RDS Aurora +// databases. +func newRDSAuroraClustersFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &rdsAuroraClustersPlugin{}) } -// rdsAuroraClustersFetcher retrieves RDS Aurora clusters. -type rdsAuroraClustersFetcher struct { - awsFetcher - - cfg rdsFetcherConfig - log logrus.FieldLogger -} +// rdsAuroraClustersPlugin retrieves RDS Aurora clusters. +type rdsAuroraClustersPlugin struct{} -// newRDSAuroraClustersFetcher returns a new RDS Aurora fetcher instance. -func newRDSAuroraClustersFetcher(config rdsFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - return &rdsAuroraClustersFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:aurora", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil +func (f *rdsAuroraClustersPlugin) ComponentShortName() string { + return "aurora" } -// Get returns Aurora clusters matching the watcher's selectors. -func (f *rdsAuroraClustersFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - auroraDatabases, err := f.getAuroraDatabases(ctx) +// GetDatabases returns a list of database resources representing RDS clusters. +func (f *rdsAuroraClustersPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } - - applyAssumeRoleToDatabases(auroraDatabases, f.cfg.AssumeRole) - return filterDatabasesByLabels(auroraDatabases, f.cfg.Labels, f.log).AsResources(), nil -} - -// getAuroraDatabases returns a list of database resources representing RDS clusters. -func (f *rdsAuroraClustersFetcher) getAuroraDatabases(ctx context.Context) (types.Databases, error) { - clusters, err := getAllDBClusters(ctx, f.cfg.RDS, maxAWSPages, f.log) + clusters, err := getAllDBClusters(ctx, rdsClient, maxAWSPages, cfg.Log) if err != nil { return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) } databases := make(types.Databases, 0, len(clusters)) for _, cluster := range clusters { if !services.IsRDSClusterSupported(cluster) { - f.log.Debugf("Aurora cluster %q (engine mode %v, engine version %v) doesn't support IAM authentication. Skipping.", + cfg.Log.Debugf("Aurora cluster %q (engine mode %v, engine version %v) doesn't support IAM authentication. Skipping.", aws.StringValue(cluster.DBClusterIdentifier), aws.StringValue(cluster.EngineMode), aws.StringValue(cluster.EngineVersion)) @@ -209,7 +142,7 @@ func (f *rdsAuroraClustersFetcher) getAuroraDatabases(ctx context.Context) (type } if !services.IsRDSClusterAvailable(cluster.Status, cluster.DBClusterIdentifier) { - f.log.Debugf("The current status of Aurora cluster %q is %q. Skipping.", + cfg.Log.Debugf("The current status of Aurora cluster %q is %q. Skipping.", aws.StringValue(cluster.DBClusterIdentifier), aws.StringValue(cluster.Status)) continue @@ -234,7 +167,7 @@ func (f *rdsAuroraClustersFetcher) getAuroraDatabases(ctx context.Context) (type if cluster.Endpoint != nil && hasWriterInstance { database, err := services.NewDatabaseFromRDSCluster(cluster) if err != nil { - f.log.Warnf("Could not convert RDS cluster %q to database resource: %v.", + cfg.Log.Warnf("Could not convert RDS cluster %q to database resource: %v.", aws.StringValue(cluster.DBClusterIdentifier), err) } else { databases = append(databases, database) @@ -246,7 +179,7 @@ func (f *rdsAuroraClustersFetcher) getAuroraDatabases(ctx context.Context) (type if cluster.ReaderEndpoint != nil && hasReaderInstance { database, err := services.NewDatabaseFromRDSClusterReaderEndpoint(cluster) if err != nil { - f.log.Warnf("Could not convert RDS cluster %q reader endpoint to database resource: %v.", + cfg.Log.Warnf("Could not convert RDS cluster %q reader endpoint to database resource: %v.", aws.StringValue(cluster.DBClusterIdentifier), err) } else { databases = append(databases, database) @@ -257,7 +190,7 @@ func (f *rdsAuroraClustersFetcher) getAuroraDatabases(ctx context.Context) (type if len(cluster.CustomEndpoints) > 0 { customEndpointDatabases, err := services.NewDatabasesFromRDSClusterCustomEndpoints(cluster) if err != nil { - f.log.Warnf("Could not convert RDS cluster %q custom endpoints to database resources: %v.", + cfg.Log.Warnf("Could not convert RDS cluster %q custom endpoints to database resources: %v.", aws.StringValue(cluster.DBClusterIdentifier), err) } @@ -290,12 +223,6 @@ func getAllDBClusters(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages i return clusters, trace.Wrap(err) } -// String returns the fetcher's string description. -func (f *rdsAuroraClustersFetcher) String() string { - return fmt.Sprintf("rdsAuroraClustersFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) -} - // rdsInstanceEngines returns engines to make sure DescribeDBInstances call returns // only databases with engines Teleport supports. func rdsInstanceEngines() []string { diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go index 8286a1c6c02bb..719924ed2d1d9 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go @@ -15,82 +15,62 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// rdsDBProxyFetcher retrieves RDS Proxies and their custom endpoints. -type rdsDBProxyFetcher struct { - awsFetcher - - cfg rdsFetcherConfig - log logrus.FieldLogger +// newRDSDBProxyFetcher returns a new AWS fetcher for RDS Proxy databases. +func newRDSDBProxyFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &rdsDBProxyPlugin{}) } -// newRDSDBProxyFetcher returns a new RDS Proxy fetcher instance. -func newRDSDBProxyFetcher(config rdsFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - return &rdsDBProxyFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:rdsproxy", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil +// rdsDBProxyPlugin retrieves RDS Proxies and their custom endpoints. +type rdsDBProxyPlugin struct{} + +func (f *rdsDBProxyPlugin) ComponentShortName() string { + return "rdsproxy" } -// Get returns RDS Proxies and proxy endpoints matching the watcher's -// selectors. -func (f *rdsDBProxyFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - databases, err := f.getRDSProxyDatabases(ctx) +// GetDatabases returns a list of database resources representing RDS +// Proxies and custom endpoints. +func (f *rdsDBProxyPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) if err != nil { return nil, trace.Wrap(err) } - - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil -} - -// getRDSProxyDatabases returns a list of database resources representing RDS -// Proxies and custom endpoints. -func (f *rdsDBProxyFetcher) getRDSProxyDatabases(ctx context.Context) (types.Databases, error) { // Get a list of all RDS Proxies. Each RDS Proxy has one "default" // endpoint. - rdsProxies, err := getRDSProxies(ctx, f.cfg.RDS, maxAWSPages) + rdsProxies, err := getRDSProxies(ctx, rdsClient, maxAWSPages) if err != nil { return nil, trace.Wrap(err) } // Get all RDS Proxy custom endpoints sorted by the name of the RDS Proxy // that owns the custom endpoints. - customEndpointsByProxyName, err := getRDSProxyCustomEndpoints(ctx, f.cfg.RDS, maxAWSPages) + customEndpointsByProxyName, err := getRDSProxyCustomEndpoints(ctx, rdsClient, maxAWSPages) if err != nil { - f.log.Debugf("Failed to get RDS Proxy endpoints: %v.", err) + cfg.Log.Debugf("Failed to get RDS Proxy endpoints: %v.", err) } var databases types.Databases for _, dbProxy := range rdsProxies { if !aws.BoolValue(dbProxy.RequireTLS) { - f.log.Debugf("RDS Proxy %q doesn't support TLS. Skipping.", aws.StringValue(dbProxy.DBProxyName)) + cfg.Log.Debugf("RDS Proxy %q doesn't support TLS. Skipping.", aws.StringValue(dbProxy.DBProxyName)) continue } if !services.IsRDSProxyAvailable(dbProxy) { - f.log.Debugf("The current status of RDS Proxy %q is %q. Skipping.", + cfg.Log.Debugf("The current status of RDS Proxy %q is %q. Skipping.", aws.StringValue(dbProxy.DBProxyName), aws.StringValue(dbProxy.Status)) continue @@ -98,23 +78,23 @@ func (f *rdsDBProxyFetcher) getRDSProxyDatabases(ctx context.Context) (types.Dat // rds.DBProxy has no port information. An extra SDK call is made to // find the port from its targets. - port, err := getRDSProxyTargetPort(ctx, f.cfg.RDS, dbProxy.DBProxyName) + port, err := getRDSProxyTargetPort(ctx, rdsClient, dbProxy.DBProxyName) if err != nil { - f.log.Debugf("Failed to get port for RDS Proxy %v: %v.", aws.StringValue(dbProxy.DBProxyName), err) + cfg.Log.Debugf("Failed to get port for RDS Proxy %v: %v.", aws.StringValue(dbProxy.DBProxyName), err) continue } // rds.DBProxy has no tags information. An extra SDK call is made to // fetch the tags. If failed, keep going without the tags. - tags, err := listRDSResourceTags(ctx, f.cfg.RDS, dbProxy.DBProxyArn) + tags, err := listRDSResourceTags(ctx, rdsClient, dbProxy.DBProxyArn) if err != nil { - f.log.Debugf("Failed to get tags for RDS Proxy %v: %v.", aws.StringValue(dbProxy.DBProxyName), err) + cfg.Log.Debugf("Failed to get tags for RDS Proxy %v: %v.", aws.StringValue(dbProxy.DBProxyName), err) } // Add a database from RDS Proxy (default endpoint). database, err := services.NewDatabaseFromRDSProxy(dbProxy, port, tags) if err != nil { - f.log.Debugf("Could not convert RDS Proxy %q to database resource: %v.", + cfg.Log.Debugf("Could not convert RDS Proxy %q to database resource: %v.", aws.StringValue(dbProxy.DBProxyName), err) } else { databases = append(databases, database) @@ -123,7 +103,7 @@ func (f *rdsDBProxyFetcher) getRDSProxyDatabases(ctx context.Context) (types.Dat // Add custom endpoints. for _, customEndpoint := range customEndpointsByProxyName[aws.StringValue(dbProxy.DBProxyName)] { if !services.IsRDSProxyCustomEndpointAvailable(customEndpoint) { - f.log.Debugf("The current status of custom endpoint %q of RDS Proxy %q is %q. Skipping.", + cfg.Log.Debugf("The current status of custom endpoint %q of RDS Proxy %q is %q. Skipping.", aws.StringValue(customEndpoint.DBProxyEndpointName), aws.StringValue(customEndpoint.DBProxyName), aws.StringValue(customEndpoint.Status)) @@ -132,7 +112,7 @@ func (f *rdsDBProxyFetcher) getRDSProxyDatabases(ctx context.Context) (types.Dat database, err = services.NewDatabaseFromRDSProxyCustomEndpoint(dbProxy, customEndpoint, port, tags) if err != nil { - f.log.Debugf("Could not convert custom endpoint %q of RDS Proxy %q to database resource: %v.", + cfg.Log.Debugf("Could not convert custom endpoint %q of RDS Proxy %q to database resource: %v.", aws.StringValue(customEndpoint.DBProxyEndpointName), aws.StringValue(customEndpoint.DBProxyName), err) @@ -145,12 +125,6 @@ func (f *rdsDBProxyFetcher) getRDSProxyDatabases(ctx context.Context) (types.Dat return databases, nil } -// String returns the fetcher's string description. -func (f *rdsDBProxyFetcher) String() string { - return fmt.Sprintf("rdsDBProxyFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) -} - // getRDSProxies fetches all RDS Proxies using the provided client, up to the // specified max number of pages. func getRDSProxies(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int) (rdsProxies []*rds.DBProxy, err error) { diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go index 9d901d7b9aa9f..c1b8bba6f92a6 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) func TestRDSDBProxyFetcher(t *testing.T) { @@ -80,6 +81,7 @@ func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rds.DBProxy, types rdsProxyDatabase, err := services.NewDatabaseFromRDSProxy(rdsProxy, 9999, nil) require.NoError(t, err) + common.ApplyAWSDatabaseNameSuffix(rdsProxyDatabase, services.AWSMatcherRDSProxy) return rdsProxy, rdsProxyDatabase } @@ -94,5 +96,6 @@ func makeRDSProxyCustomEndpoint(t *testing.T, rdsProxy *rds.DBProxy, name, regio } rdsProxyEndpointDatabase, err := services.NewDatabaseFromRDSProxyCustomEndpoint(rdsProxy, rdsProxyEndpoint, 9999, nil) require.NoError(t, err) + common.ApplyAWSDatabaseNameSuffix(rdsProxyEndpointDatabase, services.AWSMatcherRDSProxy) return rdsProxyEndpoint, rdsProxyEndpointDatabase } diff --git a/lib/srv/discovery/fetchers/db/aws_rds_test.go b/lib/srv/discovery/fetchers/db/aws_rds_test.go index f86f80cd4a661..9fce0508defb8 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_test.go @@ -31,6 +31,7 @@ import ( libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) // TestRDSFetchers tests RDS instance fetcher and Aurora cluster fetcher (as @@ -223,6 +224,7 @@ func makeRDSInstance(t *testing.T, name, region string, labels map[string]string database, err := services.NewDatabaseFromRDSInstance(instance) require.NoError(t, err) + common.ApplyAWSDatabaseNameSuffix(database, services.AWSMatcherRDS) return instance, database } @@ -247,6 +249,7 @@ func makeRDSCluster(t *testing.T, name, region string, labels map[string]string, database, err := services.NewDatabaseFromRDSCluster(cluster) require.NoError(t, err) + common.ApplyAWSDatabaseNameSuffix(database, services.AWSMatcherRDS) return cluster, database } @@ -291,6 +294,9 @@ func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels require.NoError(t, err) databases = append(databases, customDatabases...) + for _, db := range databases { + common.ApplyAWSDatabaseNameSuffix(db, services.AWSMatcherRDS) + } return cluster, databases } diff --git a/lib/srv/discovery/fetchers/db/aws_redshift.go b/lib/srv/discovery/fetchers/db/aws_redshift.go index 771e183644dfb..e211b08121a3c 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift.go @@ -18,73 +18,35 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/redshift" "github.com/aws/aws-sdk-go/service/redshift/redshiftiface" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// redshiftFetcherConfig is the Redshift databases fetcher configuration. -type redshiftFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // Redshift is the Redshift API client. - Redshift redshiftiface.RedshiftAPI - // Region is the AWS region to query databases in. - Region string - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole +// newRedshiftFetcher returns a new AWS fetcher for Redshift databases. +func newRedshiftFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &redshiftPlugin{}) } -// CheckAndSetDefaults validates the config and sets defaults. -func (c *redshiftFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.Redshift == nil { - return trace.BadParameter("missing parameter Redshift") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - return nil -} - -// redshiftFetcher retrieves Redshift databases. -type redshiftFetcher struct { - awsFetcher +// redshiftPlugin retrieves Redshift databases. +type redshiftPlugin struct{} - cfg redshiftFetcherConfig - log logrus.FieldLogger -} - -// newRedshiftFetcher returns a new Redshift databases fetcher instance. -func newRedshiftFetcher(config redshiftFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { +// GetDatabases returns Redshift databases matching the watcher's selectors. +func (f *redshiftPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + redshiftClient, err := cfg.AWSClients.GetAWSRedshiftClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) + if err != nil { return nil, trace.Wrap(err) } - return &redshiftFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:redshift", - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns Redshift and Aurora databases matching the watcher's selectors. -func (f *redshiftFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - clusters, err := getRedshiftClusters(ctx, f.cfg.Redshift) + clusters, err := getRedshiftClusters(ctx, redshiftClient) if err != nil { return nil, trace.Wrap(err) } @@ -92,7 +54,7 @@ func (f *redshiftFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e var databases types.Databases for _, cluster := range clusters { if !services.IsRedshiftClusterAvailable(cluster) { - f.log.Debugf("The current status of Redshift cluster %q is %q. Skipping.", + cfg.Log.Debugf("The current status of Redshift cluster %q is %q. Skipping.", aws.StringValue(cluster.ClusterIdentifier), aws.StringValue(cluster.ClusterStatus)) continue @@ -100,21 +62,18 @@ func (f *redshiftFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, e database, err := services.NewDatabaseFromRedshiftCluster(cluster) if err != nil { - f.log.Infof("Could not convert Redshift cluster %q to database resource: %v.", + cfg.Log.Infof("Could not convert Redshift cluster %q to database resource: %v.", aws.StringValue(cluster.ClusterIdentifier), err) continue } databases = append(databases, database) } - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil + return databases, nil } -// String returns the fetcher's string description. -func (f *redshiftFetcher) String() string { - return fmt.Sprintf("redshiftFetcher(Region=%v, Labels=%v)", - f.cfg.Region, f.cfg.Labels) +func (f *redshiftPlugin) ComponentShortName() string { + return "redshift" } // getRedshiftClusters fetches all Reshift clusters using the provided client, diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go b/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go index 6dc3535330671..c8f719d593e43 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_serverless.go @@ -18,7 +18,6 @@ package db import ( "context" - "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/redshiftserverless" @@ -27,120 +26,85 @@ import ( "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// redshiftServerlessFetcherConfig is the Redshift Serverless databases fetcher -// configuration. -type redshiftServerlessFetcherConfig struct { - // Labels is a selector to match cloud databases. - Labels types.Labels - // Region is the AWS region to query databases in. - Region string - // Client is the Redshift Serverless API client. - Client redshiftserverlessiface.RedshiftServerlessAPI - // AssumeRole is the AWS IAM role to assume before discovering databases. - AssumeRole types.AssumeRole +// newRedshiftServerlessFetcher returns a new AWS fetcher for Redshift +// Serverless databases. +func newRedshiftServerlessFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { + return newAWSFetcher(cfg, &redshiftServerlessPlugin{}) } -// CheckAndSetDefaults validates the config and sets defaults. -func (c *redshiftServerlessFetcherConfig) CheckAndSetDefaults() error { - if len(c.Labels) == 0 { - return trace.BadParameter("missing parameter Labels") - } - if c.Region == "" { - return trace.BadParameter("missing parameter Region") - } - if c.Client == nil { - return trace.BadParameter("missing parameter Client") - } - return nil -} - -type redshiftServerlessWorkgroupWithTags struct { +type workgroupWithTags struct { *redshiftserverless.Workgroup Tags []*redshiftserverless.Tag } -// redshiftServerlessFetcher retrieves Redshift Serverless databases. -type redshiftServerlessFetcher struct { - awsFetcher +// redshiftServerlessPlugin retrieves Redshift Serverless databases. +type redshiftServerlessPlugin struct{} - cfg redshiftServerlessFetcherConfig - log logrus.FieldLogger +func (f *redshiftServerlessPlugin) ComponentShortName() string { + // (r)ed(s)hift (s)erver(<)less + return "rss<" } -// newRedshiftServerlessFetcher returns a new Redshift Serverless databases -// fetcher instance. -func newRedshiftServerlessFetcher(config redshiftServerlessFetcherConfig) (common.Fetcher, error) { - if err := config.CheckAndSetDefaults(); err != nil { +// rssAPI is a type alias for brevity alone. +type rssAPI = redshiftserverlessiface.RedshiftServerlessAPI + +// GetDatabases returns Redshift Serverless databases matching the watcher's selectors. +func (f *redshiftServerlessPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { + client, err := cfg.AWSClients.GetAWSRedshiftServerlessClient(ctx, cfg.Region, + cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID)) + if err != nil { return nil, trace.Wrap(err) } - return &redshiftServerlessFetcher{ - cfg: config, - log: logrus.WithFields(logrus.Fields{ - trace.Component: "watch:rss<", // (r)ed(s)hift (s)erver(<)less - "labels": config.Labels, - "region": config.Region, - "role": config.AssumeRole, - }), - }, nil -} - -// Get returns Redshift Serverless databases matching the watcher's selectors. -func (f *redshiftServerlessFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { - databases, workgroups, err := f.getDatabasesFromWorkgroups(ctx) + databases, workgroups, err := getDatabasesFromWorkgroups(ctx, client, cfg.Log) if err != nil { return nil, trace.Wrap(err) } if len(workgroups) > 0 { - vpcEndpointDatabases, err := f.getDatabasesFromVPCEndpoints(ctx, workgroups) + vpcEndpointDatabases, err := getDatabasesFromVPCEndpoints(ctx, workgroups, client, cfg.Log) if err != nil { if trace.IsAccessDenied(err) { - f.log.Debugf("No permission to get Redshift Serverless VPC endpoints: %v.", err) + cfg.Log.Debugf("No permission to get Redshift Serverless VPC endpoints: %v.", err) } else { - f.log.Warnf("Failed to get Redshift Serverless VPC endpoints: %v.", err) + cfg.Log.Warnf("Failed to get Redshift Serverless VPC endpoints: %v.", err) } } databases = append(databases, vpcEndpointDatabases...) } - applyAssumeRoleToDatabases(databases, f.cfg.AssumeRole) - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil -} - -// String returns the fetcher's string description. -func (f *redshiftServerlessFetcher) String() string { - return fmt.Sprintf("redshiftServerlessFetcher(Region=%v, Labels=%v)", f.cfg.Region, f.cfg.Labels) + return databases, nil } -func (f *redshiftServerlessFetcher) getDatabasesFromWorkgroups(ctx context.Context) (types.Databases, []*redshiftServerlessWorkgroupWithTags, error) { - workgroups, err := f.getWorkgroups(ctx) +func getDatabasesFromWorkgroups(ctx context.Context, client rssAPI, log logrus.FieldLogger) (types.Databases, []*workgroupWithTags, error) { + workgroups, err := getRSSWorkgroups(ctx, client) if err != nil { return nil, nil, trace.Wrap(err) } var databases types.Databases - var workgroupsWithTags []*redshiftServerlessWorkgroupWithTags + var workgroupsWithTags []*workgroupWithTags for _, workgroup := range workgroups { if !services.IsAWSResourceAvailable(workgroup, workgroup.Status) { - f.log.Debugf("The current status of Redshift Serverless workgroup %v is %v. Skipping.", aws.StringValue(workgroup.WorkgroupName), aws.StringValue(workgroup.Status)) + log.Debugf("The current status of Redshift Serverless workgroup %v is %v. Skipping.", aws.StringValue(workgroup.WorkgroupName), aws.StringValue(workgroup.Status)) continue } - tags := f.getResourceTags(ctx, workgroup.WorkgroupArn) + tags := getRSSResourceTags(ctx, workgroup.WorkgroupArn, client, log) database, err := services.NewDatabaseFromRedshiftServerlessWorkgroup(workgroup, tags) if err != nil { - f.log.WithError(err).Infof("Could not convert Redshift Serverless workgroup %q to database resource.", aws.StringValue(workgroup.WorkgroupName)) + log.WithError(err).Infof("Could not convert Redshift Serverless workgroup %q to database resource.", aws.StringValue(workgroup.WorkgroupName)) continue } databases = append(databases, database) - workgroupsWithTags = append(workgroupsWithTags, &redshiftServerlessWorkgroupWithTags{ + workgroupsWithTags = append(workgroupsWithTags, &workgroupWithTags{ Workgroup: workgroup, Tags: tags, }) @@ -148,8 +112,8 @@ func (f *redshiftServerlessFetcher) getDatabasesFromWorkgroups(ctx context.Conte return databases, workgroupsWithTags, nil } -func (f *redshiftServerlessFetcher) getDatabasesFromVPCEndpoints(ctx context.Context, workgroups []*redshiftServerlessWorkgroupWithTags) (types.Databases, error) { - endpoints, err := f.getVPCEndpoints(ctx) +func getDatabasesFromVPCEndpoints(ctx context.Context, workgroups []*workgroupWithTags, client rssAPI, log logrus.FieldLogger) (types.Databases, error) { + endpoints, err := getRSSVPCEndpoints(ctx, client) if err != nil { return nil, trace.Wrap(err) } @@ -158,12 +122,12 @@ func (f *redshiftServerlessFetcher) getDatabasesFromVPCEndpoints(ctx context.Con for _, endpoint := range endpoints { workgroup, found := findWorkgroupWithName(workgroups, aws.StringValue(endpoint.WorkgroupName)) if !found { - f.log.Debugf("Could not find matching workgroup for Redshift Serverless endpoint %v. Skipping.", aws.StringValue(endpoint.EndpointName)) + log.Debugf("Could not find matching workgroup for Redshift Serverless endpoint %v. Skipping.", aws.StringValue(endpoint.EndpointName)) continue } if !services.IsAWSResourceAvailable(endpoint, endpoint.EndpointStatus) { - f.log.Debugf("The current status of Redshift Serverless endpoint %v is %v. Skipping.", aws.StringValue(endpoint.EndpointName), aws.StringValue(endpoint.EndpointStatus)) + log.Debugf("The current status of Redshift Serverless endpoint %v is %v. Skipping.", aws.StringValue(endpoint.EndpointName), aws.StringValue(endpoint.EndpointStatus)) continue } @@ -171,7 +135,7 @@ func (f *redshiftServerlessFetcher) getDatabasesFromVPCEndpoints(ctx context.Con // tags from the workgroups instead. database, err := services.NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint, workgroup.Workgroup, workgroup.Tags) if err != nil { - f.log.WithError(err).Infof("Could not convert Redshift Serverless endpoint %q to database resource.", aws.StringValue(endpoint.EndpointName)) + log.WithError(err).Infof("Could not convert Redshift Serverless endpoint %q to database resource.", aws.StringValue(endpoint.EndpointName)) continue } databases = append(databases, database) @@ -179,41 +143,41 @@ func (f *redshiftServerlessFetcher) getDatabasesFromVPCEndpoints(ctx context.Con return databases, nil } -func (f *redshiftServerlessFetcher) getResourceTags(ctx context.Context, arn *string) []*redshiftserverless.Tag { - output, err := f.cfg.Client.ListTagsForResourceWithContext(ctx, &redshiftserverless.ListTagsForResourceInput{ +func getRSSResourceTags(ctx context.Context, arn *string, client rssAPI, log logrus.FieldLogger) []*redshiftserverless.Tag { + output, err := client.ListTagsForResourceWithContext(ctx, &redshiftserverless.ListTagsForResourceInput{ ResourceArn: arn, }) if err != nil { // Log errors here and return nil. if trace.IsAccessDenied(err) { - f.log.WithError(err).Debugf("No Permission to get tags for %q.", aws.StringValue(arn)) + log.WithError(err).Debugf("No Permission to get tags for %q.", aws.StringValue(arn)) } else { - f.log.WithError(err).Warnf("Failed to get tags for %q.", aws.StringValue(arn)) + log.WithError(err).Warnf("Failed to get tags for %q.", aws.StringValue(arn)) } return nil } return output.Tags } -func (f *redshiftServerlessFetcher) getWorkgroups(ctx context.Context) ([]*redshiftserverless.Workgroup, error) { +func getRSSWorkgroups(ctx context.Context, client rssAPI) ([]*redshiftserverless.Workgroup, error) { var pages [][]*redshiftserverless.Workgroup - err := f.cfg.Client.ListWorkgroupsPagesWithContext(ctx, nil, func(page *redshiftserverless.ListWorkgroupsOutput, lastPage bool) bool { + err := client.ListWorkgroupsPagesWithContext(ctx, nil, func(page *redshiftserverless.ListWorkgroupsOutput, lastPage bool) bool { pages = append(pages, page.Workgroups) return len(pages) <= maxAWSPages }) return flatten(pages), libcloudaws.ConvertRequestFailureError(err) } -func (f *redshiftServerlessFetcher) getVPCEndpoints(ctx context.Context) ([]*redshiftserverless.EndpointAccess, error) { +func getRSSVPCEndpoints(ctx context.Context, client rssAPI) ([]*redshiftserverless.EndpointAccess, error) { var pages [][]*redshiftserverless.EndpointAccess - err := f.cfg.Client.ListEndpointAccessPagesWithContext(ctx, nil, func(page *redshiftserverless.ListEndpointAccessOutput, lastPage bool) bool { + err := client.ListEndpointAccessPagesWithContext(ctx, nil, func(page *redshiftserverless.ListEndpointAccessOutput, lastPage bool) bool { pages = append(pages, page.Endpoints) return len(pages) <= maxAWSPages }) return flatten(pages), libcloudaws.ConvertRequestFailureError(err) } -func findWorkgroupWithName(workgroups []*redshiftServerlessWorkgroupWithTags, name string) (*redshiftServerlessWorkgroupWithTags, bool) { +func findWorkgroupWithName(workgroups []*workgroupWithTags, name string) (*workgroupWithTags, bool) { for _, workgroup := range workgroups { if aws.StringValue(workgroup.WorkgroupName) == name { return workgroup, true diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go b/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go index 28d2118f985c3..d2383d19d57e3 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_serverless_test.go @@ -28,6 +28,7 @@ import ( libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) func TestRedshiftServerlessFetcher(t *testing.T) { @@ -93,6 +94,7 @@ func makeRedshiftServerlessWorkgroup(t *testing.T, name, region string, labels m tags := libcloudaws.LabelsToTags[redshiftserverless.Tag](labels) database, err := services.NewDatabaseFromRedshiftServerlessWorkgroup(workgroup, tags) require.NoError(t, err) + common.ApplyAWSDatabaseNameSuffix(database, services.AWSMatcherRedshiftServerless) return workgroup, database } @@ -101,5 +103,6 @@ func makeRedshiftServerlessEndpoint(t *testing.T, workgroup *redshiftserverless. tags := libcloudaws.LabelsToTags[redshiftserverless.Tag](labels) database, err := services.NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint, workgroup, tags) require.NoError(t, err) + common.ApplyAWSDatabaseNameSuffix(database, services.AWSMatcherRedshiftServerless) return endpoint, database } diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_test.go b/lib/srv/discovery/fetchers/db/aws_redshift_test.go index 8a6d17be2bbe0..bc7ca50187cc4 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_test.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_test.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) func TestRedshiftFetcher(t *testing.T) { @@ -93,6 +94,7 @@ func makeRedshiftCluster(t *testing.T, region, env string, opts ...func(*redshif database, err := services.NewDatabaseFromRedshiftCluster(cluster) require.NoError(t, err) + common.ApplyAWSDatabaseNameSuffix(database, services.AWSMatcherRedshift) return cluster, database } diff --git a/lib/srv/discovery/fetchers/db/azure.go b/lib/srv/discovery/fetchers/db/azure.go index 127611294cf33..12112dccb7409 100644 --- a/lib/srv/discovery/fetchers/db/azure.go +++ b/lib/srv/discovery/fetchers/db/azure.go @@ -148,8 +148,15 @@ func (f *azureFetcher[DBType, ListClient]) Get(ctx context.Context) (types.Resou if err != nil { return nil, trace.Wrap(err) } + f.rewriteDatabases(databases) + return databases.AsResources(), nil +} - return filterDatabasesByLabels(databases, f.cfg.Labels, f.log).AsResources(), nil +// rewriteDatabases rewrites the discovered databases. +func (f *azureFetcher[DBType, ListClient]) rewriteDatabases(databases types.Databases) { + for _, db := range databases { + common.ApplyAzureDatabaseNameSuffix(db, f.cfg.Type) + } } // getSubscriptions returns the subscriptions that this fetcher is configured to query. @@ -225,7 +232,7 @@ func (f *azureFetcher[DBType, ListClient]) getDatabases(ctx context.Context) (ty databases = append(databases, database) } } - return databases, nil + return filterDatabasesByLabels(databases, f.cfg.Labels, f.log), nil } // String returns the fetcher's string description. diff --git a/lib/srv/discovery/fetchers/db/azure_dbserver_test.go b/lib/srv/discovery/fetchers/db/azure_dbserver_test.go index 67cfb43dbf132..99f96f5eb1c02 100644 --- a/lib/srv/discovery/fetchers/db/azure_dbserver_test.go +++ b/lib/srv/discovery/fetchers/db/azure_dbserver_test.go @@ -31,6 +31,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) // TestAzureDBServerFetchers tests common azureFetcher functionalities and the @@ -345,6 +346,7 @@ func makeAzureMySQLServer(t *testing.T, name, subscription, group, region string database, err := services.NewDatabaseFromAzureServer(azureDBServer) require.NoError(t, err) + common.ApplyAzureDatabaseNameSuffix(database, services.AzureMatcherMySQL) return server, database } @@ -380,6 +382,7 @@ func makeAzurePostgresServer(t *testing.T, name, subscription, group, region str database, err := services.NewDatabaseFromAzureServer(azureDBServer) require.NoError(t, err) + common.ApplyAzureDatabaseNameSuffix(database, services.AzureMatcherPostgres) return server, database } diff --git a/lib/srv/discovery/fetchers/db/azure_mysql_flex_test.go b/lib/srv/discovery/fetchers/db/azure_mysql_flex_test.go index fd629ea44824e..0f6c8220839ae 100644 --- a/lib/srv/discovery/fetchers/db/azure_mysql_flex_test.go +++ b/lib/srv/discovery/fetchers/db/azure_mysql_flex_test.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) // TestAzureMySQLFlexFetchers tests Azure MySQL Flexible server fetchers. @@ -86,5 +87,6 @@ func makeAzureMySQLFlexServer(t *testing.T, name, subscription, group, region st } database, err := services.NewDatabaseFromAzureMySQLFlexServer(server) require.NoError(t, err) + common.ApplyAzureDatabaseNameSuffix(database, services.AzureMatcherMySQL) return server, database } diff --git a/lib/srv/discovery/fetchers/db/azure_postgres_flex_test.go b/lib/srv/discovery/fetchers/db/azure_postgres_flex_test.go index 8114ffe1e9c98..af3e495652905 100644 --- a/lib/srv/discovery/fetchers/db/azure_postgres_flex_test.go +++ b/lib/srv/discovery/fetchers/db/azure_postgres_flex_test.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) // TestAzurePostgresFlexFetchers tests Azure PostgreSQL Flexible server fetchers. @@ -86,5 +87,6 @@ func makeAzurePostgresFlexServer(t *testing.T, name, subscription, group, region } database, err := services.NewDatabaseFromAzurePostgresFlexServer(server) require.NoError(t, err) + common.ApplyAzureDatabaseNameSuffix(database, services.AzureMatcherPostgres) return server, database } diff --git a/lib/srv/discovery/fetchers/db/azure_redis_test.go b/lib/srv/discovery/fetchers/db/azure_redis_test.go index 2b62678f4be7a..3966c36aad221 100644 --- a/lib/srv/discovery/fetchers/db/azure_redis_test.go +++ b/lib/srv/discovery/fetchers/db/azure_redis_test.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) // TestAzureRedisFetchers tests Azure Redis and Azure Redis Enterprise fetchers @@ -87,6 +88,7 @@ func makeAzureRedisServer(t *testing.T, name, subscription, group, region string database, err := services.NewDatabaseFromAzureRedis(resourceInfo) require.NoError(t, err) + common.ApplyAzureDatabaseNameSuffix(database, services.AzureMatcherRedis) return resourceInfo, database } @@ -113,5 +115,6 @@ func makeAzureRedisEnterpriseCluster(t *testing.T, cluster, subscription, group, database, err := services.NewDatabaseFromAzureRedisEnterprise(armCluster, armDatabase) require.NoError(t, err) + common.ApplyAzureDatabaseNameSuffix(database, services.AzureMatcherRedis) return armCluster, armDatabase, database } diff --git a/lib/srv/discovery/fetchers/db/db.go b/lib/srv/discovery/fetchers/db/db.go index 1d6a07a8c3cee..2a756db943be3 100644 --- a/lib/srv/discovery/fetchers/db/db.go +++ b/lib/srv/discovery/fetchers/db/db.go @@ -29,18 +29,18 @@ import ( "github.com/gravitational/teleport/lib/srv/discovery/common" ) -type makeAWSFetcherFunc func(context.Context, cloud.AWSClients, string, types.Labels, types.AssumeRole) (common.Fetcher, error) +type makeAWSFetcherFunc func(awsFetcherConfig) (common.Fetcher, error) type makeAzureFetcherFunc func(azureFetcherConfig) (common.Fetcher, error) var ( makeAWSFetcherFuncs = map[string][]makeAWSFetcherFunc{ - services.AWSMatcherRDS: {makeRDSInstanceFetcher, makeRDSAuroraFetcher}, - services.AWSMatcherRDSProxy: {makeRDSProxyFetcher}, - services.AWSMatcherRedshift: {makeRedshiftFetcher}, - services.AWSMatcherRedshiftServerless: {makeRedshiftServerlessFetcher}, - services.AWSMatcherElastiCache: {makeElastiCacheFetcher}, - services.AWSMatcherMemoryDB: {makeMemoryDBFetcher}, - services.AWSMatcherOpenSearch: {makeOpenSearchFetcher}, + services.AWSMatcherRDS: {newRDSDBInstancesFetcher, newRDSAuroraClustersFetcher}, + services.AWSMatcherRDSProxy: {newRDSDBProxyFetcher}, + services.AWSMatcherRedshift: {newRedshiftFetcher}, + services.AWSMatcherRedshiftServerless: {newRedshiftServerlessFetcher}, + services.AWSMatcherElastiCache: {newElastiCacheFetcher}, + services.AWSMatcherMemoryDB: {newMemoryDBFetcher}, + services.AWSMatcherOpenSearch: {newOpenSearchFetcher}, } makeAzureFetcherFuncs = map[string][]makeAzureFetcherFunc{ @@ -76,7 +76,13 @@ func MakeAWSFetchers(ctx context.Context, clients cloud.AWSClients, matchers []t for _, makeFetcher := range makeFetchers { for _, region := range matcher.Regions { - fetcher, err := makeFetcher(ctx, clients, region, matcher.Tags, assumeRole) + fetcher, err := makeFetcher(awsFetcherConfig{ + AWSClients: clients, + Type: matcherType, + AssumeRole: assumeRole, + Labels: matcher.Tags, + Region: region, + }) if err != nil { return nil, trace.Wrap(err) } @@ -120,125 +126,6 @@ func MakeAzureFetchers(clients cloud.AzureClients, matchers []types.AzureMatcher return result, nil } -// makeRDSInstanceFetcher returns RDS instance fetcher for the provided region and tags. -func makeRDSInstanceFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - - fetcher, err := newRDSDBInstancesFetcher(rdsFetcherConfig{ - Region: region, - Labels: tags, - RDS: rds, - AssumeRole: assumeRole, - }) - return fetcher, trace.Wrap(err) -} - -// makeRDSAuroraFetcher returns RDS Aurora fetcher for the provided region and tags. -func makeRDSAuroraFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - - fetcher, err := newRDSAuroraClustersFetcher(rdsFetcherConfig{ - Region: region, - Labels: tags, - RDS: rds, - AssumeRole: assumeRole, - }) - return fetcher, trace.Wrap(err) -} - -// makeRDSProxyFetcher returns RDS proxy fetcher for the provided region and tags. -func makeRDSProxyFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - rds, err := clients.GetAWSRDSClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - - return newRDSDBProxyFetcher(rdsFetcherConfig{ - Region: region, - Labels: tags, - RDS: rds, - AssumeRole: assumeRole, - }) -} - -// makeRedshiftFetcher returns Redshift fetcher for the provided region and tags. -func makeRedshiftFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - redshift, err := clients.GetAWSRedshiftClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - return newRedshiftFetcher(redshiftFetcherConfig{ - Region: region, - Labels: tags, - Redshift: redshift, - AssumeRole: assumeRole, - }) -} - -// makeElastiCacheFetcher returns ElastiCache fetcher for the provided region and tags. -func makeElastiCacheFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - elastiCache, err := clients.GetAWSElastiCacheClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - return newElastiCacheFetcher(elastiCacheFetcherConfig{ - Region: region, - Labels: tags, - ElastiCache: elastiCache, - AssumeRole: assumeRole, - }) -} - -// makeMemoryDBFetcher returns MemoryDB fetcher for the provided region and tags. -func makeMemoryDBFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - memorydb, err := clients.GetAWSMemoryDBClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - return newMemoryDBFetcher(memoryDBFetcherConfig{ - Region: region, - Labels: tags, - MemoryDB: memorydb, - AssumeRole: assumeRole, - }) -} - -// makeOpenSearchFetcher returns OpenSearch fetcher for the provided region and tags. -func makeOpenSearchFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - opensearch, err := clients.GetAWSOpenSearchClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - - return newOpenSearchFetcher(openSearchFetcherConfig{ - Region: region, - Labels: tags, - openSearch: opensearch, - AssumeRole: assumeRole, - }) -} - -// makeRedshiftServerlessFetcher returns Redshift Serverless fetcher for the -// provided region and tags. -func makeRedshiftServerlessFetcher(ctx context.Context, clients cloud.AWSClients, region string, tags types.Labels, assumeRole types.AssumeRole) (common.Fetcher, error) { - client, err := clients.GetAWSRedshiftServerlessClient(ctx, region, cloud.WithAssumeRole(assumeRole.RoleARN, assumeRole.ExternalID)) - if err != nil { - return nil, trace.Wrap(err) - } - return newRedshiftServerlessFetcher(redshiftServerlessFetcherConfig{ - Region: region, - Labels: tags, - Client: client, - AssumeRole: assumeRole, - }) -} - // filterDatabasesByLabels filters input databases with provided labels. func filterDatabasesByLabels(databases types.Databases, labels types.Labels, log logrus.FieldLogger) types.Databases { var matchedDatabases types.Databases @@ -255,14 +142,6 @@ func filterDatabasesByLabels(databases types.Databases, labels types.Labels, log return matchedDatabases } -// applyAssumeRoleToDatabases applies assume role settings from fetcher to databases. -func applyAssumeRoleToDatabases(databases types.Databases, assumeRole types.AssumeRole) { - for _, db := range databases { - db.SetAWSAssumeRole(assumeRole.RoleARN) - db.SetAWSExternalID(assumeRole.ExternalID) - } -} - // flatten flattens a nested slice [][]T to []T. func flatten[T any](s [][]T) (result []T) { for i := range s { diff --git a/lib/srv/discovery/fetchers/db/helpers_test.go b/lib/srv/discovery/fetchers/db/helpers_test.go index a2a81e18c3f2d..a2f089836e7bc 100644 --- a/lib/srv/discovery/fetchers/db/helpers_test.go +++ b/lib/srv/discovery/fetchers/db/helpers_test.go @@ -146,9 +146,11 @@ func copyDatabasesWithAWSAssumeRole(role types.AssumeRole, databases ...types.Da } out := make(types.Databases, 0, len(databases)) for _, db := range databases { - out = append(out, db.Copy()) + dbCopy := db.Copy() + dbCopy.SetAWSAssumeRole(role.RoleARN) + dbCopy.SetAWSExternalID(role.ExternalID) + out = append(out, dbCopy) } - applyAssumeRoleToDatabases(out, role) return out } diff --git a/lib/srv/discovery/fetchers/eks.go b/lib/srv/discovery/fetchers/eks.go index c0c0c0c42bcc5..296b9582c80b6 100644 --- a/lib/srv/discovery/fetchers/eks.go +++ b/lib/srv/discovery/fetchers/eks.go @@ -87,9 +87,17 @@ func (a *eksFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) return nil, trace.Wrap(err) } + a.rewriteKubeClusters(clusters) return clusters.AsResources(), nil } +// rewriteKubeClusters rewrites the discovered kube clusters. +func (a *eksFetcher) rewriteKubeClusters(clusters types.KubeClusters) { + for _, c := range clusters { + common.ApplyEKSNameSuffix(c) + } +} + func (a *eksFetcher) getEKSClusters(ctx context.Context) (types.KubeClusters, error) { var ( clusters types.KubeClusters diff --git a/lib/srv/discovery/fetchers/eks_test.go b/lib/srv/discovery/fetchers/eks_test.go index b099c3726a82a..a751f615b1a6b 100644 --- a/lib/srv/discovery/fetchers/eks_test.go +++ b/lib/srv/discovery/fetchers/eks_test.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) func TestEKSFetcher(t *testing.T) { @@ -197,6 +198,7 @@ func eksClustersToResources(t *testing.T, clusters ...*eks.Cluster) types.Resour kubeCluster, err := services.NewKubeClusterFromAWSEKS(cluster) require.NoError(t, err) require.True(t, kubeCluster.IsAWS()) + common.ApplyEKSNameSuffix(kubeCluster) kubeClusters = append(kubeClusters, kubeCluster) } return kubeClusters.AsResources() diff --git a/lib/srv/discovery/fetchers/gke.go b/lib/srv/discovery/fetchers/gke.go index acccc17835347..060a00f9cb993 100644 --- a/lib/srv/discovery/fetchers/gke.go +++ b/lib/srv/discovery/fetchers/gke.go @@ -84,6 +84,7 @@ func (a *gkeFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) return nil, trace.Wrap(err) } + a.rewriteKubeClusters(clusters) return clusters.AsResources(), nil } @@ -108,6 +109,13 @@ func (a *gkeFetcher) getGKEClusters(ctx context.Context) (types.KubeClusters, er return clusters, trace.Wrap(err) } +// rewriteKubeClusters rewrites the discovered kube clusters. +func (a *gkeFetcher) rewriteKubeClusters(clusters types.KubeClusters) { + for _, c := range clusters { + common.ApplyGKENameSuffix(c) + } +} + func (a *gkeFetcher) ResourceType() string { return types.KindKubernetesCluster } diff --git a/lib/srv/discovery/fetchers/gke_test.go b/lib/srv/discovery/fetchers/gke_test.go index 10ad78a66b774..833ff224a5d7b 100644 --- a/lib/srv/discovery/fetchers/gke_test.go +++ b/lib/srv/discovery/fetchers/gke_test.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud/gcp" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/discovery/common" ) func TestGKEFetcher(t *testing.T) { @@ -178,6 +179,7 @@ func gkeClustersToResources(t *testing.T, clusters ...gcp.GKECluster) types.Reso kubeCluster, err := services.NewKubeClusterFromGCPGKE(cluster) require.NoError(t, err) require.True(t, kubeCluster.IsGCP()) + common.ApplyGKENameSuffix(kubeCluster) kubeClusters = append(kubeClusters, kubeCluster) } return kubeClusters.AsResources()