diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index c20fa9f3e8ecd..62dfb93ee2e74 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -241,36 +241,19 @@ func (s *Server) initAWSWatchers(matchers []types.AWSMatcher) error { // Add kube fetchers. for _, matcher := range otherMatchers { - matcherAssumeRole := &types.AssumeRole{} + matcherAssumeRole := types.AssumeRole{} if matcher.AssumeRole != nil { - matcherAssumeRole = matcher.AssumeRole + matcherAssumeRole = *matcher.AssumeRole } for _, t := range matcher.Types { for _, region := range matcher.Regions { switch t { case services.AWSMatcherEKS: - client, err := s.Clients.GetAWSEKSClient( - s.ctx, - region, - cloud.WithAssumeRole( - matcherAssumeRole.RoleARN, - matcherAssumeRole.ExternalID, - ), - ) + fetcher, err := s.getEKSFetcher(region, matcherAssumeRole, matcher.Tags) if err != nil { - return trace.Wrap(err) - } - fetcher, err := fetchers.NewEKSFetcher( - fetchers.EKSFetcherConfig{ - Client: client, - Region: region, - FilterLabels: matcher.Tags, - Log: s.Log, - }, - ) - if err != nil { - return trace.Wrap(err) + s.Log.WithError(err).Warnf("Could not initialize EKS fetcher(Region=%q, Labels=%q, AssumeRole=%q), skipping.", region, matcher.Tags, matcherAssumeRole.RoleARN) + continue } s.kubeFetchers = append(s.kubeFetchers, fetcher) } @@ -281,6 +264,19 @@ func (s *Server) initAWSWatchers(matchers []types.AWSMatcher) error { return nil } +func (s *Server) getEKSFetcher(region string, assumeRole types.AssumeRole, tags types.Labels) (common.Fetcher, error) { + fetcher, err := fetchers.NewEKSFetcher( + fetchers.EKSFetcherConfig{ + EKSClientGetter: s.Clients, + AssumeRole: assumeRole, + Region: region, + FilterLabels: tags, + Log: s.Log, + }, + ) + return fetcher, trace.Wrap(err) +} + // initAzureWatchers starts Azure resource watchers based on types provided. func (s *Server) initAzureWatchers(ctx context.Context, matchers []types.AzureMatcher) error { vmMatchers, otherMatchers := splitMatchers(matchers, func(matcherType string) bool { diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 1854c90aef841..dbdade57d9763 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -754,8 +754,8 @@ func TestDiscoveryKube(t *testing.T) { return len(clustersNotUpdated) == 0 && clustersFoundInAuth }, 5*time.Second, 200*time.Millisecond) - require.Equal(t, tc.expectedAssumedRoles, sts.GetAssumedRoleARNs(), "roles incorrectly assumed") - require.Equal(t, tc.expectedExternalIDs, sts.GetAssumedRoleExternalIDs(), "external IDs incorrectly assumed") + require.ElementsMatch(t, tc.expectedAssumedRoles, sts.GetAssumedRoleARNs(), "roles incorrectly assumed") + require.ElementsMatch(t, tc.expectedExternalIDs, sts.GetAssumedRoleExternalIDs(), "external IDs incorrectly assumed") if tc.wantEvents > 0 { require.Eventually(t, func() bool { @@ -770,6 +770,101 @@ func TestDiscoveryKube(t *testing.T) { } } +func TestDiscoveryServer_New(t *testing.T) { + t.Parallel() + testCases := []struct { + desc string + cloudClients cloud.Clients + matchers []types.AWSMatcher + errAssertion require.ErrorAssertionFunc + discServerAssertion require.ValueAssertionFunc + }{ + { + desc: "no matchers error", + cloudClients: &cloud.TestCloudClients{STS: &mocks.STSMock{}}, + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorIs(t, err, &trace.BadParameterError{Message: "no matchers configured for discovery"}) + }, + discServerAssertion: require.Nil, + }, + { + desc: "success with EKS matcher", + cloudClients: &cloud.TestCloudClients{STS: &mocks.STSMock{}, EKS: &mocks.EKSMock{}}, + matchers: []types.AWSMatcher{ + { + Types: []string{"eks"}, + Regions: []string{"eu-west-1"}, + Tags: map[string]utils.Strings{"env": {"prod"}}, + AssumeRole: &types.AssumeRole{ + RoleARN: "arn:aws:iam::123456789012:role/teleport-role", + ExternalID: "external-id", + }, + }, + }, + errAssertion: require.NoError, + discServerAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.NotNil(t, i) + val, ok := i.(*Server) + require.True(t, ok) + require.Len(t, val.kubeFetchers, 1, "unexpected amount of kube fetchers") + }, + }, + { + desc: "EKS fetcher is skipped on initialization error (missing region)", + cloudClients: &cloud.TestCloudClients{ + STS: &mocks.STSMock{}, + EKS: &mocks.EKSMock{}, + }, + matchers: []types.AWSMatcher{ + { + Types: []string{"eks"}, + Regions: []string{}, + Tags: map[string]utils.Strings{"env": {"prod"}}, + AssumeRole: &types.AssumeRole{ + RoleARN: "arn:aws:iam::123456789012:role/teleport-role", + ExternalID: "external-id", + }, + }, + { + Types: []string{"eks"}, + Regions: []string{"eu-west-1"}, + Tags: map[string]utils.Strings{"env": {"staging"}}, + AssumeRole: &types.AssumeRole{ + RoleARN: "arn:aws:iam::55555555555:role/teleport-role", + ExternalID: "external-id2", + }, + }, + }, + errAssertion: require.NoError, + discServerAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.NotNil(t, i) + val, ok := i.(*Server) + require.True(t, ok) + require.Len(t, val.kubeFetchers, 1, "unexpected amount of kube fetchers") + }, + }, + } + + for _, tt := range testCases { + t.Run(tt.desc, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + discServer, err := New( + ctx, + &Config{ + Clients: nil, + AccessPoint: newFakeAccessPoint(), + AWSMatchers: tt.matchers, + Emitter: &mockEmitter{}, + }) + + tt.errAssertion(t, err) + tt.discServerAssertion(t, discServer) + }) + } +} + type mockAKSAPI struct { azure.AKSClient group map[string][]*azure.AKSCluster @@ -1923,3 +2018,31 @@ func (f *fakeAccessPoint) UpsertServerInfo(ctx context.Context, si types.ServerI f.upsertedServerInfos <- si return nil } + +func (f *fakeAccessPoint) NewWatcher(ctx context.Context, watch types.Watch) (types.Watcher, error) { + return newFakeWatcher(), nil +} + +type fakeWatcher struct { +} + +func newFakeWatcher() fakeWatcher { + + return fakeWatcher{} +} + +func (m fakeWatcher) Events() <-chan types.Event { + return make(chan types.Event) +} + +func (m fakeWatcher) Done() <-chan struct{} { + return make(chan struct{}) +} + +func (m fakeWatcher) Close() error { + return nil +} + +func (m fakeWatcher) Error() error { + return nil +} diff --git a/lib/srv/discovery/fetchers/eks.go b/lib/srv/discovery/fetchers/eks.go index 8e458369bdd5f..9542b5a9ab4e3 100644 --- a/lib/srv/discovery/fetchers/eks.go +++ b/lib/srv/discovery/fetchers/eks.go @@ -29,6 +29,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -39,12 +40,24 @@ const ( type eksFetcher struct { EKSFetcherConfig + + mu sync.Mutex + client eksiface.EKSAPI +} + +// EKSClientGetter is an interface for getting an EKS client. +type EKSClientGetter interface { + // GetAWSEKSClient returns AWS EKS client for the specified region. + GetAWSEKSClient(ctx context.Context, region string, opts ...cloud.AWSAssumeRoleOptionFn) (eksiface.EKSAPI, error) } // EKSFetcherConfig configures the EKS fetcher. type EKSFetcherConfig struct { - // Client is the AWS eKS client. - Client eksiface.EKSAPI + // EKSClientGetter retrieves an EKS client. + EKSClientGetter EKSClientGetter + // AssumeRole provides a role ARN and ExternalID to assume an AWS role + // when fetching clusters. + AssumeRole types.AssumeRole // Region is the region where the clusters should be located. Region string // FilterLabels are the filter criteria. @@ -55,8 +68,8 @@ type EKSFetcherConfig struct { // CheckAndSetDefaults validates and sets the defaults values. func (c *EKSFetcherConfig) CheckAndSetDefaults() error { - if c.Client == nil { - return trace.BadParameter("missing Client field") + if c.EKSClientGetter == nil { + return trace.BadParameter("missing EKSClientGetter field") } if len(c.Region) == 0 { return trace.BadParameter("missing Region field") @@ -78,7 +91,31 @@ func NewEKSFetcher(cfg EKSFetcherConfig) (common.Fetcher, error) { return nil, trace.Wrap(err) } - return &eksFetcher{cfg}, nil + return &eksFetcher{EKSFetcherConfig: cfg}, nil +} + +func (a *eksFetcher) getClient(ctx context.Context) (eksiface.EKSAPI, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.client != nil { + return a.client, nil + } + + client, err := a.EKSClientGetter.GetAWSEKSClient( + ctx, + a.Region, + cloud.WithAssumeRole( + a.AssumeRole.RoleARN, + a.AssumeRole.ExternalID, + ), + ) + if err != nil { + return nil, trace.Wrap(err) + } + a.client = client + + return a.client, nil } func (a *eksFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, error) { @@ -104,7 +141,12 @@ func (a *eksFetcher) getEKSClusters(ctx context.Context) (types.KubeClusters, er ) group.SetLimit(concurrencyLimit) - err := a.Client.ListClustersPagesWithContext(ctx, + client, err := a.getClient(ctx) + if err != nil { + return nil, trace.Wrap(err, "failed getting AWS EKS client") + } + + err = client.ListClustersPagesWithContext(ctx, &eks.ListClustersInput{ Include: nil, // For now we should only list EKS clusters }, @@ -159,7 +201,12 @@ func (a *eksFetcher) String() string { // If any cluster does not match the filtering criteria, this function returns a “trace.CompareFailed“ error // to distinguish filtering and operational errors. func (a *eksFetcher) getMatchingKubeCluster(ctx context.Context, clusterName string) (types.KubeCluster, error) { - rsp, err := a.Client.DescribeClusterWithContext( + client, err := a.getClient(ctx) + if err != nil { + return nil, trace.Wrap(err, "failed getting AWS EKS client") + } + + rsp, err := client.DescribeClusterWithContext( ctx, &eks.DescribeClusterInput{ Name: aws.String(clusterName), diff --git a/lib/srv/discovery/fetchers/eks_test.go b/lib/srv/discovery/fetchers/eks_test.go index b099c3726a82a..e21e862758fa2 100644 --- a/lib/srv/discovery/fetchers/eks_test.go +++ b/lib/srv/discovery/fetchers/eks_test.go @@ -29,6 +29,7 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/services" ) @@ -98,10 +99,10 @@ func TestEKSFetcher(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := EKSFetcherConfig{ - Client: newPopulatedEKSMock(), - FilterLabels: tt.args.filterLabels, - Region: tt.args.region, - Log: logrus.New(), + EKSClientGetter: &mockEKSClientGetter{}, + FilterLabels: tt.args.filterLabels, + Region: tt.args.region, + Log: logrus.New(), } fetcher, err := NewEKSFetcher(cfg) require.NoError(t, err) @@ -113,6 +114,12 @@ func TestEKSFetcher(t *testing.T) { } } +type mockEKSClientGetter struct{} + +func (e *mockEKSClientGetter) GetAWSEKSClient(ctx context.Context, region string, opts ...cloud.AWSAssumeRoleOptionFn) (eksiface.EKSAPI, error) { + return newPopulatedEKSMock(), nil +} + type mockEKSAPI struct { eksiface.EKSAPI clusters []*eks.Cluster