Skip to content
53 changes: 49 additions & 4 deletions lib/cloud/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/cloud/gcp"
)
Expand Down Expand Up @@ -89,8 +90,41 @@ type AzureClients interface {
GetAzureRunCommandClient(subscription string) (azure.RunCommandClient, error)
}

// AzureClientsOption is an option to pass to NewAzureClients
type AzureClientsOption func(clients *azureClients)

type azureOIDCCredentials interface {
GenerateAzureOIDCToken(ctx context.Context, integration string) (string, error)
GetIntegration(ctx context.Context, name string) (types.Integration, error)
}

// WithAzureIntegrationCredentials configures Azure cloud clients to use integration credentials.
func WithAzureIntegrationCredentials(integrationName string, auth azureOIDCCredentials) AzureClientsOption {
return func(clt *azureClients) {
clt.newAzureCredentialFunc = func() (azcore.TokenCredential, error) {
ctx := context.TODO()
Comment thread
Tener marked this conversation as resolved.
integration, err := auth.GetIntegration(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}
spec := integration.GetAzureOIDCIntegrationSpec()
if spec == nil {
return nil, trace.BadParameter("expected %q to be an %q integration, was %q instead", integration.GetName(), types.IntegrationSubKindAzureOIDC, integration.GetSubKind())
}
cred, err := azidentity.NewClientAssertionCredential(spec.TenantID, spec.ClientID, func(ctx context.Context) (string, error) {
return auth.GenerateAzureOIDCToken(ctx, integrationName)
// TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options
}, nil)
Comment thread
greedy52 marked this conversation as resolved.
if err != nil {
return nil, trace.Wrap(err)
}
return cred, nil
}
}
}

// NewAzureClients returns a new instance of Azure SDK clients.
func NewAzureClients() (AzureClients, error) {
func NewAzureClients(opts ...AzureClientsOption) (AzureClients, error) {
azClients := &azureClients{
azureMySQLClients: make(map[string]azure.DBServersClient),
azurePostgresClients: make(map[string]azure.DBServersClient),
Expand Down Expand Up @@ -130,6 +164,15 @@ func NewAzureClients() (AzureClients, error) {
return nil, trace.Wrap(err)
}

azClients.newAzureCredentialFunc = func() (azcore.TokenCredential, error) {
// TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options
return azidentity.NewDefaultAzureCredential(nil)
}

for _, opt := range opts {
opt(azClients)
}

return azClients, nil
}

Expand Down Expand Up @@ -168,8 +211,11 @@ type azureClients struct {
// mtx is used for locking.
mtx sync.RWMutex

// newAzureCredentialFunc creates new Azure credential.
newAzureCredentialFunc func() (azcore.TokenCredential, error)
// azureCredential is the cached Azure credential.
azureCredential azcore.TokenCredential

// azureMySQLClients is the cached Azure MySQL Server clients.
azureMySQLClients map[string]azure.DBServersClient
// azurePostgresClients is the cached Azure Postgres Server clients.
Expand Down Expand Up @@ -378,9 +424,8 @@ func (c *azureClients) initAzureCredential() (azcore.TokenCredential, error) {
if c.azureCredential != nil { // If some other thread already got here first.
return c.azureCredential, nil
}
// TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options
options := &azidentity.DefaultAzureCredentialOptions{}
cred, err := azidentity.NewDefaultAzureCredential(options)

cred, err := c.newAzureCredentialFunc()
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
123 changes: 123 additions & 0 deletions lib/cloud/clients_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Teleport
// Copyright (C) 2025 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package cloud

import (
"context"
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
)

type testAzureOIDCCredentials struct {
integration types.Integration
}

func (m *testAzureOIDCCredentials) GenerateAzureOIDCToken(ctx context.Context, integration string) (string, error) {
return "dummy-oidc-token", nil
}

func (m *testAzureOIDCCredentials) GetIntegration(ctx context.Context, name string) (types.Integration, error) {
if m.integration == nil || m.integration.GetName() != name {
return nil, trace.NotFound("integration %q not found", name)
}
return m.integration, nil
}

func TestWithAzureIntegrationCredentials(t *testing.T) {
const integrationName = "azure"

tests := []struct {
name string
integration types.Integration
wantErr string
}{
{
name: "valid azure integration",
integration: &types.IntegrationV1{
ResourceHeader: types.ResourceHeader{
Kind: types.KindIntegration,
SubKind: types.IntegrationSubKindAzureOIDC,
Version: types.V1,
Metadata: types.Metadata{
Name: integrationName,
Namespace: defaults.Namespace,
},
},
Spec: types.IntegrationSpecV1{
SubKindSpec: &types.IntegrationSpecV1_AzureOIDC{
AzureOIDC: &types.AzureOIDCIntegrationSpecV1{
ClientID: "baz-quux",
TenantID: "foo-bar",
},
},
},
},
},
{
name: "integration not found",
integration: nil,
wantErr: `integration "azure" not found`,
},
{
name: "invalid integration type",
integration: &types.IntegrationV1{
ResourceHeader: types.ResourceHeader{
Kind: types.KindIntegration,
SubKind: types.IntegrationSubKindAWSOIDC,
Version: types.V1,
Metadata: types.Metadata{
Name: "azure",
Namespace: defaults.Namespace,
},
},
Spec: types.IntegrationSpecV1{
SubKindSpec: &types.IntegrationSpecV1_AWSOIDC{
AWSOIDC: &types.AWSOIDCIntegrationSpecV1{
RoleARN: "arn:aws:iam::123456789012:role/teleport",
},
},
},
},
wantErr: `expected "azure" to be an "azure-oidc" integration, was "aws-oidc" instead`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opt := WithAzureIntegrationCredentials(integrationName, &testAzureOIDCCredentials{
integration: tt.integration,
})
clients, err := NewAzureClients(opt)
require.NoError(t, err)

credential, err := clients.GetAzureCredential()

if tt.wantErr == "" {
require.NoError(t, err)
require.NotNil(t, credential)
} else {
require.ErrorContains(t, err, tt.wantErr)
require.Nil(t, credential)
}
})
}
}
2 changes: 2 additions & 0 deletions lib/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,7 @@ func applyDiscoveryConfig(fc *FileConfig, cfg *servicecfg.Config) error {
Types: matcher.Types,
Regions: matcher.Regions,
ResourceTags: matcher.ResourceTags,
Integration: matcher.Integration,
Params: installParams,
}
if err := serviceMatcher.CheckAndSetDefaults(); err != nil {
Expand Down Expand Up @@ -1835,6 +1836,7 @@ func applyDatabasesConfig(fc *FileConfig, cfg *servicecfg.Config) error {
Types: matcher.Types,
Regions: matcher.Regions,
ResourceTags: matcher.ResourceTags,
Integration: matcher.Integration,
})
}
for _, database := range fc.Databases.Databases {
Expand Down
17 changes: 13 additions & 4 deletions lib/config/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ func TestConfigReading(t *testing.T) {
},
ResourceGroups: []string{"group1"},
Subscriptions: []string{"sub1"},
Integration: "integration1",
},
},
GCPMatchers: []GCPMatcher{
Expand Down Expand Up @@ -520,6 +521,7 @@ func TestConfigReading(t *testing.T) {
ResourceGroups: []string{"rg1", "rg2"},
Types: []string{"mysql"},
Regions: []string{"eastus", "westus"},
Integration: "integration1",
ResourceTags: map[string]apiutils.Strings{
"a": {"b"},
},
Expand Down Expand Up @@ -878,6 +880,7 @@ SREzU8onbBsjMg9QDiSf5oJLKvd/Ren+zGY7
ResourceGroups: []string{"group1", "group2"},
Types: []string{"postgres", "mysql"},
Regions: []string{"eastus", "centralus"},
Integration: "integration123",
ResourceTags: map[string]apiutils.Strings{
"a": {"b"},
},
Expand Down Expand Up @@ -1602,6 +1605,7 @@ func makeConfigFixture() string {
},
ResourceGroups: []string{"group1"},
Subscriptions: []string{"sub1"},
Integration: "integration1",
},
}

Expand Down Expand Up @@ -1721,6 +1725,7 @@ func makeConfigFixture() string {
ResourceTags: map[string]apiutils.Strings{
"a": {"b"},
},
Integration: "integration1",
},
{
Subscriptions: []string{"sub3", "sub4"},
Expand Down Expand Up @@ -3900,6 +3905,7 @@ func TestApplyDiscoveryConfig(t *testing.T) {
},
Suffix: "blue",
},
Integration: "integration123",
},
},
},
Expand All @@ -3921,6 +3927,7 @@ func TestApplyDiscoveryConfig(t *testing.T) {
},
Regions: []string{"*"},
ResourceTags: types.Labels{"*": []string{"*"}},
Integration: "integration123",
ResourceGroups: []string{"*"},
},
},
Expand Down Expand Up @@ -4628,8 +4635,9 @@ func TestDiscoveryConfig(t *testing.T) {
cfg["discovery_service"].(cfgMap)["enabled"] = "yes"
cfg["discovery_service"].(cfgMap)["azure"] = []cfgMap{
{
"types": []string{"aks"},
"regions": []string{"eucentral1"},
"types": []string{"aks"},
"regions": []string{"eucentral1"},
"integration": "integration1",
"tags": cfgMap{
"discover_teleport": "yes",
},
Expand All @@ -4639,8 +4647,9 @@ func TestDiscoveryConfig(t *testing.T) {
}
},
expectedAzureMatchers: []types.AzureMatcher{{
Types: []string{"aks"},
Regions: []string{"eucentral1"},
Types: []string{"aks"},
Regions: []string{"eucentral1"},
Integration: "integration1",
ResourceTags: map[string]apiutils.Strings{
"discover_teleport": []string{"yes"},
},
Expand Down
2 changes: 2 additions & 0 deletions lib/config/fileconf.go
Original file line number Diff line number Diff line change
Expand Up @@ -1952,6 +1952,8 @@ type AzureMatcher struct {
Regions []string `yaml:"regions,omitempty"`
// ResourceTags are Azure tags on resources to match.
ResourceTags map[string]apiutils.Strings `yaml:"tags,omitempty"`
// Integration is the Azure Integration name.
Integration string `yaml:"integration,omitempty"`
// InstallParams sets the join method when installing on
// discovered Azure nodes.
InstallParams *InstallParams `yaml:"install,omitempty"`
Expand Down
1 change: 1 addition & 0 deletions lib/config/testdata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ db_service:
resource_groups: ["group1", "group2"]
types: ["postgres", "mysql"]
regions: ["eastus", "centralus"]
integration: integration123
tags:
"a": "b"
- types: ["postgres", "mysql"]
Expand Down
15 changes: 7 additions & 8 deletions lib/services/matchers.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,13 @@ func SimplifyAzureMatchers(matchers []types.AzureMatcher) []types.AzureMatcher {
regions[i] = azureutils.NormalizeLocation(region)
}
}
result = append(result, types.AzureMatcher{
Subscriptions: subs,
ResourceGroups: groups,
Regions: regions,
Types: ts,
ResourceTags: m.ResourceTags,
Params: m.Params,
})
elem := m
elem.Subscriptions = subs
elem.ResourceGroups = groups
elem.Regions = regions
elem.Types = ts

result = append(result, elem)
}
return result
}
Expand Down
Loading
Loading