Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions lib/cloud/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/cloud/gcp"
)
Expand Down Expand Up @@ -89,8 +90,41 @@ type AzureClients interface {
GetAzureRunCommandClient(subscription string) (azure.RunCommandClient, error)
}

// AzureClientsOption is an option to pass to NewAzureClients
type AzureClientsOption func(clients *azureClients)

type azureOIDCCredentials interface {
GenerateAzureOIDCToken(ctx context.Context, integration string) (string, error)
GetIntegration(ctx context.Context, name string) (types.Integration, error)
}

// WithAzureIntegrationCredentials configures Azure cloud clients to use integration credentials.
func WithAzureIntegrationCredentials(integrationName string, auth azureOIDCCredentials) AzureClientsOption {
return func(clt *azureClients) {
clt.newAzureCredentialFunc = func() (azcore.TokenCredential, error) {
ctx := context.TODO()
integration, err := auth.GetIntegration(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}
spec := integration.GetAzureOIDCIntegrationSpec()
if spec == nil {
return nil, trace.BadParameter("expected %q to be an %q integration, was %q instead", integration.GetName(), types.IntegrationSubKindAzureOIDC, integration.GetSubKind())
}
cred, err := azidentity.NewClientAssertionCredential(spec.TenantID, spec.ClientID, func(ctx context.Context) (string, error) {
return auth.GenerateAzureOIDCToken(ctx, integrationName)
// TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options
}, nil)
if err != nil {
return nil, trace.Wrap(err)
}
return cred, nil
}
}
}

// NewAzureClients returns a new instance of Azure SDK clients.
func NewAzureClients() (AzureClients, error) {
func NewAzureClients(opts ...AzureClientsOption) (AzureClients, error) {
azClients := &azureClients{
azureMySQLClients: make(map[string]azure.DBServersClient),
azurePostgresClients: make(map[string]azure.DBServersClient),
Expand Down Expand Up @@ -130,6 +164,15 @@ func NewAzureClients() (AzureClients, error) {
return nil, trace.Wrap(err)
}

azClients.newAzureCredentialFunc = func() (azcore.TokenCredential, error) {
// TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options
return azidentity.NewDefaultAzureCredential(nil)
}

for _, opt := range opts {
opt(azClients)
}

return azClients, nil
}

Expand Down Expand Up @@ -168,8 +211,11 @@ type azureClients struct {
// mtx is used for locking.
mtx sync.RWMutex

// newAzureCredentialFunc creates new Azure credential.
newAzureCredentialFunc func() (azcore.TokenCredential, error)
// azureCredential is the cached Azure credential.
azureCredential azcore.TokenCredential

// azureMySQLClients is the cached Azure MySQL Server clients.
azureMySQLClients map[string]azure.DBServersClient
// azurePostgresClients is the cached Azure Postgres Server clients.
Expand Down Expand Up @@ -378,9 +424,8 @@ func (c *azureClients) initAzureCredential() (azcore.TokenCredential, error) {
if c.azureCredential != nil { // If some other thread already got here first.
return c.azureCredential, nil
}
// TODO(gavin): if/when we support AzureChina/AzureGovernment, we will need to specify the cloud in these options
options := &azidentity.DefaultAzureCredentialOptions{}
cred, err := azidentity.NewDefaultAzureCredential(options)

cred, err := c.newAzureCredentialFunc()
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
123 changes: 123 additions & 0 deletions lib/cloud/clients_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Teleport
// Copyright (C) 2025 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package cloud

import (
"context"
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
)

type testAzureOIDCCredentials struct {
integration types.Integration
}

func (m *testAzureOIDCCredentials) GenerateAzureOIDCToken(ctx context.Context, integration string) (string, error) {
return "dummy-oidc-token", nil
}

func (m *testAzureOIDCCredentials) GetIntegration(ctx context.Context, name string) (types.Integration, error) {
if m.integration == nil || m.integration.GetName() != name {
return nil, trace.NotFound("integration %q not found", name)
}
return m.integration, nil
}

func TestWithAzureIntegrationCredentials(t *testing.T) {
const integrationName = "azure"

tests := []struct {
name string
integration types.Integration
wantErr string
}{
{
name: "valid azure integration",
integration: &types.IntegrationV1{
ResourceHeader: types.ResourceHeader{
Kind: types.KindIntegration,
SubKind: types.IntegrationSubKindAzureOIDC,
Version: types.V1,
Metadata: types.Metadata{
Name: integrationName,
Namespace: defaults.Namespace,
},
},
Spec: types.IntegrationSpecV1{
SubKindSpec: &types.IntegrationSpecV1_AzureOIDC{
AzureOIDC: &types.AzureOIDCIntegrationSpecV1{
ClientID: "baz-quux",
TenantID: "foo-bar",
},
},
},
},
},
{
name: "integration not found",
integration: nil,
wantErr: `integration "azure" not found`,
},
{
name: "invalid integration type",
integration: &types.IntegrationV1{
ResourceHeader: types.ResourceHeader{
Kind: types.KindIntegration,
SubKind: types.IntegrationSubKindAWSOIDC,
Version: types.V1,
Metadata: types.Metadata{
Name: "azure",
Namespace: defaults.Namespace,
},
},
Spec: types.IntegrationSpecV1{
SubKindSpec: &types.IntegrationSpecV1_AWSOIDC{
AWSOIDC: &types.AWSOIDCIntegrationSpecV1{
RoleARN: "arn:aws:iam::123456789012:role/teleport",
},
},
},
},
wantErr: `expected "azure" to be an "azure-oidc" integration, was "aws-oidc" instead`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opt := WithAzureIntegrationCredentials(integrationName, &testAzureOIDCCredentials{
integration: tt.integration,
})
clients, err := NewAzureClients(opt)
require.NoError(t, err)

credential, err := clients.GetAzureCredential()

if tt.wantErr == "" {
require.NoError(t, err)
require.NotNil(t, credential)
} else {
require.ErrorContains(t, err, tt.wantErr)
require.Nil(t, credential)
}
})
}
}
2 changes: 2 additions & 0 deletions lib/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,7 @@ func applyDiscoveryConfig(fc *FileConfig, cfg *servicecfg.Config) error {
Types: matcher.Types,
Regions: matcher.Regions,
ResourceTags: matcher.ResourceTags,
Integration: matcher.Integration,
Params: installParams,
}
if err := serviceMatcher.CheckAndSetDefaults(); err != nil {
Expand Down Expand Up @@ -1834,6 +1835,7 @@ func applyDatabasesConfig(fc *FileConfig, cfg *servicecfg.Config) error {
Types: matcher.Types,
Regions: matcher.Regions,
ResourceTags: matcher.ResourceTags,
Integration: matcher.Integration,
})
}
for _, database := range fc.Databases.Databases {
Expand Down
17 changes: 13 additions & 4 deletions lib/config/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ func TestConfigReading(t *testing.T) {
},
ResourceGroups: []string{"group1"},
Subscriptions: []string{"sub1"},
Integration: "integration1",
},
},
GCPMatchers: []GCPMatcher{
Expand Down Expand Up @@ -520,6 +521,7 @@ func TestConfigReading(t *testing.T) {
ResourceGroups: []string{"rg1", "rg2"},
Types: []string{"mysql"},
Regions: []string{"eastus", "westus"},
Integration: "integration1",
ResourceTags: map[string]apiutils.Strings{
"a": {"b"},
},
Expand Down Expand Up @@ -878,6 +880,7 @@ SREzU8onbBsjMg9QDiSf5oJLKvd/Ren+zGY7
ResourceGroups: []string{"group1", "group2"},
Types: []string{"postgres", "mysql"},
Regions: []string{"eastus", "centralus"},
Integration: "integration123",
ResourceTags: map[string]apiutils.Strings{
"a": {"b"},
},
Expand Down Expand Up @@ -1602,6 +1605,7 @@ func makeConfigFixture() string {
},
ResourceGroups: []string{"group1"},
Subscriptions: []string{"sub1"},
Integration: "integration1",
},
}

Expand Down Expand Up @@ -1721,6 +1725,7 @@ func makeConfigFixture() string {
ResourceTags: map[string]apiutils.Strings{
"a": {"b"},
},
Integration: "integration1",
},
{
Subscriptions: []string{"sub3", "sub4"},
Expand Down Expand Up @@ -3901,6 +3906,7 @@ func TestApplyDiscoveryConfig(t *testing.T) {
},
Suffix: "blue",
},
Integration: "integration123",
},
},
},
Expand All @@ -3922,6 +3928,7 @@ func TestApplyDiscoveryConfig(t *testing.T) {
},
Regions: []string{"*"},
ResourceTags: types.Labels{"*": []string{"*"}},
Integration: "integration123",
ResourceGroups: []string{"*"},
},
},
Expand Down Expand Up @@ -4629,8 +4636,9 @@ func TestDiscoveryConfig(t *testing.T) {
cfg["discovery_service"].(cfgMap)["enabled"] = "yes"
cfg["discovery_service"].(cfgMap)["azure"] = []cfgMap{
{
"types": []string{"aks"},
"regions": []string{"eucentral1"},
"types": []string{"aks"},
"regions": []string{"eucentral1"},
"integration": "integration1",
"tags": cfgMap{
"discover_teleport": "yes",
},
Expand All @@ -4640,8 +4648,9 @@ func TestDiscoveryConfig(t *testing.T) {
}
},
expectedAzureMatchers: []types.AzureMatcher{{
Types: []string{"aks"},
Regions: []string{"eucentral1"},
Types: []string{"aks"},
Regions: []string{"eucentral1"},
Integration: "integration1",
ResourceTags: map[string]apiutils.Strings{
"discover_teleport": []string{"yes"},
},
Expand Down
2 changes: 2 additions & 0 deletions lib/config/fileconf.go
Original file line number Diff line number Diff line change
Expand Up @@ -1959,6 +1959,8 @@ type AzureMatcher struct {
Regions []string `yaml:"regions,omitempty"`
// ResourceTags are Azure tags on resources to match.
ResourceTags map[string]apiutils.Strings `yaml:"tags,omitempty"`
// Integration is the Azure Integration name.
Integration string `yaml:"integration,omitempty"`
// InstallParams sets the join method when installing on
// discovered Azure nodes.
InstallParams *InstallParams `yaml:"install,omitempty"`
Expand Down
1 change: 1 addition & 0 deletions lib/config/testdata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ db_service:
resource_groups: ["group1", "group2"]
types: ["postgres", "mysql"]
regions: ["eastus", "centralus"]
integration: integration123
tags:
"a": "b"
- types: ["postgres", "mysql"]
Expand Down
15 changes: 7 additions & 8 deletions lib/services/matchers.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,13 @@ func SimplifyAzureMatchers(matchers []types.AzureMatcher) []types.AzureMatcher {
regions[i] = azureutils.NormalizeLocation(region)
}
}
result = append(result, types.AzureMatcher{
Subscriptions: subs,
ResourceGroups: groups,
Regions: regions,
Types: ts,
ResourceTags: m.ResourceTags,
Params: m.Params,
})
elem := m
elem.Subscriptions = subs
elem.ResourceGroups = groups
elem.Regions = regions
elem.Types = ts

result = append(result, elem)
}
return result
}
Expand Down
Loading
Loading