diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index 6650a195bc774..7000c4fd367ec 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -33,6 +33,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription" "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/cloud/gcp" ) @@ -89,8 +90,41 @@ type AzureClients interface { GetAzureRunCommandClient(subscription string) (azure.RunCommandClient, error) } +// AzureClientsOption is an option to pass to NewAzureClients +type AzureClientsOption func(clients *azureClients) + +type azureOIDCCredentials interface { + GenerateAzureOIDCToken(ctx context.Context, integration string) (string, error) + GetIntegration(ctx context.Context, name string) (types.Integration, error) +} + +// WithAzureIntegrationCredentials configures Azure cloud clients to use integration credentials. +func WithAzureIntegrationCredentials(integrationName string, auth azureOIDCCredentials) AzureClientsOption { + return func(clt *azureClients) { + clt.newAzureCredentialFunc = func() (azcore.TokenCredential, error) { + ctx := context.TODO() + integration, err := auth.GetIntegration(ctx, integrationName) + if err != nil { + return nil, trace.Wrap(err) + } + spec := integration.GetAzureOIDCIntegrationSpec() + if spec == nil { + return nil, trace.BadParameter("expected %q to be an %q integration, was %q instead", integration.GetName(), types.IntegrationSubKindAzureOIDC, integration.GetSubKind()) + } + cred, err := azidentity.NewClientAssertionCredential(spec.TenantID, spec.ClientID, func(ctx context.Context) (string, error) { + return auth.GenerateAzureOIDCToken(ctx, integrationName) + // TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options + }, nil) + if err != nil { + return nil, trace.Wrap(err) + } + return cred, nil + } + } +} + // NewAzureClients returns a new instance of Azure SDK clients. -func NewAzureClients() (AzureClients, error) { +func NewAzureClients(opts ...AzureClientsOption) (AzureClients, error) { azClients := &azureClients{ azureMySQLClients: make(map[string]azure.DBServersClient), azurePostgresClients: make(map[string]azure.DBServersClient), @@ -130,6 +164,15 @@ func NewAzureClients() (AzureClients, error) { return nil, trace.Wrap(err) } + azClients.newAzureCredentialFunc = func() (azcore.TokenCredential, error) { + // TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options + return azidentity.NewDefaultAzureCredential(nil) + } + + for _, opt := range opts { + opt(azClients) + } + return azClients, nil } @@ -168,8 +211,11 @@ type azureClients struct { // mtx is used for locking. mtx sync.RWMutex + // newAzureCredentialFunc creates new Azure credential. + newAzureCredentialFunc func() (azcore.TokenCredential, error) // azureCredential is the cached Azure credential. azureCredential azcore.TokenCredential + // azureMySQLClients is the cached Azure MySQL Server clients. azureMySQLClients map[string]azure.DBServersClient // azurePostgresClients is the cached Azure Postgres Server clients. @@ -378,9 +424,8 @@ func (c *azureClients) initAzureCredential() (azcore.TokenCredential, error) { if c.azureCredential != nil { // If some other thread already got here first. return c.azureCredential, nil } - // TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options - options := &azidentity.DefaultAzureCredentialOptions{} - cred, err := azidentity.NewDefaultAzureCredential(options) + + cred, err := c.newAzureCredentialFunc() if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/cloud/clients_test.go b/lib/cloud/clients_test.go new file mode 100644 index 0000000000000..f5faf1c5727e3 --- /dev/null +++ b/lib/cloud/clients_test.go @@ -0,0 +1,123 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cloud + +import ( + "context" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" +) + +type testAzureOIDCCredentials struct { + integration types.Integration +} + +func (m *testAzureOIDCCredentials) GenerateAzureOIDCToken(ctx context.Context, integration string) (string, error) { + return "dummy-oidc-token", nil +} + +func (m *testAzureOIDCCredentials) GetIntegration(ctx context.Context, name string) (types.Integration, error) { + if m.integration == nil || m.integration.GetName() != name { + return nil, trace.NotFound("integration %q not found", name) + } + return m.integration, nil +} + +func TestWithAzureIntegrationCredentials(t *testing.T) { + const integrationName = "azure" + + tests := []struct { + name string + integration types.Integration + wantErr string + }{ + { + name: "valid azure integration", + integration: &types.IntegrationV1{ + ResourceHeader: types.ResourceHeader{ + Kind: types.KindIntegration, + SubKind: types.IntegrationSubKindAzureOIDC, + Version: types.V1, + Metadata: types.Metadata{ + Name: integrationName, + Namespace: defaults.Namespace, + }, + }, + Spec: types.IntegrationSpecV1{ + SubKindSpec: &types.IntegrationSpecV1_AzureOIDC{ + AzureOIDC: &types.AzureOIDCIntegrationSpecV1{ + ClientID: "baz-quux", + TenantID: "foo-bar", + }, + }, + }, + }, + }, + { + name: "integration not found", + integration: nil, + wantErr: `integration "azure" not found`, + }, + { + name: "invalid integration type", + integration: &types.IntegrationV1{ + ResourceHeader: types.ResourceHeader{ + Kind: types.KindIntegration, + SubKind: types.IntegrationSubKindAWSOIDC, + Version: types.V1, + Metadata: types.Metadata{ + Name: "azure", + Namespace: defaults.Namespace, + }, + }, + Spec: types.IntegrationSpecV1{ + SubKindSpec: &types.IntegrationSpecV1_AWSOIDC{ + AWSOIDC: &types.AWSOIDCIntegrationSpecV1{ + RoleARN: "arn:aws:iam::123456789012:role/teleport", + }, + }, + }, + }, + wantErr: `expected "azure" to be an "azure-oidc" integration, was "aws-oidc" instead`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opt := WithAzureIntegrationCredentials(integrationName, &testAzureOIDCCredentials{ + integration: tt.integration, + }) + clients, err := NewAzureClients(opt) + require.NoError(t, err) + + credential, err := clients.GetAzureCredential() + + if tt.wantErr == "" { + require.NoError(t, err) + require.NotNil(t, credential) + } else { + require.ErrorContains(t, err, tt.wantErr) + require.Nil(t, credential) + } + }) + } +} diff --git a/lib/config/configuration.go b/lib/config/configuration.go index c70c1844f2fc3..eb2ca63c36b9d 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -1632,6 +1632,7 @@ func applyDiscoveryConfig(fc *FileConfig, cfg *servicecfg.Config) error { Types: matcher.Types, Regions: matcher.Regions, ResourceTags: matcher.ResourceTags, + Integration: matcher.Integration, Params: installParams, } if err := serviceMatcher.CheckAndSetDefaults(); err != nil { @@ -1834,6 +1835,7 @@ func applyDatabasesConfig(fc *FileConfig, cfg *servicecfg.Config) error { Types: matcher.Types, Regions: matcher.Regions, ResourceTags: matcher.ResourceTags, + Integration: matcher.Integration, }) } for _, database := range fc.Databases.Databases { diff --git a/lib/config/configuration_test.go b/lib/config/configuration_test.go index cb3d0c406638e..bec404ba4e49a 100644 --- a/lib/config/configuration_test.go +++ b/lib/config/configuration_test.go @@ -387,6 +387,7 @@ func TestConfigReading(t *testing.T) { }, ResourceGroups: []string{"group1"}, Subscriptions: []string{"sub1"}, + Integration: "integration1", }, }, GCPMatchers: []GCPMatcher{ @@ -520,6 +521,7 @@ func TestConfigReading(t *testing.T) { ResourceGroups: []string{"rg1", "rg2"}, Types: []string{"mysql"}, Regions: []string{"eastus", "westus"}, + Integration: "integration1", ResourceTags: map[string]apiutils.Strings{ "a": {"b"}, }, @@ -878,6 +880,7 @@ SREzU8onbBsjMg9QDiSf5oJLKvd/Ren+zGY7 ResourceGroups: []string{"group1", "group2"}, Types: []string{"postgres", "mysql"}, Regions: []string{"eastus", "centralus"}, + Integration: "integration123", ResourceTags: map[string]apiutils.Strings{ "a": {"b"}, }, @@ -1602,6 +1605,7 @@ func makeConfigFixture() string { }, ResourceGroups: []string{"group1"}, Subscriptions: []string{"sub1"}, + Integration: "integration1", }, } @@ -1721,6 +1725,7 @@ func makeConfigFixture() string { ResourceTags: map[string]apiutils.Strings{ "a": {"b"}, }, + Integration: "integration1", }, { Subscriptions: []string{"sub3", "sub4"}, @@ -3901,6 +3906,7 @@ func TestApplyDiscoveryConfig(t *testing.T) { }, Suffix: "blue", }, + Integration: "integration123", }, }, }, @@ -3922,6 +3928,7 @@ func TestApplyDiscoveryConfig(t *testing.T) { }, Regions: []string{"*"}, ResourceTags: types.Labels{"*": []string{"*"}}, + Integration: "integration123", ResourceGroups: []string{"*"}, }, }, @@ -4629,8 +4636,9 @@ func TestDiscoveryConfig(t *testing.T) { cfg["discovery_service"].(cfgMap)["enabled"] = "yes" cfg["discovery_service"].(cfgMap)["azure"] = []cfgMap{ { - "types": []string{"aks"}, - "regions": []string{"eucentral1"}, + "types": []string{"aks"}, + "regions": []string{"eucentral1"}, + "integration": "integration1", "tags": cfgMap{ "discover_teleport": "yes", }, @@ -4640,8 +4648,9 @@ func TestDiscoveryConfig(t *testing.T) { } }, expectedAzureMatchers: []types.AzureMatcher{{ - Types: []string{"aks"}, - Regions: []string{"eucentral1"}, + Types: []string{"aks"}, + Regions: []string{"eucentral1"}, + Integration: "integration1", ResourceTags: map[string]apiutils.Strings{ "discover_teleport": []string{"yes"}, }, diff --git a/lib/config/fileconf.go b/lib/config/fileconf.go index c49c9f7518cc1..b6bb485b02919 100644 --- a/lib/config/fileconf.go +++ b/lib/config/fileconf.go @@ -1959,6 +1959,8 @@ type AzureMatcher struct { Regions []string `yaml:"regions,omitempty"` // ResourceTags are Azure tags on resources to match. ResourceTags map[string]apiutils.Strings `yaml:"tags,omitempty"` + // Integration is the Azure Integration name. + Integration string `yaml:"integration,omitempty"` // InstallParams sets the join method when installing on // discovered Azure nodes. InstallParams *InstallParams `yaml:"install,omitempty"` diff --git a/lib/config/testdata_test.go b/lib/config/testdata_test.go index 83268682874b2..5571fa08268c1 100644 --- a/lib/config/testdata_test.go +++ b/lib/config/testdata_test.go @@ -181,6 +181,7 @@ db_service: resource_groups: ["group1", "group2"] types: ["postgres", "mysql"] regions: ["eastus", "centralus"] + integration: integration123 tags: "a": "b" - types: ["postgres", "mysql"] diff --git a/lib/services/matchers.go b/lib/services/matchers.go index 13edcd07e3402..9108b29c185e4 100644 --- a/lib/services/matchers.go +++ b/lib/services/matchers.go @@ -98,14 +98,13 @@ func SimplifyAzureMatchers(matchers []types.AzureMatcher) []types.AzureMatcher { regions[i] = azureutils.NormalizeLocation(region) } } - result = append(result, types.AzureMatcher{ - Subscriptions: subs, - ResourceGroups: groups, - Regions: regions, - Types: ts, - ResourceTags: m.ResourceTags, - Params: m.Params, - }) + elem := m + elem.Subscriptions = subs + elem.ResourceGroups = groups + elem.Regions = regions + elem.Types = ts + + result = append(result, elem) } return result } diff --git a/lib/services/matchers_test.go b/lib/services/matchers_test.go index 8ac2ca73eaf07..9f8c4b36c48b0 100644 --- a/lib/services/matchers_test.go +++ b/lib/services/matchers_test.go @@ -139,6 +139,63 @@ func TestMatchResourceLabels(t *testing.T) { } } +func TestSimplifyAzureMatchers(t *testing.T) { + matchers := []types.AzureMatcher{ + { + Subscriptions: []string{"sub-1", types.Wildcard, "sub-1"}, + Regions: []string{"eu-west-1", "eu-west-2"}, + Types: []string{"mysql", "mysql", "postgres"}, + ResourceTags: types.Labels{"env": []string{"prod"}}, + Params: &types.InstallerParams{ + JoinMethod: types.JoinMethodAzure, + JoinToken: "token-1", + Azure: &types.AzureInstallerParams{ + ClientID: "client-1", + }, + }, + Integration: "integration-1", + }, + { + ResourceGroups: []string{ + "rg-1", + types.Wildcard, + "rg-1", + }, + Types: []string{"redis"}, + Integration: "integration-2", + }, + } + + simplified := SimplifyAzureMatchers(matchers) + + want := []types.AzureMatcher{ + { + Subscriptions: []string{types.Wildcard}, + ResourceGroups: []string{types.Wildcard}, + Regions: []string{"eu-west-1", "eu-west-2"}, + Types: []string{"mysql", "postgres"}, + ResourceTags: types.Labels{"env": []string{"prod"}}, + Params: &types.InstallerParams{ + JoinMethod: types.JoinMethodAzure, + JoinToken: "token-1", + Azure: &types.AzureInstallerParams{ + ClientID: "client-1", + }, + }, + Integration: "integration-1", + }, + { + Subscriptions: []string{types.Wildcard}, + ResourceGroups: []string{types.Wildcard}, + Regions: []string{types.Wildcard}, + Types: []string{"redis"}, + Integration: "integration-2", + }, + } + + require.Equal(t, want, simplified) +} + func TestMatchResourceByFilters_Helper(t *testing.T) { t.Parallel() diff --git a/lib/srv/db/watcher.go b/lib/srv/db/watcher.go index 1ff2e2cc10aef..d54a1fa72cca6 100644 --- a/lib/srv/db/watcher.go +++ b/lib/srv/db/watcher.go @@ -26,6 +26,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/readonly" discovery "github.com/gravitational/teleport/lib/srv/discovery/common" @@ -117,7 +118,12 @@ func (s *Server) startCloudWatcher(ctx context.Context) error { if err != nil { return trace.Wrap(err) } - azureFetchers, err := dbfetchers.MakeAzureFetchers(s.cfg.AzureClients, s.cfg.AzureMatchers, "" /* discovery config */) + azureFetchers, err := dbfetchers.MakeAzureFetchers(ctx, func(ctx context.Context, integration string) (cloud.AzureClients, error) { + if integration != "" { + return nil, trace.NotImplemented("db_service discovery does not support Azure OIDC authentication; use discovery_service instead.") + } + return s.cfg.AzureClients, nil + }, s.cfg.AzureMatchers, "" /* discovery config */) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/discovery/config_test.go b/lib/srv/discovery/config_test.go index 68b790a6d423d..affd1c4680250 100644 --- a/lib/srv/discovery/config_test.go +++ b/lib/srv/discovery/config_test.go @@ -49,7 +49,7 @@ func TestConfigCheckAndSetDefaults(t *testing.T) { errAssertFunc: require.NoError, cfgChange: func(c *Config) {}, postCheckAndSetDefaultsFunc: func(t *testing.T, c *Config) { - require.NotNil(t, c.azureClients) + require.NotNil(t, c.initAzureClients) require.NotNil(t, c.gcpClients) require.NotNil(t, c.AWSConfigProvider) require.NotNil(t, c.AWSDatabaseFetcherFactory) diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index d1070c04272fb..7c5c294cee6b2 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -64,6 +64,7 @@ import ( azure_sync "github.com/gravitational/teleport/lib/srv/discovery/fetchers/azuresync" "github.com/gravitational/teleport/lib/srv/discovery/fetchers/db" "github.com/gravitational/teleport/lib/srv/server" + "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/aws/stsutils" logutils "github.com/gravitational/teleport/lib/utils/log" libslices "github.com/gravitational/teleport/lib/utils/slices" @@ -187,8 +188,8 @@ type Config struct { // It is used to add Expiration times to Resources that don't support Heartbeats (eg EICE Nodes). jitter retryutils.Jitter - // azureClients is a reference to Azure clients. - azureClients cloud.AzureClients + // initAzureClients initializes an instance of Azure clients with particular options. + initAzureClients func(opts ...cloud.AzureClientsOption) (cloud.AzureClients, error) // gcpClients is a reference to GCP clients. gcpClients cloud.GCPClients } @@ -241,12 +242,8 @@ func (c *Config) CheckAndSetDefaults() error { kubernetes matchers are present.`) } - if c.azureClients == nil { - azureClients, err := cloud.NewAzureClients() - if err != nil { - return trace.Wrap(err) - } - c.azureClients = azureClients + if c.initAzureClients == nil { + c.initAzureClients = cloud.NewAzureClients } if c.gcpClients == nil { @@ -460,6 +457,9 @@ type Server struct { // usageEventCache keeps track of which instances the server has emitted // usage events for. usageEventCache map[string]struct{} + + // azureClientCache caches instances of integration-specific Azure clients. + azureClientCache *utils.FnCache } // New initializes a discovery Server @@ -738,7 +738,7 @@ func (s *Server) azureServerFetchersFromMatchers(matchers []types.AzureMatcher, return matcherType == types.AzureMatcherVM }) - return server.MatchersToAzureInstanceFetchers(s.Log, serverMatchers, s.azureClients, discoveryConfigName) + return server.MatchersToAzureInstanceFetchers(s.Log, serverMatchers, s.getAzureClients, discoveryConfigName) } // gcpServerFetchersFromMatchers converts Matchers into a set of GCP Servers Fetchers. @@ -782,7 +782,7 @@ func (s *Server) databaseFetchersFromMatchers(matchers Matchers, discoveryConfig // Azure azureDatabaseMatchers, _ := splitMatchers(matchers.Azure, db.IsAzureMatcherType) if len(azureDatabaseMatchers) > 0 { - databaseFetchers, err := db.MakeAzureFetchers(s.azureClients, azureDatabaseMatchers, discoveryConfigName) + databaseFetchers, err := db.MakeAzureFetchers(s.ctx, s.getAzureClients, azureDatabaseMatchers, discoveryConfigName) if err != nil { return nil, trace.Wrap(err) } @@ -815,13 +815,54 @@ func (s *Server) kubeFetchersFromMatchers(matchers Matchers, discoveryConfigName return result, nil } +// getAzureClients returns an instance of AzureClients made to work with particular integration. +// If integration argument is empty, ambient credentials will be used instead. This is the default mode. +// +// The returned instance is cached for a period of time, so subsequent calls may return the same object. +func (s *Server) getAzureClients(ctx context.Context, integration string) (cloud.AzureClients, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.azureClientCache == nil { + azureClientCache, err := utils.NewFnCache(utils.FnCacheConfig{ + TTL: time.Minute * 15, + Clock: s.clock, + }) + if err != nil { + return nil, trace.Wrap(err) + } + s.azureClientCache = azureClientCache + } + + // sanity check: this shouldn't happen as matchers are pre-filtered when running in integration-credentials-only mode. + if integration == "" && s.IntegrationOnlyCredentials { + return nil, trace.BadParameter("cannot create Azure clients with ambient credentials due configuration (this is a bug)") + } + + out, err := utils.FnCacheGet(ctx, s.azureClientCache, integration, func(ctx context.Context) (cloud.AzureClients, error) { + var opts []cloud.AzureClientsOption + if integration != "" { + opts = append(opts, cloud.WithAzureIntegrationCredentials(integration, s.AccessPoint)) + } + azureClients, err := s.initAzureClients(opts...) + if err != nil { + return nil, trace.Wrap(err) + } + return azureClients, nil + }) + if err != nil { + return nil, trace.Wrap(err) + } + return out, nil +} + // initAzureWatchers starts Azure resource watchers based on types provided. func (s *Server) initAzureWatchers(ctx context.Context, matchers []types.AzureMatcher, discoveryConfigName string) error { vmMatchers, otherMatchers := splitMatchers(matchers, func(matcherType string) bool { return matcherType == types.AzureMatcherVM }) - s.staticServerAzureFetchers = server.MatchersToAzureInstanceFetchers(s.Log, vmMatchers, s.azureClients, discoveryConfigName) + s.staticServerAzureFetchers = server.MatchersToAzureInstanceFetchers(s.Log, vmMatchers, s.getAzureClients, discoveryConfigName) // VM watcher. var err error @@ -845,7 +886,7 @@ func (s *Server) initAzureWatchers(ctx context.Context, matchers []types.AzureMa // Add kube fetchers. for _, matcher := range otherMatchers { - subscriptions, err := s.getAzureSubscriptions(ctx, matcher.Subscriptions) + subscriptions, err := s.getAzureSubscriptions(ctx, matcher.Integration, matcher.Subscriptions) if err != nil { return trace.Wrap(err) } @@ -853,10 +894,15 @@ func (s *Server) initAzureWatchers(ctx context.Context, matchers []types.AzureMa for _, t := range matcher.Types { switch t { case types.AzureMatcherKubernetes: - kubeClient, err := s.azureClients.GetAzureKubernetesClient(subscription) + azureClients, err := s.getAzureClients(ctx, matcher.Integration) if err != nil { return trace.Wrap(err) } + kubeClient, err := azureClients.GetAzureKubernetesClient(subscription) + if err != nil { + return trace.Wrap(err) + } + fetcher, err := fetchers.NewAKSFetcher(fetchers.AKSFetcherConfig{ Client: kubeClient, Regions: matcher.Regions, @@ -864,6 +910,7 @@ func (s *Server) initAzureWatchers(ctx context.Context, matchers []types.AzureMa ResourceGroups: matcher.ResourceGroups, Logger: s.Log, DiscoveryConfigName: discoveryConfigName, + Integration: matcher.Integration, }) if err != nil { return trace.Wrap(err) @@ -1358,7 +1405,12 @@ outer: } func (s *Server) handleAzureInstances(instances *server.AzureInstances) error { - runClient, err := s.azureClients.GetAzureRunCommandClient(instances.SubscriptionID) + azureClients, err := s.getAzureClients(s.ctx, instances.Integration) + if err != nil { + return trace.Wrap(err) + } + + runClient, err := azureClients.GetAzureRunCommandClient(instances.SubscriptionID) if err != nil { return trace.Wrap(err) } @@ -1877,19 +1929,19 @@ func (s *Server) upsertDynamicMatchers(ctx context.Context, dc *discoveryconfig. return nil } -// discardUnsupportedMatchers drops any matcher that is not supported in the current DiscoveryService. -// Discarded Matchers: -// - when running in IntegrationOnlyCredentials mode, any Matcher that doesn't have an Integration is discarded. func (s *Server) discardUnsupportedMatchers(m *Matchers) { - if !s.IntegrationOnlyCredentials { - return + if s.IntegrationOnlyCredentials { + discardAmbientCredentialMatchers(s.ctx, s.Log, m) } +} +// discardAmbientCredentialMatchers drops any matcher that depends on ambient credentials (and not integration). +func discardAmbientCredentialMatchers(ctx context.Context, log *slog.Logger, m *Matchers) { // Discard all matchers that don't have an Integration validAWSMatchers := make([]types.AWSMatcher, 0, len(m.AWS)) for i, m := range m.AWS { if m.Integration == "" { - s.Log.WarnContext(s.ctx, "Discarding AWS matcher - missing integration", "matcher_pos", i) + log.WarnContext(ctx, "Discarding AWS matcher - missing integration", "matcher_pos", i) continue } validAWSMatchers = append(validAWSMatchers, m) @@ -1897,17 +1949,21 @@ func (s *Server) discardUnsupportedMatchers(m *Matchers) { m.AWS = validAWSMatchers if len(m.GCP) > 0 { - s.Log.WarnContext(s.ctx, "Discarding GCP matchers - missing integration") + log.WarnContext(ctx, "Discarding GCP matchers - missing integration") m.GCP = []types.GCPMatcher{} } - if len(m.Azure) > 0 { - s.Log.WarnContext(s.ctx, "Discarding Azure matchers - missing integration") - m.Azure = []types.AzureMatcher{} + filtered := slices.DeleteFunc(m.Azure, func(matcher types.AzureMatcher) bool { + return matcher.Integration == "" + }) + discarded := len(m.Azure) - len(filtered) + if discarded > 0 { + m.Azure = filtered + log.WarnContext(ctx, "Discarded Azure matchers without integration", "count", discarded) } if len(m.Kubernetes) > 0 { - s.Log.WarnContext(s.ctx, "Discarding Kubernetes matchers - missing integration") + log.WarnContext(ctx, "Discarding Kubernetes matchers - missing integration") m.Kubernetes = []types.KubernetesMatcher{} } } @@ -1943,10 +1999,14 @@ func (s *Server) Wait() error { return nil } -func (s *Server) getAzureSubscriptions(ctx context.Context, subs []string) ([]string, error) { +func (s *Server) getAzureSubscriptions(ctx context.Context, integration string, subs []string) ([]string, error) { subscriptionIds := subs if slices.Contains(subs, types.Wildcard) { - subsClient, err := s.azureClients.GetAzureSubscriptionClient() + azureClients, err := s.getAzureClients(ctx, integration) + if err != nil { + return nil, trace.Wrap(err) + } + subsClient, err := azureClients.GetAzureSubscriptionClient() if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 068e94e93aaf3..272dc9f25b5c0 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -84,6 +84,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authtest" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/cloud/cloudtest" @@ -100,6 +101,7 @@ import ( "github.com/gravitational/teleport/lib/srv/discovery/fetchers/db" "github.com/gravitational/teleport/lib/srv/server" usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport" + libutils "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/log/logtest" "github.com/gravitational/teleport/lib/utils/testutils/synctest" ) @@ -1546,13 +1548,14 @@ func TestDiscoveryInCloudKube(t *testing.T) { Regions: []string{types.Wildcard}, ResourceGroups: []string{types.Wildcard}, Subscriptions: []string{"sub1"}, + Integration: "dummy-azure-integration", }, }, expectedClustersToExistInAuth: []types.KubeCluster{ mustConvertEKSToKubeCluster(t, eksMockClusters[0], rewriteDiscoveryLabelsParams{discoveryGroup: mainDiscoveryGroup}), mustConvertEKSToKubeCluster(t, eksMockClusters[1], rewriteDiscoveryLabelsParams{discoveryGroup: mainDiscoveryGroup}), - mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][0], rewriteDiscoveryLabelsParams{discoveryGroup: mainDiscoveryGroup}), - mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][1], rewriteDiscoveryLabelsParams{discoveryGroup: mainDiscoveryGroup}), + mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][0], rewriteDiscoveryLabelsParams{discoveryGroup: mainDiscoveryGroup, integration: "dummy-azure-integration"}), + mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][1], rewriteDiscoveryLabelsParams{discoveryGroup: mainDiscoveryGroup, integration: "dummy-azure-integration"}), }, clustersNotUpdated: []string{mustConvertAKSToKubeCluster(t, aksMockClusters["group1"][0], rewriteDiscoveryLabelsParams{discoveryGroup: mainDiscoveryGroup}).GetName()}, wantEvents: 2, @@ -1598,8 +1601,10 @@ func TestDiscoveryInCloudKube(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - azureClients := &cloudtest.AzureClients{ - AzureAKSClient: newPopulatedAKSMock(), + initAzureClients := func(opts ...cloud.AzureClientsOption) (cloud.AzureClients, error) { + return &cloudtest.AzureClients{ + AzureAKSClient: newPopulatedAKSMock(), + }, nil } gcpClients := &cloudtest.GCPClients{ @@ -1675,7 +1680,7 @@ func TestDiscoveryInCloudKube(t *testing.T) { discServer, err := New( authz.ContextWithUser(ctx, identity.I), &Config{ - azureClients: azureClients, + initAzureClients: initAzureClients, gcpClients: gcpClients, AWSFetchersClients: mockedClients, ClusterFeatures: func() proto.Features { return proto.Features{} }, @@ -1692,12 +1697,12 @@ func TestDiscoveryInCloudKube(t *testing.T) { }) require.NoError(t, err) + require.NoError(t, discServer.Start()) t.Cleanup(discServer.Stop) - go discServer.Start() clustersNotUpdatedMap := sliceToSet(tc.clustersNotUpdated) clustersFoundInAuth := false - require.Eventually(t, func() bool { + require.EventuallyWithT(t, func(c *assert.CollectT) { loop: for { select { @@ -1708,13 +1713,13 @@ func TestDiscoveryInCloudKube(t *testing.T) { delete(clustersNotUpdatedMap, cluster) default: kubeClusters, err := tlsServer.Auth().GetKubernetesClusters(ctx) - require.NoError(t, err) + require.NoError(c, err) if len(kubeClusters) == len(tc.expectedClustersToExistInAuth) { - c1 := types.KubeClusters(kubeClusters).ToMap() - c2 := types.KubeClusters(tc.expectedClustersToExistInAuth).ToMap() + c1 := types.KubeClusters(tc.expectedClustersToExistInAuth).ToMap() + c2 := types.KubeClusters(kubeClusters).ToMap() for k := range c1 { if services.CompareResources(c1[k], c2[k]) != services.Equal { - return false + require.Equal(c, c1[k], c2[k], "expected no differences") } } clustersFoundInAuth = true @@ -1722,7 +1727,8 @@ func TestDiscoveryInCloudKube(t *testing.T) { break loop } } - return len(clustersNotUpdated) == 0 && clustersFoundInAuth + require.Empty(c, clustersNotUpdated) + require.True(c, clustersFoundInAuth) }, 5*time.Second, 200*time.Millisecond) require.ElementsMatch(t, tc.expectedAssumedRoles, mockedClients.STSClient.GetAssumedRoleARNs(), "roles incorrectly assumed") @@ -1737,6 +1743,15 @@ func TestDiscoveryInCloudKube(t *testing.T) { return reporter.ResourceCreateEventCount() != 0 }, time.Second, 100*time.Millisecond) } + + // verify usage of integration credentials. + for _, matcher := range tc.azureMatchers { + require.NotNil(t, discServer.azureClientCache) + _, err = libutils.FnCacheGet(t.Context(), discServer.azureClientCache, matcher.Integration, func(ctx context.Context) (cloud.AzureClients, error) { + return nil, trace.NotFound("cache key %q not found", matcher.Integration) + }) + require.NoError(t, err) + } }) } } @@ -2194,6 +2209,10 @@ func TestDiscoveryDatabase(t *testing.T) { _, awsRedshiftDBWithDiscoveryConfig := makeRedshiftCluster(t, "aws-redshift", "us-east-1", rewriteDiscoveryLabelsParams{discoveryGroup: mainDiscoveryGroup, discoveryConfigName: discoveryConfigName}) awsRDSInstance, awsRDSDB := makeRDSInstance(t, "aws-rds", "us-west-1", rewriteDiscoveryLabelsParams{discoveryGroup: mainDiscoveryGroup}) azRedisResource, azRedisDB := makeAzureRedisServer(t, "az-redis", "sub1", "group1", "East US", rewriteDiscoveryLabelsParams{discoveryGroup: mainDiscoveryGroup}) + + azRedisDBWithIntegration := azRedisDB.Copy() + rewriteCloudResource(t, azRedisDBWithIntegration, rewriteDiscoveryLabelsParams{integration: integrationName}) + _, azRedisDBWithDiscoveryConfig := makeAzureRedisServer(t, "az-redis", "sub1", "group1", "East US", rewriteDiscoveryLabelsParams{discoveryGroup: mainDiscoveryGroup, discoveryConfigName: discoveryConfigName}) role := types.AssumeRole{RoleARN: "arn:aws:iam::123456789012:role/test-role", ExternalID: "test123"} @@ -2278,6 +2297,20 @@ func TestDiscoveryDatabase(t *testing.T) { expectDatabases: []types.Database{azRedisDB}, wantEvents: 1, }, + { + name: "discover Azure database with integration", + azureMatchers: []types.AzureMatcher{{ + Types: []string{types.AzureMatcherRedis}, + ResourceTags: map[string]utils.Strings{types.Wildcard: {types.Wildcard}}, + Regions: []string{types.Wildcard}, + ResourceGroups: []string{types.Wildcard}, + Subscriptions: []string{"sub1"}, + Integration: integrationName, + }}, + expectDatabases: []types.Database{azRedisDBWithIntegration}, + wantEvents: 1, + integrationsOnlyCredentials: true, + }, { name: "update existing database", existingDatabases: []types.Database{ @@ -2631,7 +2664,9 @@ func TestDiscoveryDatabase(t *testing.T) { AWSConfigProvider: *fakeConfigProvider, eksClusters: []*ekstypes.Cluster{eksAWSResource}, }, - azureClients: azureClients, + initAzureClients: func(opts ...cloud.AzureClientsOption) (cloud.AzureClients, error) { + return azureClients, nil + }, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewClientset(), AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), @@ -2666,8 +2701,8 @@ func TestDiscoveryDatabase(t *testing.T) { }, 1*time.Second, 100*time.Millisecond) } + require.NoError(t, srv.Start()) t.Cleanup(srv.Stop) - go srv.Start() select { case <-waitForReconcile: @@ -2709,6 +2744,15 @@ func TestDiscoveryDatabase(t *testing.T) { if tc.userTasksCheck != nil { tc.userTasksCheck(t, tlsServer.Auth()) } + + // verify usage of integration credentials. + for _, matcher := range tc.azureMatchers { + require.NotNil(t, srv.azureClientCache) + _, err = libutils.FnCacheGet(t.Context(), srv.azureClientCache, matcher.Integration, func(ctx context.Context) (cloud.AzureClients, error) { + return nil, trace.NotFound("cache key %q not found", matcher.Integration) + }) + require.NoError(t, err) + } }) } } @@ -3009,18 +3053,33 @@ func (m *mockAzureInstaller) GetInstalledInstances() []string { func TestAzureVMDiscovery(t *testing.T) { t.Parallel() - defaultDiscoveryGroup := "dc001" + const defaultDiscoveryGroup = "dc001" + + const noIntegration = "" + const dummyIntegration = "dummy" vmMatcherFn := func() Matchers { return Matchers{ - Azure: []types.AzureMatcher{{ - Types: []string{"vm"}, - Subscriptions: []string{"testsub"}, - ResourceGroups: []string{"testrg"}, - Regions: []string{"westcentralus"}, - ResourceTags: types.Labels{"teleport": {"yes"}}, - Params: &types.InstallerParams{}, - }}, + Azure: []types.AzureMatcher{ + { + Types: []string{"vm"}, + Subscriptions: []string{"testsub"}, + ResourceGroups: []string{"testrg"}, + Regions: []string{"westcentralus"}, + ResourceTags: types.Labels{"teleport": {"yes"}}, + Params: &types.InstallerParams{}, + Integration: noIntegration, + }, + { + Types: []string{"vm"}, + Subscriptions: []string{"testsub"}, + ResourceGroups: []string{"testrg"}, + Regions: []string{"westcentralus"}, + ResourceTags: types.Labels{"teleport-integration": {"yes"}}, + Params: &types.InstallerParams{}, + Integration: dummyIntegration, + }, + }, } } @@ -3037,128 +3096,88 @@ func TestAzureVMDiscovery(t *testing.T) { ) require.NoError(t, err) + foundAzureVMs := func() []*armcompute.VirtualMachine { + return []*armcompute.VirtualMachine{ + { + ID: aws.String((&arm.ResourceID{ + SubscriptionID: "testsub", + ResourceGroupName: "rg", + Name: "testvm", + }).String()), + Name: aws.String("testvm"), + Location: aws.String("westcentralus"), + Tags: map[string]*string{ + "teleport": aws.String("yes"), + }, + Properties: &armcompute.VirtualMachineProperties{ + VMID: aws.String("test-vmid"), + }, + }, + { + ID: aws.String((&arm.ResourceID{ + SubscriptionID: "testsub", + ResourceGroupName: "rg", + Name: "testvm-integration", + }).String()), + Name: aws.String("testvm-integration"), + Location: aws.String("westcentralus"), + Tags: map[string]*string{ + "teleport-integration": aws.String("yes"), + }, + Properties: &armcompute.VirtualMachineProperties{ + VMID: aws.String("test-vmid-integration"), + }, + }, + } + } + + presentNode := &types.ServerV2{ + Kind: types.KindNode, + Metadata: types.Metadata{ + Name: "name", + Labels: map[string]string{ + "teleport.internal/subscription-id": "testsub", + "teleport.internal/vm-id": "test-vmid", + }, + Namespace: defaults.Namespace, + }, + } + + presentNodeAlt := presentNode.DeepCopy().(*types.ServerV2) + presentNodeAlt.Metadata.Labels["teleport.internal/vm-id"] = "alternate-vmid" + tests := []struct { - name string - presentVMs []types.Server - foundAzureVMs []*armcompute.VirtualMachine - discoveryConfig *discoveryconfig.DiscoveryConfig - staticMatchers Matchers - wantInstalledInstances []string + name string + presentVMs []types.Server + discoveryConfig *discoveryconfig.DiscoveryConfig + staticMatchers Matchers + wantInstalledInstances []string + expectedIntegrationNames []string }{ { - name: "no nodes present, 1 found", - presentVMs: []types.Server{}, - foundAzureVMs: []*armcompute.VirtualMachine{ - { - ID: aws.String((&arm.ResourceID{ - SubscriptionID: "testsub", - ResourceGroupName: "rg", - Name: "testvm", - }).String()), - Name: aws.String("testvm"), - Location: aws.String("westcentralus"), - Tags: map[string]*string{ - "teleport": aws.String("yes"), - }, - Properties: &armcompute.VirtualMachineProperties{ - VMID: aws.String("test-vmid"), - }, - }, - }, + name: "no nodes present, 1 found", + presentVMs: []types.Server{}, staticMatchers: vmMatcherFn(), - wantInstalledInstances: []string{"testvm"}, + wantInstalledInstances: []string{"testvm", "testvm-integration"}, }, { - name: "nodes present, instance filtered", - presentVMs: []types.Server{ - &types.ServerV2{ - Kind: types.KindNode, - Metadata: types.Metadata{ - Name: "name", - Labels: map[string]string{ - "teleport.internal/subscription-id": "testsub", - "teleport.internal/vm-id": "test-vmid", - }, - Namespace: defaults.Namespace, - }, - }, - }, - staticMatchers: vmMatcherFn(), - foundAzureVMs: []*armcompute.VirtualMachine{ - { - ID: aws.String((&arm.ResourceID{ - SubscriptionID: "testsub", - ResourceGroupName: "rg", - Name: "testvm", - }).String()), - Location: aws.String("westcentralus"), - Tags: map[string]*string{ - "teleport": aws.String("yes"), - }, - Properties: &armcompute.VirtualMachineProperties{ - VMID: aws.String("test-vmid"), - }, - }, - }, + name: "nodes present, instance filtered", + presentVMs: []types.Server{presentNode}, + staticMatchers: vmMatcherFn(), + wantInstalledInstances: []string{"testvm-integration"}, }, { - name: "nodes present, instance not filtered", - presentVMs: []types.Server{ - &types.ServerV2{ - Kind: types.KindNode, - Metadata: types.Metadata{ - Name: "name", - Labels: map[string]string{ - "teleport.internal/subscription-id": "testsub", - "teleport.internal/vm-id": "alternate-vmid", - }, - Namespace: defaults.Namespace, - }, - }, - }, - staticMatchers: vmMatcherFn(), - foundAzureVMs: []*armcompute.VirtualMachine{ - { - ID: aws.String((&arm.ResourceID{ - SubscriptionID: "testsub", - ResourceGroupName: "rg", - Name: "testvm", - }).String()), - Name: aws.String("testvm"), - Location: aws.String("westcentralus"), - Tags: map[string]*string{ - "teleport": aws.String("yes"), - }, - Properties: &armcompute.VirtualMachineProperties{ - VMID: aws.String("test-vmid"), - }, - }, - }, - wantInstalledInstances: []string{"testvm"}, + name: "nodes present, instance not filtered", + presentVMs: []types.Server{presentNodeAlt}, + staticMatchers: vmMatcherFn(), + wantInstalledInstances: []string{"testvm", "testvm-integration"}, }, { - name: "no nodes present, 1 found using dynamic matchers", - presentVMs: []types.Server{}, - foundAzureVMs: []*armcompute.VirtualMachine{ - { - ID: aws.String((&arm.ResourceID{ - SubscriptionID: "testsub", - ResourceGroupName: "rg", - Name: "testvm", - }).String()), - Name: aws.String("testvm"), - Location: aws.String("westcentralus"), - Tags: map[string]*string{ - "teleport": aws.String("yes"), - }, - Properties: &armcompute.VirtualMachineProperties{ - VMID: aws.String("test-vmid"), - }, - }, - }, + name: "no nodes present, 1 found using dynamic matchers", + presentVMs: []types.Server{}, discoveryConfig: defaultDiscoveryConfig, staticMatchers: Matchers{}, - wantInstalledInstances: []string{"testvm"}, + wantInstalledInstances: []string{"testvm", "testvm-integration"}, }, } @@ -3167,11 +3186,13 @@ func TestAzureVMDiscovery(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - testAzureClients := &cloudtest.AzureClients{ - AzureVirtualMachines: &mockAzureClient{ - vms: tc.foundAzureVMs, - }, - AzureRunCommand: &mockAzureRunCommandClient{}, + initAzureClients := func(opts ...cloud.AzureClientsOption) (cloud.AzureClients, error) { + return &cloudtest.AzureClients{ + AzureVirtualMachines: &mockAzureClient{ + vms: foundAzureVMs(), + }, + AzureRunCommand: &mockAzureRunCommandClient{}, + }, nil } ctx := context.Background() @@ -3196,6 +3217,7 @@ func TestAzureVMDiscovery(t *testing.T) { require.NoError(t, err) } + logtest.InitLogger(func() bool { return true }) logger := logtest.NewLogger() emitter := &mockEmitter{} @@ -3205,7 +3227,7 @@ func TestAzureVMDiscovery(t *testing.T) { } tlsServer.Auth().SetUsageReporter(reporter) server, err := New(authz.ContextWithUser(context.Background(), identity.I), &Config{ - azureClients: testAzureClients, + initAzureClients: initAzureClients, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewClientset(), AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), @@ -3229,22 +3251,28 @@ func TestAzureVMDiscovery(t *testing.T) { server.muDynamicServerAzureFetchers.RLock() defer server.muDynamicServerAzureFetchers.RUnlock() return len(server.dynamicServerAzureFetchers) > 0 - }, 1*time.Second, 100*time.Millisecond) + }, 1*time.Second, 50*time.Millisecond) } - go server.Start() + require.NoError(t, server.Start()) t.Cleanup(server.Stop) - if len(tc.wantInstalledInstances) > 0 { - require.Eventually(t, func() bool { - instances := installer.GetInstalledInstances() - slices.Sort(instances) - return slices.Equal(tc.wantInstalledInstances, instances) && len(tc.wantInstalledInstances) == reporter.ResourceCreateEventCount() - }, 500*time.Millisecond, 50*time.Millisecond) - } else { - require.Never(t, func() bool { - return len(installer.GetInstalledInstances()) > 0 || reporter.ResourceCreateEventCount() > 0 - }, 500*time.Millisecond, 50*time.Millisecond) + require.EventuallyWithT(t, func(c *assert.CollectT) { + require.ElementsMatch(c, tc.wantInstalledInstances, installer.GetInstalledInstances()) + + // all current tests install at least one VM, so this cannot be zero. + // multiple installations will trigger just one event. + const expectedEventCount = 1 + require.Equal(c, expectedEventCount, reporter.ResourceCreateEventCount()) + }, 500*time.Millisecond, 50*time.Millisecond) + + // make sure azure client cache has expected entries + for _, integrationName := range tc.expectedIntegrationNames { + require.NotNil(t, server.azureClientCache) + _, err = libutils.FnCacheGet(t.Context(), server.azureClientCache, integrationName, func(ctx context.Context) (cloud.AzureClients, error) { + return nil, trace.NotFound("cache key %q not found", integrationName) + }) + require.NoError(t, err) } }) @@ -3674,13 +3702,51 @@ func TestServer_onCreate(t *testing.T) { }) } -func TestEmitUsageEvents(t *testing.T) { +func TestDiscardAmbientCredentialMatchers(t *testing.T) { t.Parallel() - azureClients := &cloudtest.AzureClients{ - AzureVirtualMachines: &mockAzureClient{}, - AzureRunCommand: &mockAzureRunCommandClient{}, + + logger := logtest.NewLogger() + + awsDrop := types.AWSMatcher{ + Types: []string{"ec2"}, + Regions: []string{"us-west-1"}, + } + awsKeep := types.AWSMatcher{ + Types: []string{"ec2"}, + Regions: []string{"us-west-2"}, + Integration: "aws-int", + } + + azureDrop := types.AzureMatcher{ + Subscriptions: []string{"dummy"}, + } + azureKeep := types.AzureMatcher{ + Subscriptions: []string{"dummy"}, + Integration: "azure-int", + } + + matchers := Matchers{ + AWS: []types.AWSMatcher{awsDrop, awsKeep}, + Azure: []types.AzureMatcher{azureDrop, azureKeep}, + GCP: []types.GCPMatcher{ + {ProjectIDs: []string{"proj"}}, + }, + Kubernetes: []types.KubernetesMatcher{ + {Types: []string{"app"}}, + }, } + discardAmbientCredentialMatchers(t.Context(), logger, &matchers) + + require.Equal(t, []types.AWSMatcher{awsKeep}, matchers.AWS) + require.Equal(t, []types.AzureMatcher{azureKeep}, matchers.Azure) + require.Empty(t, matchers.GCP) + require.Empty(t, matchers.Kubernetes) +} + +func TestEmitUsageEvents(t *testing.T) { + t.Parallel() + testAuthServer, err := authtest.NewAuthServer(authtest.AuthServerConfig{ Dir: t.TempDir(), }) @@ -3701,7 +3767,6 @@ func TestEmitUsageEvents(t *testing.T) { tlsServer.Auth().SetUsageReporter(reporter) server, err := New(authz.ContextWithUser(context.Background(), identity.I), &Config{ - azureClients: azureClients, ClusterFeatures: func() proto.Features { return proto.Features{} }, AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), Matchers: Matchers{ diff --git a/lib/srv/discovery/fetchers/aks.go b/lib/srv/discovery/fetchers/aks.go index 0acdc5ad5e353..f183b695d6ab0 100644 --- a/lib/srv/discovery/fetchers/aks.go +++ b/lib/srv/discovery/fetchers/aks.go @@ -51,6 +51,8 @@ type AKSFetcherConfig struct { Logger *slog.Logger // DiscoveryConfigName is the name of the DiscoveryConfig that created this Fetcher. DiscoveryConfigName string + // Integration is the name of Azure integration used for auth. + Integration string } // CheckAndSetDefaults validates and sets the defaults values. @@ -155,7 +157,7 @@ func (a *aksFetcher) Cloud() string { } func (a *aksFetcher) IntegrationName() string { - return "" + return a.Integration } func (a *aksFetcher) GetDiscoveryConfigName() string { diff --git a/lib/srv/discovery/fetchers/db/azure.go b/lib/srv/discovery/fetchers/db/azure.go index a1674015af2e2..f1aeff9a0e243 100644 --- a/lib/srv/discovery/fetchers/db/azure.go +++ b/lib/srv/discovery/fetchers/db/azure.go @@ -93,6 +93,8 @@ type azureFetcherConfig struct { regionSet map[string]struct{} // DiscoveryConfigName is the name of the discovery config which originated the resource. DiscoveryConfigName string + // Integration is the name of Azure integration used for auth. + Integration string } // regionMatches returns whether a given region matches the configured Regions selector @@ -154,8 +156,7 @@ func (f *azureFetcher[DBType, ListClient]) FetcherType() string { // IntegrationName returns the integration name. func (f *azureFetcher[DBType, ListClient]) IntegrationName() string { - // There is currently no integration that supports Auto Discover for Azure resources. - return "" + return f.cfg.Integration } // GetDiscoveryConfigName is the name of the discovery config which originated the resource. diff --git a/lib/srv/discovery/fetchers/db/db.go b/lib/srv/discovery/fetchers/db/db.go index 282258f3147aa..42528599ab58c 100644 --- a/lib/srv/discovery/fetchers/db/db.go +++ b/lib/srv/discovery/fetchers/db/db.go @@ -197,8 +197,13 @@ func (f *AWSFetcherFactory) MakeFetchers(ctx context.Context, matchers []types.A } // MakeAzureFetchers creates new Azure database fetchers. -func MakeAzureFetchers(clients cloud.AzureClients, matchers []types.AzureMatcher, discoveryConfigName string) (result []common.Fetcher, err error) { +func MakeAzureFetchers(ctx context.Context, getAzureClients func(ctx context.Context, integration string) (cloud.AzureClients, error), matchers []types.AzureMatcher, discoveryConfigName string) (result []common.Fetcher, err error) { for _, matcher := range services.SimplifyAzureMatchers(matchers) { + azureClients, err := getAzureClients(ctx, matcher.Integration) + if err != nil { + return nil, trace.Wrap(err) + } + for _, matcherType := range matcher.Types { makeFetchers, found := makeAzureFetcherFuncs[matcherType] if !found { @@ -211,13 +216,14 @@ func MakeAzureFetchers(clients cloud.AzureClients, matchers []types.AzureMatcher for _, sub := range matcher.Subscriptions { for _, group := range matcher.ResourceGroups { fetcher, err := makeFetcher(azureFetcherConfig{ - AzureClients: clients, + AzureClients: azureClients, Type: matcherType, Subscription: sub, ResourceGroup: group, Labels: matcher.ResourceTags, Regions: matcher.Regions, DiscoveryConfigName: discoveryConfigName, + Integration: matcher.Integration, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/discovery/fetchers/db/helpers_test.go b/lib/srv/discovery/fetchers/db/helpers_test.go index bc89c1616af9f..04fd2d98aa401 100644 --- a/lib/srv/discovery/fetchers/db/helpers_test.go +++ b/lib/srv/discovery/fetchers/db/helpers_test.go @@ -23,6 +23,7 @@ import ( "os" "testing" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -79,7 +80,12 @@ func mustMakeAWSFetchers(t *testing.T, cfg AWSFetcherFactoryConfig, matchers []t func mustMakeAzureFetchers(t *testing.T, clients cloud.AzureClients, matchers []types.AzureMatcher) []common.Fetcher { t.Helper() - fetchers, err := MakeAzureFetchers(clients, matchers, "" /* discovery config */) + fetchers, err := MakeAzureFetchers(t.Context(), func(ctx context.Context, integration string) (cloud.AzureClients, error) { + if integration != "" { + return nil, trace.NotImplemented("expected empty integration, got %q", integration) + } + return clients, nil + }, matchers, "" /* discovery config */) require.NoError(t, err) require.NotEmpty(t, fetchers) diff --git a/lib/srv/server/azure_watcher.go b/lib/srv/server/azure_watcher.go index 25fac4160191a..43d707f174bd6 100644 --- a/lib/srv/server/azure_watcher.go +++ b/lib/srv/server/azure_watcher.go @@ -32,6 +32,7 @@ import ( usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/installers" + "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/services" ) @@ -50,12 +51,14 @@ type AzureInstances struct { InstallerParams *types.InstallerParams // Instances is a list of discovered Azure virtual machines. Instances []*armcompute.VirtualMachine + // Integration is the optional name of the integration to use for auth. + Integration string } // MakeEvents generates MakeEvents for these instances. func (instances *AzureInstances) MakeEvents() map[string]*usageeventsv1.ResourceCreateEvent { resourceType := types.DiscoveredResourceNode - if instances.InstallerParams.ScriptName == installers.InstallerScriptNameAgentless { + if instances.InstallerParams != nil && instances.InstallerParams.ScriptName == installers.InstallerScriptNameAgentless { resourceType = types.DiscoveredResourceAgentlessNode } events := make(map[string]*usageeventsv1.ResourceCreateEvent, len(instances.Instances)) @@ -69,9 +72,7 @@ func (instances *AzureInstances) MakeEvents() map[string]*usageeventsv1.Resource return events } -type azureClientGetter interface { - GetAzureVirtualMachinesClient(subscription string) (azure.VirtualMachinesClient, error) -} +type azureClientGetter func(ctx context.Context, integration string) (cloud.AzureClients, error) // NewAzureWatcher creates a new Azure watcher instance. func NewAzureWatcher(ctx context.Context, fetchersFn func() []Fetcher, opts ...Option) (*Watcher, error) { @@ -92,7 +93,7 @@ func NewAzureWatcher(ctx context.Context, fetchersFn func() []Fetcher, opts ...O } // MatchersToAzureInstanceFetchers converts a list of Azure VM Matchers into a list of Azure VM Fetchers. -func MatchersToAzureInstanceFetchers(logger *slog.Logger, matchers []types.AzureMatcher, clients azureClientGetter, discoveryConfigName string) []Fetcher { +func MatchersToAzureInstanceFetchers(logger *slog.Logger, matchers []types.AzureMatcher, getClient azureClientGetter, discoveryConfigName string) []Fetcher { ret := make([]Fetcher, 0) for _, matcher := range matchers { for _, subscription := range matcher.Subscriptions { @@ -101,7 +102,7 @@ func MatchersToAzureInstanceFetchers(logger *slog.Logger, matchers []types.Azure Matcher: matcher, Subscription: subscription, ResourceGroup: resourceGroup, - AzureClientGetter: clients, + AzureClientGetter: getClient, DiscoveryConfigName: discoveryConfigName, Logger: logger, }) @@ -118,7 +119,6 @@ type azureFetcherConfig struct { ResourceGroup string AzureClientGetter azureClientGetter DiscoveryConfigName string - Integration string Logger *slog.Logger } @@ -143,7 +143,7 @@ func newAzureInstanceFetcher(cfg azureFetcherConfig) *azureInstanceFetcher { ResourceGroup: cfg.ResourceGroup, Labels: cfg.Matcher.ResourceTags, DiscoveryConfigName: cfg.DiscoveryConfigName, - Integration: cfg.Integration, + Integration: cfg.Matcher.Integration, Logger: cfg.Logger, } } @@ -169,7 +169,12 @@ type resourceGroupLocation struct { // GetInstances fetches all Azure virtual machines matching configured filters. func (f *azureInstanceFetcher) GetInstances(ctx context.Context, _ bool) ([]Instances, error) { - client, err := f.AzureClientGetter.GetAzureVirtualMachinesClient(f.Subscription) + azureClients, err := f.AzureClientGetter(ctx, f.IntegrationName()) + if err != nil { + return nil, trace.Wrap(err) + } + + client, err := azureClients.GetAzureVirtualMachinesClient(f.Subscription) if err != nil { return nil, trace.Wrap(err) } @@ -231,6 +236,7 @@ func (f *azureInstanceFetcher) GetInstances(ctx context.Context, _ bool) ([]Inst Region: batchGroup.location, ResourceGroup: batchGroup.resourceGroup, Instances: vms, + Integration: f.Integration, InstallerParams: f.InstallerParams, }}) } diff --git a/lib/srv/server/azure_watcher_test.go b/lib/srv/server/azure_watcher_test.go index c0f1b448bd853..a2508c1b06f26 100644 --- a/lib/srv/server/azure_watcher_test.go +++ b/lib/srv/server/azure_watcher_test.go @@ -158,7 +158,9 @@ func TestAzureWatcher(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) t.Cleanup(cancel) watcher, err := NewAzureWatcher(ctx, func() []Fetcher { - return MatchersToAzureInstanceFetchers(logger, []types.AzureMatcher{tc.matcher}, &clients, "" /* discovery config */) + return MatchersToAzureInstanceFetchers(logger, []types.AzureMatcher{tc.matcher}, func(ctx context.Context, integration string) (cloud.AzureClients, error) { + return &clients, nil + }, "" /* discovery config */) }) require.NoError(t, err)