Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 50 additions & 33 deletions lib/cloud/azure/client_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,59 +17,76 @@ limitations under the License.
package azure

import (
"sync"
"context"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport/lib/utils"
)

const clientExpireTime = time.Hour

// ClientMap is a generic map that caches a collection of Azure clients by
// subscriptions.
type ClientMap[ClientType any] struct {
mu sync.RWMutex
clients map[string]ClientType
clients *utils.FnCache
newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (ClientType, error)
}

// NewClientMap creates a new ClientMap.
func NewClientMap[ClientType any](newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (ClientType, error)) ClientMap[ClientType] {
return ClientMap[ClientType]{
clients: make(map[string]ClientType),
newClient: newClient,
}
// ClientMapOptions defines options for creating a client map.
type ClientMapOptions struct {
clock clockwork.Clock
}

// Get returns an Azure client by subscription. A new client is created if the
// subscription is not found in the map.
func (m *ClientMap[ClientType]) Get(subscription string, getCredentials func() (azcore.TokenCredential, error)) (client ClientType, err error) {
m.mu.RLock()
if client, ok := m.clients[subscription]; ok {
m.mu.RUnlock()
return client, nil
}
m.mu.RUnlock()
// ClientMapOption allows setting options as functional arguments to NewClientMap.
type ClientMapOption func(*ClientMapOptions)

m.mu.Lock()
defer m.mu.Unlock()

// If some other thread already got here first.
if client, ok := m.clients[subscription]; ok {
return client, nil
func withClock(clock clockwork.Clock) ClientMapOption {
return func(opts *ClientMapOptions) {
opts.clock = clock
}
}

cred, err := getCredentials()
if err != nil {
return client, trace.Wrap(err)
// NewClientMap creates a new ClientMap.
func NewClientMap[ClientType any](
newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (ClientType, error),
opts ...ClientMapOption,
) (ClientMap[ClientType], error) {
options := &ClientMapOptions{}
for _, opt := range opts {
opt(options)
}

// TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options
options := &arm.ClientOptions{}
client, err = m.newClient(subscription, cred, options)
cache, err := utils.NewFnCache(utils.FnCacheConfig{
TTL: clientExpireTime,
Clock: options.clock,
})
if err != nil {
return client, trace.Wrap(err)
return ClientMap[ClientType]{}, trace.Wrap(err)
}
return ClientMap[ClientType]{
clients: cache,
newClient: newClient,
}, nil
}

m.clients[subscription] = client
return client, nil
// Get returns an Azure client by subscription. A new client is created if the
// subscription is not found in the map.
func (m *ClientMap[ClientType]) Get(subscription string, getCredentials func() (azcore.TokenCredential, error)) (ClientType, error) {
client, err := utils.FnCacheGet[ClientType](context.Background(), m.clients, subscription, func(ctx context.Context) (client ClientType, err error) {
cred, err := getCredentials()
if err != nil {
return client, trace.Wrap(err)
}

// TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options
options := &arm.ClientOptions{}
client, err = m.newClient(subscription, cred, options)
return client, trace.Wrap(err)
})
return client, trace.Wrap(err)
}
22 changes: 21 additions & 1 deletion lib/cloud/azure/client_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
)

Expand All @@ -34,7 +35,9 @@ func TestClientMap(t *testing.T) {
}
return nil, trace.BadParameter("failed to create")
}
clientMap := NewClientMap(mockNewClientFunc)
clock := clockwork.NewFakeClock()
clientMap, err := NewClientMap(mockNewClientFunc, withClock(clock))
require.NoError(t, err)

// Note that some test cases (e.g. "get from cache") depend on previous
// test cases. Thus running in sequence.
Expand Down Expand Up @@ -73,4 +76,21 @@ func TestClientMap(t *testing.T) {
require.NotNil(t, client)
require.IsType(t, NewRedisClientByAPI(nil), client)
})

t.Run("expire from cache", func(t *testing.T) {
oldClient, err := clientMap.Get("good-sub", func() (azcore.TokenCredential, error) {
return nil, nil
})
require.NoError(t, err)
require.NotNil(t, oldClient)

clock.Advance(2 * clientExpireTime)
newClient, err := clientMap.Get("good-sub", func() (azcore.TokenCredential, error) {
return nil, nil
})
require.NoError(t, err)
require.NotNil(t, newClient)
require.NotSame(t, oldClient, newClient)
})

}
63 changes: 49 additions & 14 deletions lib/cloud/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,49 @@ func (c *clientCache[T]) GetClient(ctx context.Context) (T, error) {
return c.client, trace.Wrap(c.err)
}

func newAzureClients() (*azureClients, error) {
azClients := &azureClients{
azureMySQLClients: make(map[string]azure.DBServersClient),
azurePostgresClients: make(map[string]azure.DBServersClient),
azureKubernetesClient: make(map[string]azure.AKSClient),
}
var err error
azClients.azureRedisClients, err = azure.NewClientMap(azure.NewRedisClient)
if err != nil {
return nil, trace.Wrap(err)
}
azClients.azureRedisEnterpriseClients, err = azure.NewClientMap(azure.NewRedisEnterpriseClient)
if err != nil {
return nil, trace.Wrap(err)
}
azClients.azureVirtualMachinesClients, err = azure.NewClientMap(azure.NewVirtualMachinesClient)
if err != nil {
return nil, trace.Wrap(err)
}
azClients.azureSQLServerClients, err = azure.NewClientMap(azure.NewSQLClient)
if err != nil {
return nil, trace.Wrap(err)
}
azClients.azureManagedSQLServerClients, err = azure.NewClientMap(azure.NewManagedSQLClient)
if err != nil {
return nil, trace.Wrap(err)
}
azClients.azureMySQLFlexServersClients, err = azure.NewClientMap(azure.NewMySQLFlexServersClient)
if err != nil {
return nil, trace.Wrap(err)
}
azClients.azurePostgresFlexServersClients, err = azure.NewClientMap(azure.NewPostgresFlexServersClient)
if err != nil {
return nil, trace.Wrap(err)
}
azClients.azureRunCommandClients, err = azure.NewClientMap(azure.NewRunCommandClient)
if err != nil {
return nil, trace.Wrap(err)
}

return azClients, nil
}

// NewClients returns a new instance of cloud clients retriever.
func NewClients() (Clients, error) {
awsSessionsCache, err := utils.NewFnCache(utils.FnCacheConfig{
Expand All @@ -195,26 +238,18 @@ func NewClients() (Clients, error) {
if err != nil {
return nil, trace.Wrap(err)
}
azClients, err := newAzureClients()
if err != nil {
return nil, trace.Wrap(err)
}
return &cloudClients{
awsSessionsCache: awsSessionsCache,
gcpClients: gcpClients{
gcpSQLAdmin: newClientCache[gcp.SQLAdminClient](gcp.NewSQLAdminClient),
gcpGKE: newClientCache[gcp.GKEClient](gcp.NewGKEClient),
gcpInstances: newClientCache[gcp.InstancesClient](gcp.NewInstancesClient),
},
azureClients: azureClients{
azureMySQLClients: make(map[string]azure.DBServersClient),
azurePostgresClients: make(map[string]azure.DBServersClient),
azureRedisClients: azure.NewClientMap(azure.NewRedisClient),
azureRedisEnterpriseClients: azure.NewClientMap(azure.NewRedisEnterpriseClient),
azureKubernetesClient: make(map[string]azure.AKSClient),
azureVirtualMachinesClients: azure.NewClientMap(azure.NewVirtualMachinesClient),
azureSQLServerClients: azure.NewClientMap(azure.NewSQLClient),
azureManagedSQLServerClients: azure.NewClientMap(azure.NewManagedSQLClient),
azureMySQLFlexServersClients: azure.NewClientMap(azure.NewMySQLFlexServersClient),
azurePostgresFlexServersClients: azure.NewClientMap(azure.NewPostgresFlexServersClient),
azureRunCommandClients: azure.NewClientMap(azure.NewRunCommandClient),
},
azureClients: azClients,
}, nil
}

Expand All @@ -230,7 +265,7 @@ type cloudClients struct {
// gcpClients contains GCP-specific clients.
gcpClients
// azureClients contains Azure-specific clients.
azureClients
*azureClients
// mtx is used for locking.
mtx sync.RWMutex
}
Expand Down
52 changes: 28 additions & 24 deletions lib/srv/server/azure_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func (instances *AzureInstances) MakeEvents() map[string]*usageeventsv1.Resource
return events
}

type azureClientGetter interface {
GetAzureVirtualMachinesClient(subscription string) (azure.VirtualMachinesClient, error)
}

// NewAzureWatcher creates a new Azure watcher instance.
func NewAzureWatcher(ctx context.Context, matchers []types.AzureMatcher, clients cloud.Clients) (*Watcher, error) {
cancelCtx, cancelFn := context.WithCancel(ctx)
Expand All @@ -84,15 +88,11 @@ func NewAzureWatcher(ctx context.Context, matchers []types.AzureMatcher, clients
for _, matcher := range matchers {
for _, subscription := range matcher.Subscriptions {
for _, resourceGroup := range matcher.ResourceGroups {
cl, err := clients.GetAzureVirtualMachinesClient(subscription)
if err != nil {
return nil, trace.Wrap(err)
}
fetcher := newAzureInstanceFetcher(azureFetcherConfig{
Matcher: matcher,
Subscription: subscription,
ResourceGroup: resourceGroup,
AzureClient: cl,
Matcher: matcher,
Subscription: subscription,
ResourceGroup: resourceGroup,
AzureClientGetter: clients,
})
watcher.fetchers = append(watcher.fetchers, fetcher)
}
Expand All @@ -102,28 +102,28 @@ func NewAzureWatcher(ctx context.Context, matchers []types.AzureMatcher, clients
}

type azureFetcherConfig struct {
Matcher types.AzureMatcher
Subscription string
ResourceGroup string
AzureClient azure.VirtualMachinesClient
Matcher types.AzureMatcher
Subscription string
ResourceGroup string
AzureClientGetter azureClientGetter
}

type azureInstanceFetcher struct {
Azure azure.VirtualMachinesClient
Regions []string
Subscription string
ResourceGroup string
Labels types.Labels
Parameters map[string]string
AzureClientGetter azureClientGetter
Regions []string
Subscription string
ResourceGroup string
Labels types.Labels
Parameters map[string]string
}

func newAzureInstanceFetcher(cfg azureFetcherConfig) *azureInstanceFetcher {
ret := &azureInstanceFetcher{
Azure: cfg.AzureClient,
Regions: cfg.Matcher.Regions,
Subscription: cfg.Subscription,
ResourceGroup: cfg.ResourceGroup,
Labels: cfg.Matcher.ResourceTags,
AzureClientGetter: cfg.AzureClientGetter,
Regions: cfg.Matcher.Regions,
Subscription: cfg.Subscription,
ResourceGroup: cfg.ResourceGroup,
Labels: cfg.Matcher.ResourceTags,
}

if cfg.Matcher.Params != nil {
Expand All @@ -143,12 +143,16 @@ func (*azureInstanceFetcher) GetMatchingInstances(_ []types.Server, _ bool) ([]I

// GetInstances fetches all Azure virtual machines matching configured filters.
func (f *azureInstanceFetcher) GetInstances(ctx context.Context, _ bool) ([]Instances, error) {
client, err := f.AzureClientGetter.GetAzureVirtualMachinesClient(f.Subscription)
if err != nil {
return nil, trace.Wrap(err)
}
instancesByRegion := make(map[string][]*armcompute.VirtualMachine)
for _, region := range f.Regions {
instancesByRegion[region] = []*armcompute.VirtualMachine{}
}

vms, err := f.Azure.ListVirtualMachines(ctx, f.ResourceGroup)
vms, err := client.ListVirtualMachines(ctx, f.ResourceGroup)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down