Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/cloud/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -645,10 +645,10 @@ func (c *cloudClients) getAWSSessionForRegion(region string) (*awssession.Sessio

// getAWSSessionForRole returns AWS session for the specified region and role.
func (c *cloudClients) getAWSSessionForRole(ctx context.Context, region string, options awsAssumeRoleOpts) (*awssession.Session, error) {
assumeRoler := sts.New(options.baseSession)
cacheKey := fmt.Sprintf("Region[%s]:RoleARN[%s]:ExternalID[%s]", region, options.assumeRoleARN, options.assumeRoleExternalID)
return utils.FnCacheGet(ctx, c.awsSessionsCache, cacheKey, func(ctx context.Context) (*awssession.Session, error) {
return newSessionWithRole(ctx, assumeRoler, region, options.assumeRoleARN, options.assumeRoleExternalID)
stsClient := sts.New(options.baseSession)
return newSessionWithRole(ctx, stsClient, region, options.assumeRoleARN, options.assumeRoleExternalID)
})
}

Expand Down
7 changes: 4 additions & 3 deletions lib/integrations/awsoidc/listdatabases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/aws/aws-sdk-go-v2/service/rds"
rdsTypes "github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -210,7 +211,7 @@ func TestListDatabases(t *testing.T) {
},
)
require.NoError(t, err)
require.Equal(t, expectedDB, ldr.Databases[0])
require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0]))
},
errCheck: noErrorFunc,
},
Expand Down Expand Up @@ -275,7 +276,7 @@ func TestListDatabases(t *testing.T) {
},
)
require.NoError(t, err)
require.Equal(t, expectedDB, ldr.Databases[0])
require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0]))
},
errCheck: noErrorFunc,
},
Expand Down Expand Up @@ -328,7 +329,7 @@ func TestListDatabases(t *testing.T) {
},
)
require.NoError(t, err)
require.Equal(t, expectedDB, ldr.Databases[0])
require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0]))
},
errCheck: noErrorFunc,
},
Expand Down
6 changes: 2 additions & 4 deletions lib/srv/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -811,10 +811,8 @@ func (s *Server) Start() error {
if s.gcpWatcher != nil {
go s.handleGCPDiscovery()
}
if len(s.kubeFetchers) > 0 {
if err := s.startKubeWatchers(); err != nil {
return trace.Wrap(err)
}
if err := s.startKubeWatchers(); err != nil {
return trace.Wrap(err)
}
if err := s.startDatabaseWatchers(); err != nil {
return trace.Wrap(err)
Expand Down
12 changes: 6 additions & 6 deletions lib/srv/discovery/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ func TestDiscoveryKube(t *testing.T) {
mustConvertEKSToKubeCluster(t, eksMockClusters[0], mainDiscoveryGroup),
mustConvertEKSToKubeCluster(t, eksMockClusters[1], mainDiscoveryGroup),
},
clustersNotUpdated: []string{"eks-cluster1"},
clustersNotUpdated: []string{mustConvertEKSToKubeCluster(t, eksMockClusters[0], mainDiscoveryGroup).GetName()},
wantEvents: 1,
},
{
Expand Down Expand Up @@ -613,7 +613,7 @@ func TestDiscoveryKube(t *testing.T) {
mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][0], mainDiscoveryGroup),
mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][1], mainDiscoveryGroup),
},
clustersNotUpdated: []string{"aks-cluster1"},
clustersNotUpdated: []string{mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][0], mainDiscoveryGroup).GetName()},
wantEvents: 2,
},
{
Expand Down Expand Up @@ -1099,7 +1099,7 @@ func TestDiscoveryDatabase(t *testing.T) {
name: "update existing database",
existingDatabases: []types.Database{
mustNewDatabase(t, types.Metadata{
Name: "aws-redshift",
Name: awsRedshiftDB.GetName(),
Description: "should be updated",
Labels: map[string]string{types.OriginLabel: types.OriginCloud, types.TeleportInternalDiscoveryGroupName: mainDiscoveryGroup},
}, types.DatabaseSpecV3{
Expand All @@ -1123,7 +1123,7 @@ func TestDiscoveryDatabase(t *testing.T) {
name: "update existing database with assumed role",
existingDatabases: []types.Database{
mustNewDatabase(t, types.Metadata{
Name: "aws-rds",
Name: awsRDSDBWithRole.GetName(),
Description: "should be updated",
Labels: map[string]string{types.OriginLabel: types.OriginCloud, types.TeleportInternalDiscoveryGroupName: mainDiscoveryGroup},
}, types.DatabaseSpecV3{
Expand All @@ -1143,7 +1143,7 @@ func TestDiscoveryDatabase(t *testing.T) {
name: "delete existing database",
existingDatabases: []types.Database{
mustNewDatabase(t, types.Metadata{
Name: "aws-redshift",
Name: awsRedshiftDB.GetName(),
Description: "should not be deleted",
Labels: map[string]string{types.OriginLabel: types.OriginCloud},
}, types.DatabaseSpecV3{
Expand All @@ -1158,7 +1158,7 @@ func TestDiscoveryDatabase(t *testing.T) {
}},
expectDatabases: []types.Database{
mustNewDatabase(t, types.Metadata{
Name: "aws-redshift",
Name: awsRedshiftDB.GetName(),
Description: "should not be deleted",
Labels: map[string]string{types.OriginLabel: types.OriginCloud},
}, types.DatabaseSpecV3{
Expand Down
7 changes: 7 additions & 0 deletions lib/srv/discovery/fetchers/aks.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,16 @@ func (a *aksFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error)

kubeClusters = append(kubeClusters, kubeCluster)
}

a.rewriteKubeClusters(kubeClusters)
return kubeClusters.AsResources(), nil
}

// rewriteKubeClusters rewrites the discovered kube clusters.
func (a *aksFetcher) rewriteKubeClusters(clusters types.KubeClusters) {
// no-op
}

func (a *aksFetcher) getAKSClusters(ctx context.Context) ([]*azure.AKSCluster, error) {
var (
clusters []*azure.AKSCluster
Expand Down
116 changes: 116 additions & 0 deletions lib/srv/discovery/fetchers/db/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,121 @@ limitations under the License.
package db

import (
"context"
"fmt"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/srv/discovery/common"
)

// awsFetcherPlugin defines an interface that provides database type specific
// functions for use by the common AWS database fetcher.
type awsFetcherPlugin interface {
// GetDatabases fetches databases from AWS API and converts the results to
// Teleport types.Databases.
GetDatabases(context.Context, *awsFetcherConfig) (types.Databases, error)
// ComponentShortName provides the plugin's short component name for
// logging purposes.
ComponentShortName() string
}

// awsFetcherConfig is the AWS database fetcher configuration.
type awsFetcherConfig struct {
// AWSClients are the AWS API clients.
AWSClients cloud.AWSClients
// Type is the type of DB matcher, for example "rds", "redshift", etc.
Type string
// AssumeRole provides a role ARN and ExternalID to assume an AWS role
// when fetching databases.
AssumeRole types.AssumeRole
// Labels is a selector to match cloud database tags.
Labels types.Labels
// Region is the AWS region selector to match cloud databases.
Region string
// Log is a field logger to provide structured logging for each matcher,
// based on its config settings by default.
Log logrus.FieldLogger
}

// CheckAndSetDefaults validates the config and sets defaults.
func (cfg *awsFetcherConfig) CheckAndSetDefaults(component string) error {
if cfg.AWSClients == nil {
return trace.BadParameter("missing parameter AWSClients")
}
if cfg.Type == "" {
return trace.BadParameter("missing parameter Type")
}
if len(cfg.Labels) == 0 {
return trace.BadParameter("missing parameter Labels")
}
if cfg.Region == "" {
return trace.BadParameter("missing parameter Region")
}
if cfg.Log == nil {
cfg.Log = logrus.WithFields(logrus.Fields{
trace.Component: "watch:" + component,
"labels": cfg.Labels,
"region": cfg.Region,
"role": cfg.AssumeRole,
})
}
return nil
}

// newAWSFetcher returns a AWS database fetcher for the provided selectors
// and AWS database-type specific fetcher plugin.
func newAWSFetcher(cfg awsFetcherConfig, plugin awsFetcherPlugin) (*awsFetcher, error) {
if err := cfg.CheckAndSetDefaults(plugin.ComponentShortName()); err != nil {
return nil, trace.Wrap(err)
}
return &awsFetcher{cfg: cfg, plugin: plugin}, nil
}

// awsFetcher is the common base for AWS database fetchers.
type awsFetcher struct {
// cfg is the awsFetcher configuration.
cfg awsFetcherConfig
// plugin does AWS database type specific API calls fetch databases.
plugin awsFetcherPlugin
}

// awsFetcher implements common.Fetcher.
var _ common.Fetcher = (*awsFetcher)(nil)

// Get returns AWS databases matching the fetcher's selectors.
func (f *awsFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) {
databases, err := f.getDatabases(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
f.rewriteDatabases(databases)
return databases.AsResources(), nil
}

func (f *awsFetcher) getDatabases(ctx context.Context) (types.Databases, error) {
databases, err := f.plugin.GetDatabases(ctx, &f.cfg)
if err != nil {
return nil, trace.Wrap(err)
}
return filterDatabasesByLabels(databases, f.cfg.Labels, f.cfg.Log), nil
}

// rewriteDatabases rewrites the discovered databases.
func (f *awsFetcher) rewriteDatabases(databases types.Databases) {
for _, db := range databases {
f.applyAssumeRole(db)
}
}

// applyAssumeRole sets the database AWS AssumeRole metadata to match the
// fetcher's setting.
func (f *awsFetcher) applyAssumeRole(db types.Database) {
db.SetAWSAssumeRole(f.cfg.AssumeRole.RoleARN)
db.SetAWSExternalID(f.cfg.AssumeRole.ExternalID)
}

// Cloud returns the cloud the fetcher is operating.
Expand All @@ -34,6 +144,12 @@ func (f *awsFetcher) ResourceType() string {
return types.KindDatabase
}

// String returns the fetcher's string description.
func (f *awsFetcher) String() string {
return fmt.Sprintf("awsFetcher(Type: %v, Region=%v, Labels=%v)",
f.cfg.Type, f.cfg.Region, f.cfg.Labels)
}

// maxAWSPages is the maximum number of pages to iterate over when fetching aws
// databases.
const maxAWSPages = 10
Loading