diff --git a/go.mod b/go.mod
index b368ee8247455..4a5958fc7f599 100644
--- a/go.mod
+++ b/go.mod
@@ -17,6 +17,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3 v3.0.1
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2 v2.4.0
+ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.2.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysqlflexibleservers v1.2.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql v1.2.0
diff --git a/go.sum b/go.sum
index 87cac2378c729..9964060e139d0 100644
--- a/go.sum
+++ b/go.sum
@@ -61,6 +61,8 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal v1.1.2 h1:mLY+pNL
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal v1.1.2/go.mod h1:FbdwsQ2EzwvXxOPcMFYO8ogEc9uMMIj3YkmCdXdAFmk=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI=
+github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0 h1:z4YeiSXxnUI+PqB46Yj6MZA3nwb1CcJIkEMDrzUd8Cs=
+github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0/go.mod h1:rko9SzMxcMk0NJsNAxALEGaTYyy79bNRwxgJfrH0Spw=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.2.0 h1:dhywcZH9yPDIje9aTqwy6psZSPzI6CJLYEprDahIBSQ=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.2.0/go.mod h1:6z3b+JdBLH0eMzfBex/cvEIoEFVEwXuB0wbgdfN11iM=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysqlflexibleservers v1.2.0 h1:3jDMffAwnvs6qmOqhjNVHB29AKxs6brnzJeo65E1YwM=
diff --git a/lib/cloud/azure/mocks.go b/lib/cloud/azure/mocks.go
index 4a2372fad8ba6..59133720d6f27 100644
--- a/lib/cloud/azure/mocks.go
+++ b/lib/cloud/azure/mocks.go
@@ -20,12 +20,15 @@ package azure
import (
"context"
+ "fmt"
+ "log/slog"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2"
+ "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysqlflexibleservers"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql"
@@ -664,3 +667,50 @@ func (m *ARMPostgresFlexServerMock) NewListByResourceGroupPager(group string, _
}, nil
})
}
+
+// ARMUserAssignedIdentitiesMock implements ARMUserAssignedIdentities.
+type ARMUserAssignedIdentitiesMock struct {
+ identitiesMap map[string]armmsi.Identity
+}
+
+// NewARMUserAssignedIdentitiesMock creates a new ARMUserAssignedIdentitiesMock.
+func NewARMUserAssignedIdentitiesMock(identities ...armmsi.Identity) *ARMUserAssignedIdentitiesMock {
+ identitiesMap := make(map[string]armmsi.Identity)
+ for _, identity := range identities {
+ id, err := arm.ParseResourceID(*identity.ID)
+ if err == nil {
+ identitiesMap[id.ResourceGroupName+"+"+id.Name] = identity
+ } else {
+ slog.With("error", err).WarnContext(context.Background(), "Failed to add identity to mock.")
+ }
+ }
+ return &ARMUserAssignedIdentitiesMock{
+ identitiesMap: identitiesMap,
+ }
+}
+
+func (m *ARMUserAssignedIdentitiesMock) Get(ctx context.Context, resourceGroupName, resourceName string, options *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
+ if m == nil || m.identitiesMap == nil {
+ return armmsi.UserAssignedIdentitiesClientGetResponse{}, trace.AccessDenied("access denied")
+ }
+
+ identity, found := m.identitiesMap[resourceGroupName+"+"+resourceName]
+ if !found {
+ return armmsi.UserAssignedIdentitiesClientGetResponse{}, trace.NotFound("%s of group %s not found", resourceName, resourceGroupName)
+ }
+ return armmsi.UserAssignedIdentitiesClientGetResponse{
+ Identity: identity,
+ }, nil
+}
+
+// NewUserAssignedIdentity creates an armmsi.Identity.
+func NewUserAssignedIdentity(subscription, resourceGroupName, resourceName, clientID string) armmsi.Identity {
+ id := fmt.Sprintf("/subscriptions/%s/resourcegroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscription, resourceGroupName, resourceName)
+ return armmsi.Identity{
+ ID: &id,
+ Name: &resourceName,
+ Properties: &armmsi.UserAssignedIdentityProperties{
+ ClientID: &clientID,
+ },
+ }
+}
diff --git a/lib/cloud/azure/user_identities.go b/lib/cloud/azure/user_identities.go
new file mode 100644
index 0000000000000..9ad1805fff660
--- /dev/null
+++ b/lib/cloud/azure/user_identities.go
@@ -0,0 +1,72 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package azure
+
+import (
+ "context"
+
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore"
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
+ "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi"
+ "github.com/gravitational/trace"
+)
+
+// ARMUserAssignedIdentities provides an interface for
+// armmsi.UserAssignedIdentitiesClient.
+type ARMUserAssignedIdentities interface {
+ Get(ctx context.Context, resourceGroupName, resourceName string, options *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error)
+}
+
+// UserAssignedIdentitiesClient wraps the armmsi.UserAssignedIdentitiesClient to fetch
+// identity info.
+type UserAssignedIdentitiesClient struct {
+ api ARMUserAssignedIdentities
+}
+
+// NewUserAssignedIdentitiesClient creates a new UserAssignedIdentitiesClient
+// by subscription and credential.
+func NewUserAssignedIdentitiesClient(subscription string, cred azcore.TokenCredential, options *arm.ClientOptions) (*UserAssignedIdentitiesClient, error) {
+ api, err := armmsi.NewUserAssignedIdentitiesClient(subscription, cred, options)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return NewUserAssignedIdentitiesClientByAPI(api), nil
+}
+
+// NewUserAssignedIdentitiesClientByAPI creates a new
+// UserAssignedIdentitiesClient by ARMUserAssignedIdentities interface.
+func NewUserAssignedIdentitiesClientByAPI(api ARMUserAssignedIdentities) *UserAssignedIdentitiesClient {
+ return &UserAssignedIdentitiesClient{
+ api: api,
+ }
+}
+
+// GetClientID returns the client ID for the provided identity.
+func (c *UserAssignedIdentitiesClient) GetClientID(ctx context.Context, resourceGroupName, resourceName string) (string, error) {
+ identity, err := c.api.Get(ctx, resourceGroupName, resourceName, nil)
+ if err != nil {
+ return "", trace.Wrap(ConvertResponseError(err))
+ }
+
+ if identity.Properties == nil || identity.Properties.ClientID == nil {
+ return "", trace.BadParameter("cannot find ClientID from identity %s", resourceName)
+ }
+
+ return *identity.Properties.ClientID, nil
+}
diff --git a/lib/cloud/azure/user_identities_test.go b/lib/cloud/azure/user_identities_test.go
new file mode 100644
index 0000000000000..4176c72a614e4
--- /dev/null
+++ b/lib/cloud/azure/user_identities_test.go
@@ -0,0 +1,67 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package azure
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserAssignedIdentitiesClient(t *testing.T) {
+ t.Parallel()
+
+ bot1 := NewUserAssignedIdentity("my-sub", "my-group", "bot1", "bot1-id")
+ mockAPI := NewARMUserAssignedIdentitiesMock(bot1)
+
+ tests := []struct {
+ name string
+ inputResourceGroupName string
+ inputUserName string
+ wantError bool
+ wantClientID string
+ }{
+ {
+ name: "success",
+ inputResourceGroupName: "my-group",
+ inputUserName: "bot1",
+ wantClientID: "bot1-id",
+ },
+ {
+ name: "not found",
+ inputResourceGroupName: "my-group",
+ inputUserName: "bot5",
+ wantError: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ client := NewUserAssignedIdentitiesClientByAPI(mockAPI)
+ actualClientID, err := client.GetClientID(context.Background(), test.inputResourceGroupName, test.inputUserName)
+ if test.wantError {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ }
+ require.Equal(t, test.wantClientID, actualClientID)
+ })
+ }
+}
diff --git a/lib/service/service.go b/lib/service/service.go
index 8f7899a02cb58..1bf6c82469cc5 100644
--- a/lib/service/service.go
+++ b/lib/service/service.go
@@ -5482,6 +5482,7 @@ func (process *TeleportProcess) initApps() {
ConnectedProxyGetter: proxyGetter,
Emitter: asyncEmitter,
ConnectionMonitor: connMonitor,
+ Logger: logger,
})
if err != nil {
return trace.Wrap(err)
diff --git a/lib/srv/app/azure/credential.go b/lib/srv/app/azure/credential.go
new file mode 100644
index 0000000000000..1275924d88a6e
--- /dev/null
+++ b/lib/srv/app/azure/credential.go
@@ -0,0 +1,191 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package azure
+
+import (
+ "context"
+ "log/slog"
+ "time"
+
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore"
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
+ "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
+ "github.com/gravitational/trace"
+
+ cloudazure "github.com/gravitational/teleport/lib/cloud/azure"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+// credentialProvider defines an interface that manages a particular type of
+// credential.
+type credentialProvider interface {
+ // MakeCredential creates an azcore.TokenCredential for provided identity.
+ MakeCredential(ctx context.Context, userRequestedIdentity string) (azcore.TokenCredential, error)
+ // MapScope maps the input scope if necessary.
+ MapScope(scope string) string
+}
+
+func getAccessTokenFromCredentialProvider(credProvider credentialProvider) getAccessTokenFunc {
+ return func(ctx context.Context, userRequestedIdentity string, scope string) (*azcore.AccessToken, error) {
+ credential, err := credProvider.MakeCredential(ctx, userRequestedIdentity)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ opts := policy.TokenRequestOptions{
+ Scopes: []string{credProvider.MapScope(scope)},
+ }
+ token, err := credential.GetToken(ctx, opts)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &token, nil
+ }
+}
+
+func findDefaultCredentialProvider(ctx context.Context, logger *slog.Logger) (credentialProvider, error) {
+ // Check if default workload identity is available: the clientID/tenantID
+ // for the default workload identity and the token file path are required
+ // from environment variables.
+ defaultWorkloadIdentity, err := azidentity.NewWorkloadIdentityCredential(nil)
+ if err != nil {
+ // If no workload identity is found, fall back to regular managed identity.
+ logger.With("error", err).InfoContext(ctx, "Failed to load workload identity. Using managed identity.")
+ return managedIdentityCredentialProvider{}, nil
+ }
+
+ logger.InfoContext(ctx, "Using workload identity.")
+ credProvider, err := newWorloadIdentityCredentialProvider(ctx, defaultWorkloadIdentity)
+ return credProvider, trace.Wrap(err)
+}
+
+// managedIdentityCredentialProvider implements credentialProvider for using
+// managed identities assigned to the host machine. Identities are usually
+// checked against the IMDS service available in the local network.
+type managedIdentityCredentialProvider struct {
+}
+
+func (m managedIdentityCredentialProvider) MakeCredential(ctx context.Context, userRequestedIdentity string) (azcore.TokenCredential, error) {
+ credenial, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
+ ID: azidentity.ResourceID(userRequestedIdentity),
+ })
+ return credenial, trace.Wrap(err)
+}
+
+func (m managedIdentityCredentialProvider) MapScope(scope string) string {
+ // No scope needs to be mapped.
+ return scope
+}
+
+// workloadIdentityCredentialProvider implements credentialProvider for using
+// workload identities assigned to the host machine.
+//
+// https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview
+//
+// When running on AKS, multiple workload identities can be associated to the
+// same service account attached to the pod. Assuming a workload identity
+// requires a client ID of that identity but only the default Client ID is
+// provided through environment variable. We assume that the default workload
+// identity (mapped by the default client ID) is the "app-service" identity
+// with msi permissions so the client IDs for other "user-requested" identity
+// can be retrieved using the default identity.
+type workloadIdentityCredentialProvider struct {
+ cache *utils.FnCache
+ defaultAgentIdentity azcore.TokenCredential
+
+ // newClient defaults to cloudazure.NewUserAssignedIdentitiesClient. Can be
+ // overridden for test.
+ newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (*cloudazure.UserAssignedIdentitiesClient, error)
+ // newCredential defaults to newWorkloadIdentityCredentialForClientID. Can
+ // be overridden for test.
+ newCredential func(string) (azcore.TokenCredential, error)
+}
+
+func newWorloadIdentityCredentialProvider(ctx context.Context, defaultAgentIdentity azcore.TokenCredential) (*workloadIdentityCredentialProvider, error) {
+ if defaultAgentIdentity == nil {
+ return nil, trace.BadParameter("missing defaultAgentIdentity")
+ }
+ cache, err := utils.NewFnCache(utils.FnCacheConfig{
+ Context: ctx,
+ TTL: clientIDCacheTTL,
+ ReloadOnErr: true,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &workloadIdentityCredentialProvider{
+ cache: cache,
+ defaultAgentIdentity: defaultAgentIdentity,
+ newClient: cloudazure.NewUserAssignedIdentitiesClient,
+ newCredential: newWorkloadIdentityCredentialForClientID,
+ }, nil
+}
+
+func newWorkloadIdentityCredentialForClientID(clientID string) (azcore.TokenCredential, error) {
+ cred, err := azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{
+ ClientID: clientID,
+ })
+ return cred, trace.Wrap(err)
+}
+
+func (w *workloadIdentityCredentialProvider) MakeCredential(ctx context.Context, userRequestedIdentity string) (azcore.TokenCredential, error) {
+ clientID, err := w.getClientID(ctx, userRequestedIdentity)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ credential, err := w.newCredential(clientID)
+ return credential, trace.Wrap(err)
+}
+
+func (w *workloadIdentityCredentialProvider) MapScope(scope string) string {
+ // This scope ("https://management.core.windows.net/") from `az` CLI tool
+ // will fail for workload identity as workload identity is only expected to
+ // be used with compatible SDKs, whereas the SDK adds ".default" to the
+ // audience:
+ //
+ // https://github.com/Azure/azure-sdk-for-go/blob/9e78ee2b86f0f4989098dd7e545b73841fc8df47/sdk/azcore/arm/runtime/pipeline.go#L35
+ if scope == "https://management.core.windows.net/" {
+ return scope + ".default"
+ }
+ return scope
+}
+
+func (w *workloadIdentityCredentialProvider) getClientID(ctx context.Context, identityResourceID string) (string, error) {
+ clientID, err := utils.FnCacheGet(ctx, w.cache, identityResourceID, func(ctx context.Context) (string, error) {
+ resourceID, err := arm.ParseResourceID(identityResourceID)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ client, err := w.newClient(resourceID.SubscriptionID, w.defaultAgentIdentity, nil)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ clientID, err := client.GetClientID(ctx, resourceID.ResourceGroupName, resourceID.Name)
+ return clientID, trace.Wrap(err)
+ })
+ return clientID, trace.Wrap(err)
+}
+
+// clientIDCacheTTL defines how long client IDs should be cached. Client IDs
+// should never change for an identity so use a longer cache TTL.
+var clientIDCacheTTL = 30 * time.Minute
diff --git a/lib/srv/app/azure/credential_test.go b/lib/srv/app/azure/credential_test.go
new file mode 100644
index 0000000000000..8f5fc6111ea16
--- /dev/null
+++ b/lib/srv/app/azure/credential_test.go
@@ -0,0 +1,130 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package azure
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore"
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+
+ cloudazure "github.com/gravitational/teleport/lib/cloud/azure"
+)
+
+type fakeTokenCredential struct {
+ lastSeenScope string
+}
+
+func (f *fakeTokenCredential) GetToken(_ context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
+ if len(opts.Scopes) != 1 {
+ return azcore.AccessToken{}, trace.BadParameter("expect one scope but got %v", opts.Scopes)
+ }
+
+ f.lastSeenScope = opts.Scopes[0]
+ return azcore.AccessToken{
+ Token: "fake-token",
+ ExpiresOn: time.Now().Add(time.Hour),
+ }, nil
+}
+
+type fakeCredentialProvider struct {
+ cred fakeTokenCredential
+ lastSeenIdentity string
+}
+
+func (f *fakeCredentialProvider) MakeCredential(_ context.Context, userRequestedIdentity string) (azcore.TokenCredential, error) {
+ f.lastSeenIdentity = userRequestedIdentity
+ return &f.cred, nil
+}
+
+func (f *fakeCredentialProvider) MapScope(scope string) string {
+ return scope + ".mapped"
+}
+
+func Test_getAccessTokenFromCredentialProvider(t *testing.T) {
+ fakeCredProvider := &fakeCredentialProvider{}
+ userRequestedIdentity := "/subscriptions/my-sub/resourcegroups/my-group/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-name"
+ ctx := context.Background()
+
+ token, err := getAccessTokenFromCredentialProvider(fakeCredProvider)(ctx, userRequestedIdentity, "test-scope")
+ require.NoError(t, err)
+ require.Equal(t, "fake-token", token.Token)
+ require.Equal(t, userRequestedIdentity, fakeCredProvider.lastSeenIdentity)
+ require.Equal(t, "test-scope.mapped", fakeCredProvider.cred.lastSeenScope)
+}
+
+func Test_workloadIdentityCredentialProvider(t *testing.T) {
+ ctx := context.Background()
+ fakeAgentIdentity := &fakeTokenCredential{}
+ credProvider, err := newWorloadIdentityCredentialProvider(ctx, fakeAgentIdentity)
+ require.NoError(t, err)
+
+ // Hook up more mocks.
+ fakeWorkloadIdentityCredential := &fakeTokenCredential{}
+ userRequestedIdentity := cloudazure.NewUserAssignedIdentity("my-sub", "my-group", "my-name", "my-client-id")
+ mockAPI := cloudazure.NewARMUserAssignedIdentitiesMock(userRequestedIdentity)
+ credProvider.newClient = func(string, azcore.TokenCredential, *arm.ClientOptions) (*cloudazure.UserAssignedIdentitiesClient, error) {
+ return cloudazure.NewUserAssignedIdentitiesClientByAPI(mockAPI), nil
+ }
+ credProvider.newCredential = func(clientID string) (azcore.TokenCredential, error) {
+ if clientID != "my-client-id" {
+ return nil, trace.BadParameter("expect my-client-id but got %s", clientID)
+ }
+ return fakeWorkloadIdentityCredential, nil
+ }
+
+ t.Run("MakeCredential", func(t *testing.T) {
+ t.Run("success", func(t *testing.T) {
+ actualCredential, err := credProvider.MakeCredential(ctx, *userRequestedIdentity.ID)
+ require.NoError(t, err)
+ require.Same(t, fakeWorkloadIdentityCredential, actualCredential)
+ })
+ t.Run("fail to get client ID", func(t *testing.T) {
+ notFoundIdentity := "/subscriptions/my-sub/resourcegroups/my-group/providers/Microsoft.ManagedIdentity/userAssignedIdentities/not-my-name"
+ _, err := credProvider.MakeCredential(ctx, notFoundIdentity)
+ require.Error(t, err)
+ })
+ })
+
+ t.Run("MapScope", func(t *testing.T) {
+ tests := []struct {
+ inputScope string
+ outputScope string
+ }{
+ {
+ inputScope: "https://management.core.windows.net/",
+ outputScope: "https://management.core.windows.net/.default",
+ },
+ {
+ inputScope: "some-other-scope",
+ outputScope: "some-other-scope",
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.inputScope, func(t *testing.T) {
+ require.Equal(t, test.outputScope, credProvider.MapScope(test.inputScope))
+ })
+ }
+ })
+}
diff --git a/lib/srv/app/azure/handler.go b/lib/srv/app/azure/handler.go
index 65603609f1970..04c9f3b0bedb1 100644
--- a/lib/srv/app/azure/handler.go
+++ b/lib/srv/app/azure/handler.go
@@ -23,13 +23,12 @@ import (
"context"
"crypto"
"crypto/x509"
+ "log/slog"
"net/http"
"strings"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
- "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
- "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
@@ -45,12 +44,18 @@ import (
"github.com/gravitational/teleport/lib/utils"
)
+// ComponentKey is the Teleport component key for this handler.
+const ComponentKey = "azure:fwd"
+
// HandlerConfig is the configuration for an Azure app-access handler.
type HandlerConfig struct {
// RoundTripper is the underlying transport given to an oxy Forwarder.
RoundTripper http.RoundTripper
// Log is the Logger.
+ // TODO(greedy52) replace with slog.
Log logrus.FieldLogger
+ // Logger is the slog.Logger.
+ Logger *slog.Logger
// Clock is used to override time in tests.
Clock clockwork.Clock
@@ -59,7 +64,7 @@ type HandlerConfig struct {
}
// CheckAndSetDefaults validates the HandlerConfig.
-func (s *HandlerConfig) CheckAndSetDefaults() error {
+func (s *HandlerConfig) CheckAndSetDefaults(ctx context.Context) error {
if s.RoundTripper == nil {
tr, err := defaults.Transport()
if err != nil {
@@ -71,10 +76,17 @@ func (s *HandlerConfig) CheckAndSetDefaults() error {
s.Clock = clockwork.NewRealClock()
}
if s.Log == nil {
- s.Log = logrus.WithField(teleport.ComponentKey, "azure:fwd")
+ s.Log = logrus.WithField(teleport.ComponentKey, ComponentKey)
+ }
+ if s.Logger == nil {
+ s.Logger = slog.Default().With(teleport.ComponentKey, ComponentKey)
}
if s.getAccessToken == nil {
- s.getAccessToken = getAccessTokenManagedIdentity
+ credProvider, err := findDefaultCredentialProvider(ctx, s.Logger)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ s.getAccessToken = getAccessTokenFromCredentialProvider(credProvider)
}
return nil
}
@@ -99,7 +111,7 @@ func NewAzureHandler(ctx context.Context, config HandlerConfig) (http.Handler, e
// newAzureHandler creates a new instance of a handler for Azure requests. Used by NewAzureHandler and in tests.
func newAzureHandler(ctx context.Context, config HandlerConfig) (*handler, error) {
- if err := config.CheckAndSetDefaults(); err != nil {
+ if err := config.CheckAndSetDefaults(ctx); err != nil {
return nil, trace.Wrap(err)
}
@@ -262,20 +274,6 @@ func (s *handler) parseAuthHeader(token string, pubKey crypto.PublicKey) (*jwt.A
type getAccessTokenFunc func(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error)
-func getAccessTokenManagedIdentity(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) {
- identityCredential, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ResourceID(managedIdentity)})
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- opts := policy.TokenRequestOptions{Scopes: []string{scope}}
- token, err := identityCredential.GetToken(ctx, opts)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- return &token, nil
-}
-
type cacheKey struct {
managedIdentity string
scope string
diff --git a/lib/srv/app/server.go b/lib/srv/app/server.go
index a6a14c6e3f980..b122cf09ab06b 100644
--- a/lib/srv/app/server.go
+++ b/lib/srv/app/server.go
@@ -26,6 +26,7 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
+ "log/slog"
"net"
"net/http"
"strconv"
@@ -131,6 +132,9 @@ type Config struct {
// ConnectionMonitor monitors connections and terminates any if
// any session controls prevent them.
ConnectionMonitor ConnMonitor
+
+ // Logger is the slog.Logger.
+ Logger *slog.Logger
}
// CheckAndSetDefaults makes sure the configuration has the minimum required
@@ -180,6 +184,9 @@ func (c *Config) CheckAndSetDefaults() error {
if c.ConnectedProxyGetter == nil {
c.ConnectedProxyGetter = reversetunnel.NewConnectedProxyGetter()
}
+ if c.Logger == nil {
+ c.Logger = slog.Default().With(teleport.ComponentKey, teleport.Component(teleport.ComponentApp))
+ }
return nil
}
@@ -289,7 +296,9 @@ func New(ctx context.Context, c *Config) (*Server, error) {
return nil, trace.Wrap(err)
}
- azureHandler, err := appazure.NewAzureHandler(closeContext, appazure.HandlerConfig{})
+ azureHandler, err := appazure.NewAzureHandler(closeContext, appazure.HandlerConfig{
+ Logger: c.Logger.With(teleport.ComponentKey, appazure.ComponentKey),
+ })
if err != nil {
return nil, trace.Wrap(err)
}
@@ -301,6 +310,7 @@ func New(ctx context.Context, c *Config) (*Server, error) {
s := &Server{
c: c,
+ // TODO(greedy52) replace with slog from Config.Logger.
log: logrus.WithFields(logrus.Fields{
teleport.ComponentKey: teleport.ComponentApp,
}),