diff --git a/lib/cloud/azure/client_map.go b/lib/cloud/azure/client_map.go index b23c9a5076f55..c4683fb909cf2 100644 --- a/lib/cloud/azure/client_map.go +++ b/lib/cloud/azure/client_map.go @@ -17,59 +17,76 @@ limitations under the License. package azure import ( - "sync" + "context" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport/lib/utils" ) +const clientExpireTime = time.Hour + // ClientMap is a generic map that caches a collection of Azure clients by // subscriptions. type ClientMap[ClientType any] struct { - mu sync.RWMutex - clients map[string]ClientType + clients *utils.FnCache newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (ClientType, error) } -// NewClientMap creates a new ClientMap. -func NewClientMap[ClientType any](newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (ClientType, error)) ClientMap[ClientType] { - return ClientMap[ClientType]{ - clients: make(map[string]ClientType), - newClient: newClient, - } +// ClientMapOptions defines options for creating a client map. +type ClientMapOptions struct { + clock clockwork.Clock } -// Get returns an Azure client by subscription. A new client is created if the -// subscription is not found in the map. -func (m *ClientMap[ClientType]) Get(subscription string, getCredentials func() (azcore.TokenCredential, error)) (client ClientType, err error) { - m.mu.RLock() - if client, ok := m.clients[subscription]; ok { - m.mu.RUnlock() - return client, nil - } - m.mu.RUnlock() +// ClientMapOption allows setting options as functional arguments to NewClientMap. +type ClientMapOption func(*ClientMapOptions) - m.mu.Lock() - defer m.mu.Unlock() - - // If some other thread already got here first. - if client, ok := m.clients[subscription]; ok { - return client, nil +func withClock(clock clockwork.Clock) ClientMapOption { + return func(opts *ClientMapOptions) { + opts.clock = clock } +} - cred, err := getCredentials() - if err != nil { - return client, trace.Wrap(err) +// NewClientMap creates a new ClientMap. +func NewClientMap[ClientType any]( + newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (ClientType, error), + opts ...ClientMapOption, +) (ClientMap[ClientType], error) { + options := &ClientMapOptions{} + for _, opt := range opts { + opt(options) } - // TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options - options := &arm.ClientOptions{} - client, err = m.newClient(subscription, cred, options) + cache, err := utils.NewFnCache(utils.FnCacheConfig{ + TTL: clientExpireTime, + Clock: options.clock, + }) if err != nil { - return client, trace.Wrap(err) + return ClientMap[ClientType]{}, trace.Wrap(err) } + return ClientMap[ClientType]{ + clients: cache, + newClient: newClient, + }, nil +} - m.clients[subscription] = client - return client, nil +// Get returns an Azure client by subscription. A new client is created if the +// subscription is not found in the map. +func (m *ClientMap[ClientType]) Get(subscription string, getCredentials func() (azcore.TokenCredential, error)) (ClientType, error) { + client, err := utils.FnCacheGet[ClientType](context.Background(), m.clients, subscription, func(ctx context.Context) (client ClientType, err error) { + cred, err := getCredentials() + if err != nil { + return client, trace.Wrap(err) + } + + // TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options + options := &arm.ClientOptions{} + client, err = m.newClient(subscription, cred, options) + return client, trace.Wrap(err) + }) + return client, trace.Wrap(err) } diff --git a/lib/cloud/azure/client_map_test.go b/lib/cloud/azure/client_map_test.go index 7ebd1702e243b..824bc4ccf6369 100644 --- a/lib/cloud/azure/client_map_test.go +++ b/lib/cloud/azure/client_map_test.go @@ -22,6 +22,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" ) @@ -34,7 +35,9 @@ func TestClientMap(t *testing.T) { } return nil, trace.BadParameter("failed to create") } - clientMap := NewClientMap(mockNewClientFunc) + clock := clockwork.NewFakeClock() + clientMap, err := NewClientMap(mockNewClientFunc, withClock(clock)) + require.NoError(t, err) // Note that some test cases (e.g. "get from cache") depend on previous // test cases. Thus running in sequence. @@ -73,4 +76,21 @@ func TestClientMap(t *testing.T) { require.NotNil(t, client) require.IsType(t, NewRedisClientByAPI(nil), client) }) + + t.Run("expire from cache", func(t *testing.T) { + oldClient, err := clientMap.Get("good-sub", func() (azcore.TokenCredential, error) { + return nil, nil + }) + require.NoError(t, err) + require.NotNil(t, oldClient) + + clock.Advance(2 * clientExpireTime) + newClient, err := clientMap.Get("good-sub", func() (azcore.TokenCredential, error) { + return nil, nil + }) + require.NoError(t, err) + require.NotNil(t, newClient) + require.NotSame(t, oldClient, newClient) + }) + } diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index d11a77caa22ab..c4abc3b7aee61 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -187,6 +187,49 @@ func (c *clientCache[T]) GetClient(ctx context.Context) (T, error) { return c.client, trace.Wrap(c.err) } +func newAzureClients() (*azureClients, error) { + azClients := &azureClients{ + azureMySQLClients: make(map[string]azure.DBServersClient), + azurePostgresClients: make(map[string]azure.DBServersClient), + azureKubernetesClient: make(map[string]azure.AKSClient), + } + var err error + azClients.azureRedisClients, err = azure.NewClientMap(azure.NewRedisClient) + if err != nil { + return nil, trace.Wrap(err) + } + azClients.azureRedisEnterpriseClients, err = azure.NewClientMap(azure.NewRedisEnterpriseClient) + if err != nil { + return nil, trace.Wrap(err) + } + azClients.azureVirtualMachinesClients, err = azure.NewClientMap(azure.NewVirtualMachinesClient) + if err != nil { + return nil, trace.Wrap(err) + } + azClients.azureSQLServerClients, err = azure.NewClientMap(azure.NewSQLClient) + if err != nil { + return nil, trace.Wrap(err) + } + azClients.azureManagedSQLServerClients, err = azure.NewClientMap(azure.NewManagedSQLClient) + if err != nil { + return nil, trace.Wrap(err) + } + azClients.azureMySQLFlexServersClients, err = azure.NewClientMap(azure.NewMySQLFlexServersClient) + if err != nil { + return nil, trace.Wrap(err) + } + azClients.azurePostgresFlexServersClients, err = azure.NewClientMap(azure.NewPostgresFlexServersClient) + if err != nil { + return nil, trace.Wrap(err) + } + azClients.azureRunCommandClients, err = azure.NewClientMap(azure.NewRunCommandClient) + if err != nil { + return nil, trace.Wrap(err) + } + + return azClients, nil +} + // NewClients returns a new instance of cloud clients retriever. func NewClients() (Clients, error) { awsSessionsCache, err := utils.NewFnCache(utils.FnCacheConfig{ @@ -195,6 +238,10 @@ func NewClients() (Clients, error) { if err != nil { return nil, trace.Wrap(err) } + azClients, err := newAzureClients() + if err != nil { + return nil, trace.Wrap(err) + } return &cloudClients{ awsSessionsCache: awsSessionsCache, gcpClients: gcpClients{ @@ -202,19 +249,7 @@ func NewClients() (Clients, error) { gcpGKE: newClientCache[gcp.GKEClient](gcp.NewGKEClient), gcpInstances: newClientCache[gcp.InstancesClient](gcp.NewInstancesClient), }, - azureClients: azureClients{ - azureMySQLClients: make(map[string]azure.DBServersClient), - azurePostgresClients: make(map[string]azure.DBServersClient), - azureRedisClients: azure.NewClientMap(azure.NewRedisClient), - azureRedisEnterpriseClients: azure.NewClientMap(azure.NewRedisEnterpriseClient), - azureKubernetesClient: make(map[string]azure.AKSClient), - azureVirtualMachinesClients: azure.NewClientMap(azure.NewVirtualMachinesClient), - azureSQLServerClients: azure.NewClientMap(azure.NewSQLClient), - azureManagedSQLServerClients: azure.NewClientMap(azure.NewManagedSQLClient), - azureMySQLFlexServersClients: azure.NewClientMap(azure.NewMySQLFlexServersClient), - azurePostgresFlexServersClients: azure.NewClientMap(azure.NewPostgresFlexServersClient), - azureRunCommandClients: azure.NewClientMap(azure.NewRunCommandClient), - }, + azureClients: azClients, }, nil } @@ -230,7 +265,7 @@ type cloudClients struct { // gcpClients contains GCP-specific clients. gcpClients // azureClients contains Azure-specific clients. - azureClients + *azureClients // mtx is used for locking. mtx sync.RWMutex } diff --git a/lib/srv/server/azure_watcher.go b/lib/srv/server/azure_watcher.go index c3c2c4dd7d6fe..b8ee220a9c515 100644 --- a/lib/srv/server/azure_watcher.go +++ b/lib/srv/server/azure_watcher.go @@ -71,6 +71,10 @@ func (instances *AzureInstances) MakeEvents() map[string]*usageeventsv1.Resource return events } +type azureClientGetter interface { + GetAzureVirtualMachinesClient(subscription string) (azure.VirtualMachinesClient, error) +} + // NewAzureWatcher creates a new Azure watcher instance. func NewAzureWatcher(ctx context.Context, matchers []types.AzureMatcher, clients cloud.Clients, opts ...Option) (*Watcher, error) { cancelCtx, cancelFn := context.WithCancel(ctx) @@ -87,15 +91,11 @@ func NewAzureWatcher(ctx context.Context, matchers []types.AzureMatcher, clients for _, matcher := range matchers { for _, subscription := range matcher.Subscriptions { for _, resourceGroup := range matcher.ResourceGroups { - cl, err := clients.GetAzureVirtualMachinesClient(subscription) - if err != nil { - return nil, trace.Wrap(err) - } fetcher := newAzureInstanceFetcher(azureFetcherConfig{ - Matcher: matcher, - Subscription: subscription, - ResourceGroup: resourceGroup, - AzureClient: cl, + Matcher: matcher, + Subscription: subscription, + ResourceGroup: resourceGroup, + AzureClientGetter: clients, }) watcher.fetchers = append(watcher.fetchers, fetcher) } @@ -105,28 +105,28 @@ func NewAzureWatcher(ctx context.Context, matchers []types.AzureMatcher, clients } type azureFetcherConfig struct { - Matcher types.AzureMatcher - Subscription string - ResourceGroup string - AzureClient azure.VirtualMachinesClient + Matcher types.AzureMatcher + Subscription string + ResourceGroup string + AzureClientGetter azureClientGetter } type azureInstanceFetcher struct { - Azure azure.VirtualMachinesClient - Regions []string - Subscription string - ResourceGroup string - Labels types.Labels - Parameters map[string]string + AzureClientGetter azureClientGetter + Regions []string + Subscription string + ResourceGroup string + Labels types.Labels + Parameters map[string]string } func newAzureInstanceFetcher(cfg azureFetcherConfig) *azureInstanceFetcher { ret := &azureInstanceFetcher{ - Azure: cfg.AzureClient, - Regions: cfg.Matcher.Regions, - Subscription: cfg.Subscription, - ResourceGroup: cfg.ResourceGroup, - Labels: cfg.Matcher.ResourceTags, + AzureClientGetter: cfg.AzureClientGetter, + Regions: cfg.Matcher.Regions, + Subscription: cfg.Subscription, + ResourceGroup: cfg.ResourceGroup, + Labels: cfg.Matcher.ResourceTags, } if cfg.Matcher.Params != nil { @@ -146,12 +146,16 @@ func (*azureInstanceFetcher) GetMatchingInstances(_ []types.Server, _ bool) ([]I // GetInstances fetches all Azure virtual machines matching configured filters. func (f *azureInstanceFetcher) GetInstances(ctx context.Context, _ bool) ([]Instances, error) { + client, err := f.AzureClientGetter.GetAzureVirtualMachinesClient(f.Subscription) + if err != nil { + return nil, trace.Wrap(err) + } instancesByRegion := make(map[string][]*armcompute.VirtualMachine) for _, region := range f.Regions { instancesByRegion[region] = []*armcompute.VirtualMachine{} } - vms, err := f.Azure.ListVirtualMachines(ctx, f.ResourceGroup) + vms, err := client.ListVirtualMachines(ctx, f.ResourceGroup) if err != nil { return nil, trace.Wrap(err) }