diff --git a/go.mod b/go.mod index b368ee8247455..4a5958fc7f599 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3 v3.0.1 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2 v2.4.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.2.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysqlflexibleservers v1.2.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql v1.2.0 diff --git a/go.sum b/go.sum index 87cac2378c729..9964060e139d0 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,8 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal v1.1.2 h1:mLY+pNL github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal v1.1.2/go.mod h1:FbdwsQ2EzwvXxOPcMFYO8ogEc9uMMIj3YkmCdXdAFmk= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0 h1:z4YeiSXxnUI+PqB46Yj6MZA3nwb1CcJIkEMDrzUd8Cs= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0/go.mod h1:rko9SzMxcMk0NJsNAxALEGaTYyy79bNRwxgJfrH0Spw= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.2.0 h1:dhywcZH9yPDIje9aTqwy6psZSPzI6CJLYEprDahIBSQ= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.2.0/go.mod h1:6z3b+JdBLH0eMzfBex/cvEIoEFVEwXuB0wbgdfN11iM= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysqlflexibleservers v1.2.0 h1:3jDMffAwnvs6qmOqhjNVHB29AKxs6brnzJeo65E1YwM= diff --git a/lib/cloud/azure/mocks.go b/lib/cloud/azure/mocks.go index 4a2372fad8ba6..59133720d6f27 100644 --- a/lib/cloud/azure/mocks.go +++ b/lib/cloud/azure/mocks.go @@ -20,12 +20,15 @@ package azure import ( "context" + "fmt" + "log/slog" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysqlflexibleservers" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql" @@ -664,3 +667,50 @@ func (m *ARMPostgresFlexServerMock) NewListByResourceGroupPager(group string, _ }, nil }) } + +// ARMUserAssignedIdentitiesMock implements ARMUserAssignedIdentities. +type ARMUserAssignedIdentitiesMock struct { + identitiesMap map[string]armmsi.Identity +} + +// NewARMUserAssignedIdentitiesMock creates a new ARMUserAssignedIdentitiesMock. +func NewARMUserAssignedIdentitiesMock(identities ...armmsi.Identity) *ARMUserAssignedIdentitiesMock { + identitiesMap := make(map[string]armmsi.Identity) + for _, identity := range identities { + id, err := arm.ParseResourceID(*identity.ID) + if err == nil { + identitiesMap[id.ResourceGroupName+"+"+id.Name] = identity + } else { + slog.With("error", err).WarnContext(context.Background(), "Failed to add identity to mock.") + } + } + return &ARMUserAssignedIdentitiesMock{ + identitiesMap: identitiesMap, + } +} + +func (m *ARMUserAssignedIdentitiesMock) Get(ctx context.Context, resourceGroupName, resourceName string, options *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) { + if m == nil || m.identitiesMap == nil { + return armmsi.UserAssignedIdentitiesClientGetResponse{}, trace.AccessDenied("access denied") + } + + identity, found := m.identitiesMap[resourceGroupName+"+"+resourceName] + if !found { + return armmsi.UserAssignedIdentitiesClientGetResponse{}, trace.NotFound("%s of group %s not found", resourceName, resourceGroupName) + } + return armmsi.UserAssignedIdentitiesClientGetResponse{ + Identity: identity, + }, nil +} + +// NewUserAssignedIdentity creates an armmsi.Identity. +func NewUserAssignedIdentity(subscription, resourceGroupName, resourceName, clientID string) armmsi.Identity { + id := fmt.Sprintf("/subscriptions/%s/resourcegroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscription, resourceGroupName, resourceName) + return armmsi.Identity{ + ID: &id, + Name: &resourceName, + Properties: &armmsi.UserAssignedIdentityProperties{ + ClientID: &clientID, + }, + } +} diff --git a/lib/cloud/azure/user_identities.go b/lib/cloud/azure/user_identities.go new file mode 100644 index 0000000000000..9ad1805fff660 --- /dev/null +++ b/lib/cloud/azure/user_identities.go @@ -0,0 +1,72 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azure + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" + "github.com/gravitational/trace" +) + +// ARMUserAssignedIdentities provides an interface for +// armmsi.UserAssignedIdentitiesClient. +type ARMUserAssignedIdentities interface { + Get(ctx context.Context, resourceGroupName, resourceName string, options *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) +} + +// UserAssignedIdentitiesClient wraps the armmsi.UserAssignedIdentitiesClient to fetch +// identity info. +type UserAssignedIdentitiesClient struct { + api ARMUserAssignedIdentities +} + +// NewUserAssignedIdentitiesClient creates a new UserAssignedIdentitiesClient +// by subscription and credential. +func NewUserAssignedIdentitiesClient(subscription string, cred azcore.TokenCredential, options *arm.ClientOptions) (*UserAssignedIdentitiesClient, error) { + api, err := armmsi.NewUserAssignedIdentitiesClient(subscription, cred, options) + if err != nil { + return nil, trace.Wrap(err) + } + return NewUserAssignedIdentitiesClientByAPI(api), nil +} + +// NewUserAssignedIdentitiesClientByAPI creates a new +// UserAssignedIdentitiesClient by ARMUserAssignedIdentities interface. +func NewUserAssignedIdentitiesClientByAPI(api ARMUserAssignedIdentities) *UserAssignedIdentitiesClient { + return &UserAssignedIdentitiesClient{ + api: api, + } +} + +// GetClientID returns the client ID for the provided identity. +func (c *UserAssignedIdentitiesClient) GetClientID(ctx context.Context, resourceGroupName, resourceName string) (string, error) { + identity, err := c.api.Get(ctx, resourceGroupName, resourceName, nil) + if err != nil { + return "", trace.Wrap(ConvertResponseError(err)) + } + + if identity.Properties == nil || identity.Properties.ClientID == nil { + return "", trace.BadParameter("cannot find ClientID from identity %s", resourceName) + } + + return *identity.Properties.ClientID, nil +} diff --git a/lib/cloud/azure/user_identities_test.go b/lib/cloud/azure/user_identities_test.go new file mode 100644 index 0000000000000..4176c72a614e4 --- /dev/null +++ b/lib/cloud/azure/user_identities_test.go @@ -0,0 +1,67 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azure + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUserAssignedIdentitiesClient(t *testing.T) { + t.Parallel() + + bot1 := NewUserAssignedIdentity("my-sub", "my-group", "bot1", "bot1-id") + mockAPI := NewARMUserAssignedIdentitiesMock(bot1) + + tests := []struct { + name string + inputResourceGroupName string + inputUserName string + wantError bool + wantClientID string + }{ + { + name: "success", + inputResourceGroupName: "my-group", + inputUserName: "bot1", + wantClientID: "bot1-id", + }, + { + name: "not found", + inputResourceGroupName: "my-group", + inputUserName: "bot5", + wantError: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client := NewUserAssignedIdentitiesClientByAPI(mockAPI) + actualClientID, err := client.GetClientID(context.Background(), test.inputResourceGroupName, test.inputUserName) + if test.wantError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, test.wantClientID, actualClientID) + }) + } +} diff --git a/lib/service/service.go b/lib/service/service.go index 8f7899a02cb58..1bf6c82469cc5 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -5482,6 +5482,7 @@ func (process *TeleportProcess) initApps() { ConnectedProxyGetter: proxyGetter, Emitter: asyncEmitter, ConnectionMonitor: connMonitor, + Logger: logger, }) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/app/azure/credential.go b/lib/srv/app/azure/credential.go new file mode 100644 index 0000000000000..1275924d88a6e --- /dev/null +++ b/lib/srv/app/azure/credential.go @@ -0,0 +1,191 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azure + +import ( + "context" + "log/slog" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/gravitational/trace" + + cloudazure "github.com/gravitational/teleport/lib/cloud/azure" + "github.com/gravitational/teleport/lib/utils" +) + +// credentialProvider defines an interface that manages a particular type of +// credential. +type credentialProvider interface { + // MakeCredential creates an azcore.TokenCredential for provided identity. + MakeCredential(ctx context.Context, userRequestedIdentity string) (azcore.TokenCredential, error) + // MapScope maps the input scope if necessary. + MapScope(scope string) string +} + +func getAccessTokenFromCredentialProvider(credProvider credentialProvider) getAccessTokenFunc { + return func(ctx context.Context, userRequestedIdentity string, scope string) (*azcore.AccessToken, error) { + credential, err := credProvider.MakeCredential(ctx, userRequestedIdentity) + if err != nil { + return nil, trace.Wrap(err) + } + + opts := policy.TokenRequestOptions{ + Scopes: []string{credProvider.MapScope(scope)}, + } + token, err := credential.GetToken(ctx, opts) + if err != nil { + return nil, trace.Wrap(err) + } + return &token, nil + } +} + +func findDefaultCredentialProvider(ctx context.Context, logger *slog.Logger) (credentialProvider, error) { + // Check if default workload identity is available: the clientID/tenantID + // for the default workload identity and the token file path are required + // from environment variables. + defaultWorkloadIdentity, err := azidentity.NewWorkloadIdentityCredential(nil) + if err != nil { + // If no workload identity is found, fall back to regular managed identity. + logger.With("error", err).InfoContext(ctx, "Failed to load workload identity. Using managed identity.") + return managedIdentityCredentialProvider{}, nil + } + + logger.InfoContext(ctx, "Using workload identity.") + credProvider, err := newWorloadIdentityCredentialProvider(ctx, defaultWorkloadIdentity) + return credProvider, trace.Wrap(err) +} + +// managedIdentityCredentialProvider implements credentialProvider for using +// managed identities assigned to the host machine. Identities are usually +// checked against the IMDS service available in the local network. +type managedIdentityCredentialProvider struct { +} + +func (m managedIdentityCredentialProvider) MakeCredential(ctx context.Context, userRequestedIdentity string) (azcore.TokenCredential, error) { + credenial, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ + ID: azidentity.ResourceID(userRequestedIdentity), + }) + return credenial, trace.Wrap(err) +} + +func (m managedIdentityCredentialProvider) MapScope(scope string) string { + // No scope needs to be mapped. + return scope +} + +// workloadIdentityCredentialProvider implements credentialProvider for using +// workload identities assigned to the host machine. +// +// https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview +// +// When running on AKS, multiple workload identities can be associated to the +// same service account attached to the pod. Assuming a workload identity +// requires a client ID of that identity but only the default Client ID is +// provided through environment variable. We assume that the default workload +// identity (mapped by the default client ID) is the "app-service" identity +// with msi permissions so the client IDs for other "user-requested" identity +// can be retrieved using the default identity. +type workloadIdentityCredentialProvider struct { + cache *utils.FnCache + defaultAgentIdentity azcore.TokenCredential + + // newClient defaults to cloudazure.NewUserAssignedIdentitiesClient. Can be + // overridden for test. + newClient func(string, azcore.TokenCredential, *arm.ClientOptions) (*cloudazure.UserAssignedIdentitiesClient, error) + // newCredential defaults to newWorkloadIdentityCredentialForClientID. Can + // be overridden for test. + newCredential func(string) (azcore.TokenCredential, error) +} + +func newWorloadIdentityCredentialProvider(ctx context.Context, defaultAgentIdentity azcore.TokenCredential) (*workloadIdentityCredentialProvider, error) { + if defaultAgentIdentity == nil { + return nil, trace.BadParameter("missing defaultAgentIdentity") + } + cache, err := utils.NewFnCache(utils.FnCacheConfig{ + Context: ctx, + TTL: clientIDCacheTTL, + ReloadOnErr: true, + }) + if err != nil { + return nil, trace.Wrap(err) + } + return &workloadIdentityCredentialProvider{ + cache: cache, + defaultAgentIdentity: defaultAgentIdentity, + newClient: cloudazure.NewUserAssignedIdentitiesClient, + newCredential: newWorkloadIdentityCredentialForClientID, + }, nil +} + +func newWorkloadIdentityCredentialForClientID(clientID string) (azcore.TokenCredential, error) { + cred, err := azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{ + ClientID: clientID, + }) + return cred, trace.Wrap(err) +} + +func (w *workloadIdentityCredentialProvider) MakeCredential(ctx context.Context, userRequestedIdentity string) (azcore.TokenCredential, error) { + clientID, err := w.getClientID(ctx, userRequestedIdentity) + if err != nil { + return nil, trace.Wrap(err) + } + + credential, err := w.newCredential(clientID) + return credential, trace.Wrap(err) +} + +func (w *workloadIdentityCredentialProvider) MapScope(scope string) string { + // This scope ("https://management.core.windows.net/") from `az` CLI tool + // will fail for workload identity as workload identity is only expected to + // be used with compatible SDKs, whereas the SDK adds ".default" to the + // audience: + // + // https://github.com/Azure/azure-sdk-for-go/blob/9e78ee2b86f0f4989098dd7e545b73841fc8df47/sdk/azcore/arm/runtime/pipeline.go#L35 + if scope == "https://management.core.windows.net/" { + return scope + ".default" + } + return scope +} + +func (w *workloadIdentityCredentialProvider) getClientID(ctx context.Context, identityResourceID string) (string, error) { + clientID, err := utils.FnCacheGet(ctx, w.cache, identityResourceID, func(ctx context.Context) (string, error) { + resourceID, err := arm.ParseResourceID(identityResourceID) + if err != nil { + return "", trace.Wrap(err) + } + + client, err := w.newClient(resourceID.SubscriptionID, w.defaultAgentIdentity, nil) + if err != nil { + return "", trace.Wrap(err) + } + + clientID, err := client.GetClientID(ctx, resourceID.ResourceGroupName, resourceID.Name) + return clientID, trace.Wrap(err) + }) + return clientID, trace.Wrap(err) +} + +// clientIDCacheTTL defines how long client IDs should be cached. Client IDs +// should never change for an identity so use a longer cache TTL. +var clientIDCacheTTL = 30 * time.Minute diff --git a/lib/srv/app/azure/credential_test.go b/lib/srv/app/azure/credential_test.go new file mode 100644 index 0000000000000..8f5fc6111ea16 --- /dev/null +++ b/lib/srv/app/azure/credential_test.go @@ -0,0 +1,130 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azure + +import ( + "context" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + cloudazure "github.com/gravitational/teleport/lib/cloud/azure" +) + +type fakeTokenCredential struct { + lastSeenScope string +} + +func (f *fakeTokenCredential) GetToken(_ context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { + if len(opts.Scopes) != 1 { + return azcore.AccessToken{}, trace.BadParameter("expect one scope but got %v", opts.Scopes) + } + + f.lastSeenScope = opts.Scopes[0] + return azcore.AccessToken{ + Token: "fake-token", + ExpiresOn: time.Now().Add(time.Hour), + }, nil +} + +type fakeCredentialProvider struct { + cred fakeTokenCredential + lastSeenIdentity string +} + +func (f *fakeCredentialProvider) MakeCredential(_ context.Context, userRequestedIdentity string) (azcore.TokenCredential, error) { + f.lastSeenIdentity = userRequestedIdentity + return &f.cred, nil +} + +func (f *fakeCredentialProvider) MapScope(scope string) string { + return scope + ".mapped" +} + +func Test_getAccessTokenFromCredentialProvider(t *testing.T) { + fakeCredProvider := &fakeCredentialProvider{} + userRequestedIdentity := "/subscriptions/my-sub/resourcegroups/my-group/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-name" + ctx := context.Background() + + token, err := getAccessTokenFromCredentialProvider(fakeCredProvider)(ctx, userRequestedIdentity, "test-scope") + require.NoError(t, err) + require.Equal(t, "fake-token", token.Token) + require.Equal(t, userRequestedIdentity, fakeCredProvider.lastSeenIdentity) + require.Equal(t, "test-scope.mapped", fakeCredProvider.cred.lastSeenScope) +} + +func Test_workloadIdentityCredentialProvider(t *testing.T) { + ctx := context.Background() + fakeAgentIdentity := &fakeTokenCredential{} + credProvider, err := newWorloadIdentityCredentialProvider(ctx, fakeAgentIdentity) + require.NoError(t, err) + + // Hook up more mocks. + fakeWorkloadIdentityCredential := &fakeTokenCredential{} + userRequestedIdentity := cloudazure.NewUserAssignedIdentity("my-sub", "my-group", "my-name", "my-client-id") + mockAPI := cloudazure.NewARMUserAssignedIdentitiesMock(userRequestedIdentity) + credProvider.newClient = func(string, azcore.TokenCredential, *arm.ClientOptions) (*cloudazure.UserAssignedIdentitiesClient, error) { + return cloudazure.NewUserAssignedIdentitiesClientByAPI(mockAPI), nil + } + credProvider.newCredential = func(clientID string) (azcore.TokenCredential, error) { + if clientID != "my-client-id" { + return nil, trace.BadParameter("expect my-client-id but got %s", clientID) + } + return fakeWorkloadIdentityCredential, nil + } + + t.Run("MakeCredential", func(t *testing.T) { + t.Run("success", func(t *testing.T) { + actualCredential, err := credProvider.MakeCredential(ctx, *userRequestedIdentity.ID) + require.NoError(t, err) + require.Same(t, fakeWorkloadIdentityCredential, actualCredential) + }) + t.Run("fail to get client ID", func(t *testing.T) { + notFoundIdentity := "/subscriptions/my-sub/resourcegroups/my-group/providers/Microsoft.ManagedIdentity/userAssignedIdentities/not-my-name" + _, err := credProvider.MakeCredential(ctx, notFoundIdentity) + require.Error(t, err) + }) + }) + + t.Run("MapScope", func(t *testing.T) { + tests := []struct { + inputScope string + outputScope string + }{ + { + inputScope: "https://management.core.windows.net/", + outputScope: "https://management.core.windows.net/.default", + }, + { + inputScope: "some-other-scope", + outputScope: "some-other-scope", + }, + } + for _, test := range tests { + t.Run(test.inputScope, func(t *testing.T) { + require.Equal(t, test.outputScope, credProvider.MapScope(test.inputScope)) + }) + } + }) +} diff --git a/lib/srv/app/azure/handler.go b/lib/srv/app/azure/handler.go index 65603609f1970..04c9f3b0bedb1 100644 --- a/lib/srv/app/azure/handler.go +++ b/lib/srv/app/azure/handler.go @@ -23,13 +23,12 @@ import ( "context" "crypto" "crypto/x509" + "log/slog" "net/http" "strings" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" @@ -45,12 +44,18 @@ import ( "github.com/gravitational/teleport/lib/utils" ) +// ComponentKey is the Teleport component key for this handler. +const ComponentKey = "azure:fwd" + // HandlerConfig is the configuration for an Azure app-access handler. type HandlerConfig struct { // RoundTripper is the underlying transport given to an oxy Forwarder. RoundTripper http.RoundTripper // Log is the Logger. + // TODO(greedy52) replace with slog. Log logrus.FieldLogger + // Logger is the slog.Logger. + Logger *slog.Logger // Clock is used to override time in tests. Clock clockwork.Clock @@ -59,7 +64,7 @@ type HandlerConfig struct { } // CheckAndSetDefaults validates the HandlerConfig. -func (s *HandlerConfig) CheckAndSetDefaults() error { +func (s *HandlerConfig) CheckAndSetDefaults(ctx context.Context) error { if s.RoundTripper == nil { tr, err := defaults.Transport() if err != nil { @@ -71,10 +76,17 @@ func (s *HandlerConfig) CheckAndSetDefaults() error { s.Clock = clockwork.NewRealClock() } if s.Log == nil { - s.Log = logrus.WithField(teleport.ComponentKey, "azure:fwd") + s.Log = logrus.WithField(teleport.ComponentKey, ComponentKey) + } + if s.Logger == nil { + s.Logger = slog.Default().With(teleport.ComponentKey, ComponentKey) } if s.getAccessToken == nil { - s.getAccessToken = getAccessTokenManagedIdentity + credProvider, err := findDefaultCredentialProvider(ctx, s.Logger) + if err != nil { + return trace.Wrap(err) + } + s.getAccessToken = getAccessTokenFromCredentialProvider(credProvider) } return nil } @@ -99,7 +111,7 @@ func NewAzureHandler(ctx context.Context, config HandlerConfig) (http.Handler, e // newAzureHandler creates a new instance of a handler for Azure requests. Used by NewAzureHandler and in tests. func newAzureHandler(ctx context.Context, config HandlerConfig) (*handler, error) { - if err := config.CheckAndSetDefaults(); err != nil { + if err := config.CheckAndSetDefaults(ctx); err != nil { return nil, trace.Wrap(err) } @@ -262,20 +274,6 @@ func (s *handler) parseAuthHeader(token string, pubKey crypto.PublicKey) (*jwt.A type getAccessTokenFunc func(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) -func getAccessTokenManagedIdentity(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) { - identityCredential, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ResourceID(managedIdentity)}) - if err != nil { - return nil, trace.Wrap(err) - } - - opts := policy.TokenRequestOptions{Scopes: []string{scope}} - token, err := identityCredential.GetToken(ctx, opts) - if err != nil { - return nil, trace.Wrap(err) - } - return &token, nil -} - type cacheKey struct { managedIdentity string scope string diff --git a/lib/srv/app/server.go b/lib/srv/app/server.go index a6a14c6e3f980..b122cf09ab06b 100644 --- a/lib/srv/app/server.go +++ b/lib/srv/app/server.go @@ -26,6 +26,7 @@ import ( "crypto/tls" "crypto/x509" "errors" + "log/slog" "net" "net/http" "strconv" @@ -131,6 +132,9 @@ type Config struct { // ConnectionMonitor monitors connections and terminates any if // any session controls prevent them. ConnectionMonitor ConnMonitor + + // Logger is the slog.Logger. + Logger *slog.Logger } // CheckAndSetDefaults makes sure the configuration has the minimum required @@ -180,6 +184,9 @@ func (c *Config) CheckAndSetDefaults() error { if c.ConnectedProxyGetter == nil { c.ConnectedProxyGetter = reversetunnel.NewConnectedProxyGetter() } + if c.Logger == nil { + c.Logger = slog.Default().With(teleport.ComponentKey, teleport.Component(teleport.ComponentApp)) + } return nil } @@ -289,7 +296,9 @@ func New(ctx context.Context, c *Config) (*Server, error) { return nil, trace.Wrap(err) } - azureHandler, err := appazure.NewAzureHandler(closeContext, appazure.HandlerConfig{}) + azureHandler, err := appazure.NewAzureHandler(closeContext, appazure.HandlerConfig{ + Logger: c.Logger.With(teleport.ComponentKey, appazure.ComponentKey), + }) if err != nil { return nil, trace.Wrap(err) } @@ -301,6 +310,7 @@ func New(ctx context.Context, c *Config) (*Server, error) { s := &Server{ c: c, + // TODO(greedy52) replace with slog from Config.Logger. log: logrus.WithFields(logrus.Fields{ teleport.ComponentKey: teleport.ComponentApp, }),