From 0ad1673c15f747c2a0a6b039865e29e38ccd650b Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 3 Aug 2023 22:41:02 +0000 Subject: [PATCH 01/13] remove disable CP1 env var; disabled is the new default --- sdk/azidentity/azidentity.go | 4 +--- sdk/azidentity/azidentity_test.go | 21 +++++++++---------- .../client_certificate_credential_test.go | 6 ------ 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index 7b0a0f861f50..c3298bcbfcf5 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -15,7 +15,6 @@ import ( "net/url" "os" "regexp" - "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" @@ -46,8 +45,7 @@ const ( var ( // capability CP1 indicates the client application is capable of handling CAE claims challenges - cp1 = []string{"CP1"} - disableCP1 = strings.ToLower(os.Getenv("AZURE_IDENTITY_DISABLE_CP1")) == "true" + cp1 = []string{"CP1"} ) type msalClientOptions struct { diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index 9be3a774e6cc..32dd90cfd265 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -446,8 +446,6 @@ func TestAdditionallyAllowedTenants(t *testing.T) { } func TestClaims(t *testing.T) { - realCP1 := disableCP1 - t.Cleanup(func() { disableCP1 = realCP1 }) claim := `"test":"pass"` for _, test := range []struct { ctor func(azcore.ClientOptions) (azcore.TokenCredential, error) @@ -499,13 +497,12 @@ func TestClaims(t *testing.T) { }, }, } { - for _, d := range []bool{true, false} { + for _, enableCAE := range []bool{true, false} { name := test.name - if d { - name += " disableCP1" + if enableCAE { + name += " CAE" } t.Run(name, func(t *testing.T) { - disableCP1 = d reqs := 0 sts := mockSTS{ tokenRequestCallback: func(r *http.Request) *http.Response { @@ -513,14 +510,14 @@ func TestClaims(t *testing.T) { t.Error(err) } reqs++ - // If the disableCP1 flag isn't set, both requests should specify CP1. The second - // GetToken call specifies claims we should find in the following token request. + // Both requests should specify CP1 when CAE is enabled for the token. // We check only for substrings because MSAL is responsible for formatting claims. actual := fmt.Sprint(r.Form["claims"]) - if strings.Contains(actual, "CP1") == disableCP1 { + if strings.Contains(actual, "CP1") != enableCAE { t.Fatalf(`unexpected claims "%v"`, actual) } if reqs == 2 { + // the second GetToken call specifies claims we should find in the following token request if !strings.Contains(strings.ReplaceAll(actual, " ", ""), claim) { t.Fatalf(`unexpected claims "%v"`, actual) } @@ -533,10 +530,12 @@ func TestClaims(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"A"}}); err != nil { + tro := policy.TokenRequestOptions{EnableCAE: enableCAE, Scopes: []string{"A"}} + if _, err = cred.GetToken(context.Background(), tro); err != nil { t.Fatal(err) } - if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Claims: fmt.Sprintf("{%s}", claim), Scopes: []string{"B"}}); err != nil { + tro = policy.TokenRequestOptions{Claims: fmt.Sprintf("{%s}", claim), EnableCAE: enableCAE, Scopes: []string{"B"}} + if _, err = cred.GetToken(context.Background(), tro); err != nil { t.Fatal(err) } if reqs != 2 { diff --git a/sdk/azidentity/client_certificate_credential_test.go b/sdk/azidentity/client_certificate_credential_test.go index 7372ba2fbffc..cf16da52b687 100644 --- a/sdk/azidentity/client_certificate_credential_test.go +++ b/sdk/azidentity/client_certificate_credential_test.go @@ -297,12 +297,6 @@ func TestClientCertificateCredential_Regional(t *testing.T) { if err != nil { t.Fatal(err) } - - // regional STS returns an error for CP1 - before := disableCP1 - defer func() { disableCP1 = before }() - disableCP1 = true - cred, err := NewClientCertificateCredential( liveSP.tenantID, liveSP.clientID, cert, key, &ClientCertificateCredentialOptions{SendCertificateChain: true, ClientOptions: opts}, ) From aae29387c5ab6823fc4baff02f617ca8ee222c5e Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 3 Aug 2023 23:07:18 +0000 Subject: [PATCH 02/13] rename confidential/publicClient interfaces --- sdk/azidentity/azidentity.go | 4 ++-- sdk/azidentity/azidentity_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index c3298bcbfcf5..f84da46789c3 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -179,7 +179,7 @@ func (p pipelineAdapter) Do(r *http.Request) (*http.Response, error) { } // enables fakes for test scenarios -type confidentialClient interface { +type msalConfidentialClient interface { AcquireTokenSilent(ctx context.Context, scopes []string, options ...confidential.AcquireSilentOption) (confidential.AuthResult, error) AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, options ...confidential.AcquireByAuthCodeOption) (confidential.AuthResult, error) AcquireTokenByCredential(ctx context.Context, scopes []string, options ...confidential.AcquireByCredentialOption) (confidential.AuthResult, error) @@ -187,7 +187,7 @@ type confidentialClient interface { } // enables fakes for test scenarios -type publicClient interface { +type msalPublicClient interface { AcquireTokenSilent(ctx context.Context, scopes []string, options ...public.AcquireSilentOption) (public.AuthResult, error) AcquireTokenByUsernamePassword(ctx context.Context, scopes []string, username string, password string, options ...public.AcquireByUsernamePasswordOption) (public.AuthResult, error) AcquireTokenByDeviceCode(ctx context.Context, scopes []string, options ...public.AcquireByDeviceCodeOption) (public.DeviceCode, error) diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index 32dd90cfd265..036bd3680bdc 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -591,7 +591,7 @@ func (f fakeConfidentialClient) AcquireTokenOnBehalfOf(ctx context.Context, user return f.returnResult() } -var _ confidentialClient = (*fakeConfidentialClient)(nil) +var _ msalConfidentialClient = (*fakeConfidentialClient)(nil) // ================================================================================================================================== @@ -642,4 +642,4 @@ func (f fakePublicClient) AcquireTokenInteractive(ctx context.Context, scopes [] return f.returnResult() } -var _ publicClient = (*fakePublicClient)(nil) +var _ msalPublicClient = (*fakePublicClient)(nil) From 0f54af0e6668d2f581e4716df8aa4543a77f8fdb Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 3 Aug 2023 23:31:22 +0000 Subject: [PATCH 03/13] refactor tenant resolver --- sdk/azidentity/azidentity.go | 24 ++++++++++++-- sdk/azidentity/azidentity_test.go | 53 +++++++++++++++++++++++++++++++ sdk/azidentity/syncer.go | 16 +--------- sdk/azidentity/syncer_test.go | 50 ----------------------------- 4 files changed, 76 insertions(+), 67 deletions(-) diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index f84da46789c3..4e693a823167 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -10,6 +10,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "net/http" "net/url" @@ -40,12 +41,12 @@ const ( organizationsTenantID = "organizations" developerSignOnClientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" defaultSuffix = "/.default" - tenantIDValidationErr = "invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names" ) var ( // capability CP1 indicates the client application is capable of handling CAE claims challenges - cp1 = []string{"CP1"} + cp1 = []string{"CP1"} + errInvalidTenantID = errors.New("invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names") ) type msalClientOptions struct { @@ -127,6 +128,25 @@ func setAuthorityHost(cc cloud.Configuration) (string, error) { return host, nil } +// resolveTenant returns the correct tenant for a token request +func resolveTenant(defaultTenant, specified, credName string, additionalTenants []string) (string, error) { + if specified == "" || specified == defaultTenant { + return defaultTenant, nil + } + if defaultTenant == "adfs" { + return "", errors.New("ADFS doesn't support tenants") + } + if !validTenantID(specified) { + return "", errInvalidTenantID + } + for _, t := range additionalTenants { + if t == "*" || t == specified { + return specified, nil + } + } + return "", fmt.Errorf(`%s isn't configured to acquire tokens for tenant %q. To enable acquiring tokens for this tenant add it to the AdditionallyAllowedTenants on the credential options, or add "*" to allow acquiring tokens for any tenant`, credName, specified) +} + // validTenantID return true is it receives a valid tenantID, returns false otherwise func validTenantID(tenantID string) bool { match, err := regexp.MatchString("^[0-9a-zA-Z-.]+$", tenantID) diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index 036bd3680bdc..d6e3a9461b30 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -546,6 +546,59 @@ func TestClaims(t *testing.T) { } } +func TestResolveTenant(t *testing.T) { + credName := "testcred" + defaultTenant := "default-tenant" + otherTenant := "other-tenant" + for _, test := range []struct { + allowed []string + expected, tenant string + expectError bool + }{ + // no alternate tenant specified -> should get default + {expected: defaultTenant}, + {allowed: []string{""}, expected: defaultTenant}, + {allowed: []string{"*"}, expected: defaultTenant}, + {allowed: []string{otherTenant}, expected: defaultTenant}, + + // alternate tenant specified and allowed -> should get that tenant + {allowed: []string{"*"}, expected: otherTenant, tenant: otherTenant}, + {allowed: []string{otherTenant}, expected: otherTenant, tenant: otherTenant}, + {allowed: []string{"not-" + otherTenant, otherTenant}, expected: otherTenant, tenant: otherTenant}, + {allowed: []string{"not-" + otherTenant, "*"}, expected: otherTenant, tenant: otherTenant}, + + // invalid or not allowed tenant -> should get an error + {tenant: otherTenant, expectError: true}, + {allowed: []string{""}, tenant: otherTenant, expectError: true}, + {allowed: []string{defaultTenant}, tenant: otherTenant, expectError: true}, + {tenant: badTenantID, expectError: true}, + {allowed: []string{""}, tenant: badTenantID, expectError: true}, + {allowed: []string{"*", badTenantID}, tenant: badTenantID, expectError: true}, + {tenant: "invalid@tenant", expectError: true}, + {tenant: "invalid/tenant", expectError: true}, + {tenant: "invalid(tenant", expectError: true}, + {tenant: "invalid:tenant", expectError: true}, + } { + t.Run("", func(t *testing.T) { + tenant, err := resolveTenant(defaultTenant, test.tenant, credName, test.allowed) + if err != nil { + if test.expectError { + if validTenantID(test.tenant) && !strings.Contains(err.Error(), credName) { + t.Fatalf("expected error to contain %q, got %q", credName, err.Error()) + } + return + } + t.Fatal(err) + } else if test.expectError { + t.Fatal("expected an error") + } + if tenant != test.expected { + t.Fatalf(`expected "%s", got "%s"`, test.expected, tenant) + } + }) + } +} + // ================================================================================================================================== type fakeConfidentialClient struct { diff --git a/sdk/azidentity/syncer.go b/sdk/azidentity/syncer.go index 867490963efa..fe4f146d51cd 100644 --- a/sdk/azidentity/syncer.go +++ b/sdk/azidentity/syncer.go @@ -86,21 +86,7 @@ func (s *syncer) GetToken(ctx context.Context, opts policy.TokenRequestOptions) // resolveTenant returns the correct tenant for a token request given the credential's // configuration, or an error when the specified tenant isn't allowed by that configuration func (s *syncer) resolveTenant(requested string) (string, error) { - if requested == "" || requested == s.tenant { - return s.tenant, nil - } - if s.tenant == "adfs" { - return "", errors.New("ADFS doesn't support tenants") - } - if !validTenantID(requested) { - return "", errors.New(tenantIDValidationErr) - } - for _, t := range s.addlTenants { - if t == "*" || t == requested { - return requested, nil - } - } - return "", fmt.Errorf(`%s isn't configured to acquire tokens for tenant %q. To enable acquiring tokens for this tenant add it to the AdditionallyAllowedTenants on the credential options, or add "*" to allow acquiring tokens for any tenant`, s.name, requested) + return resolveTenant(s.tenant, requested, s.name, s.addlTenants) } // resolveAdditionalTenants returns a copy of tenants, simplified when tenants contains a wildcard diff --git a/sdk/azidentity/syncer_test.go b/sdk/azidentity/syncer_test.go index 137556b4eb8f..2c60ed445c8f 100644 --- a/sdk/azidentity/syncer_test.go +++ b/sdk/azidentity/syncer_test.go @@ -16,56 +16,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) -func TestResolveTenant(t *testing.T) { - defaultTenant := "default-tenant" - otherTenant := "other-tenant" - for _, test := range []struct { - allowed []string - expected, tenant string - expectError bool - }{ - // no alternate tenant specified -> should get default - {expected: defaultTenant}, - {allowed: []string{""}, expected: defaultTenant}, - {allowed: []string{"*"}, expected: defaultTenant}, - {allowed: []string{otherTenant}, expected: defaultTenant}, - - // alternate tenant specified and allowed -> should get that tenant - {allowed: []string{"*"}, expected: otherTenant, tenant: otherTenant}, - {allowed: []string{otherTenant}, expected: otherTenant, tenant: otherTenant}, - {allowed: []string{"not-" + otherTenant, otherTenant}, expected: otherTenant, tenant: otherTenant}, - {allowed: []string{"not-" + otherTenant, "*"}, expected: otherTenant, tenant: otherTenant}, - - // invalid or not allowed tenant -> should get an error - {tenant: otherTenant, expectError: true}, - {allowed: []string{""}, tenant: otherTenant, expectError: true}, - {allowed: []string{defaultTenant}, tenant: otherTenant, expectError: true}, - {tenant: badTenantID, expectError: true}, - {allowed: []string{""}, tenant: badTenantID, expectError: true}, - {allowed: []string{"*", badTenantID}, tenant: badTenantID, expectError: true}, - {tenant: "invalid@tenant", expectError: true}, - {tenant: "invalid/tenant", expectError: true}, - {tenant: "invalid(tenant", expectError: true}, - {tenant: "invalid:tenant", expectError: true}, - } { - t.Run("", func(t *testing.T) { - s := newSyncer("", defaultTenant, nil, nil, syncerOptions{AdditionallyAllowedTenants: test.allowed}) - tenant, err := s.resolveTenant(test.tenant) - if err != nil { - if test.expectError { - return - } - t.Fatal(err) - } else if test.expectError { - t.Fatal("expected an error") - } - if tenant != test.expected { - t.Fatalf(`expected "%s", got "%s"`, test.expected, tenant) - } - }) - } -} - func TestSyncer(t *testing.T) { silentAuths, tokenRequests := 0, 0 s := newSyncer("", "tenant", From 888bd66e0ff2a93e3db44229daa01351b774b266 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 4 Aug 2023 17:19:12 +0000 Subject: [PATCH 04/13] fix doc comment list format --- sdk/azidentity/azidentity.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index 4e693a823167..16511d9edaef 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -104,9 +104,9 @@ var getPublicClient = func(clientID, tenantID string, opts msalClientOptions) (p } // setAuthorityHost initializes the authority host for credentials. Precedence is: -// 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user -// 2. value of AZURE_AUTHORITY_HOST -// 3. default: Azure Public Cloud +// 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user +// 2. value of AZURE_AUTHORITY_HOST +// 3. default: Azure Public Cloud func setAuthorityHost(cc cloud.Configuration) (string, error) { host := cc.ActiveDirectoryAuthorityHost if host == "" { From 1f6d1b9a9c8bdcab1bdb711dd66dfafe549ad61a Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 4 Aug 2023 17:35:43 +0000 Subject: [PATCH 05/13] confidential/public wrapper clients --- sdk/azidentity/azidentity.go | 54 -------- sdk/azidentity/confidential_client.go | 153 ++++++++++++++++++++++ sdk/azidentity/public_client.go | 175 ++++++++++++++++++++++++++ 3 files changed, 328 insertions(+), 54 deletions(-) create mode 100644 sdk/azidentity/confidential_client.go create mode 100644 sdk/azidentity/public_client.go diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index 16511d9edaef..f0b22649859e 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -49,60 +49,6 @@ var ( errInvalidTenantID = errors.New("invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names") ) -type msalClientOptions struct { - azcore.ClientOptions - - DisableInstanceDiscovery bool - // SendX5C applies only to confidential clients authenticating with a cert - SendX5C bool -} - -var getConfidentialClient = func(clientID, tenantID string, cred confidential.Credential, opts msalClientOptions) (confidentialClient, error) { - if !validTenantID(tenantID) { - return confidential.Client{}, errors.New(tenantIDValidationErr) - } - authorityHost, err := setAuthorityHost(opts.Cloud) - if err != nil { - return confidential.Client{}, err - } - authority := runtime.JoinPaths(authorityHost, tenantID) - o := []confidential.Option{ - confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)), - confidential.WithHTTPClient(newPipelineAdapter(&opts.ClientOptions)), - } - if !disableCP1 { - o = append(o, confidential.WithClientCapabilities(cp1)) - } - if opts.SendX5C { - o = append(o, confidential.WithX5C()) - } - if opts.DisableInstanceDiscovery || strings.ToLower(tenantID) == "adfs" { - o = append(o, confidential.WithInstanceDiscovery(false)) - } - return confidential.New(authority, clientID, cred, o...) -} - -var getPublicClient = func(clientID, tenantID string, opts msalClientOptions) (public.Client, error) { - if !validTenantID(tenantID) { - return public.Client{}, errors.New(tenantIDValidationErr) - } - authorityHost, err := setAuthorityHost(opts.Cloud) - if err != nil { - return public.Client{}, err - } - o := []public.Option{ - public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)), - public.WithHTTPClient(newPipelineAdapter(&opts.ClientOptions)), - } - if !disableCP1 { - o = append(o, public.WithClientCapabilities(cp1)) - } - if opts.DisableInstanceDiscovery || strings.ToLower(tenantID) == "adfs" { - o = append(o, public.WithInstanceDiscovery(false)) - } - return public.New(clientID, o...) -} - // setAuthorityHost initializes the authority host for credentials. Precedence is: // 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user // 2. value of AZURE_AUTHORITY_HOST diff --git a/sdk/azidentity/confidential_client.go b/sdk/azidentity/confidential_client.go new file mode 100644 index 000000000000..13ca474fe171 --- /dev/null +++ b/sdk/azidentity/confidential_client.go @@ -0,0 +1,153 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + "sync" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" +) + +type confidentialClientOptions struct { + azcore.ClientOptions + + AdditionallyAllowedTenants []string + // Assertion for on-behalf-of authentication + Assertion string + DisableInstanceDiscovery, SendX5C bool +} + +// confidentialClient wraps the MSAL confidential client +type confidentialClient struct { + cae, noCAE msalConfidentialClient + caeMu, clientMu, noCAEMu *sync.Mutex + clientID, tenantID string + cred confidential.Credential + host string + name string + opts confidentialClientOptions +} + +func newConfidentialClient(tenantID, clientID, name string, cred confidential.Credential, opts confidentialClientOptions) (*confidentialClient, error) { + if !validTenantID(tenantID) { + return nil, errInvalidTenantID + } + host, err := setAuthorityHost(opts.Cloud) + if err != nil { + return nil, err + } + return &confidentialClient{ + caeMu: &sync.Mutex{}, + clientID: clientID, + clientMu: &sync.Mutex{}, + cred: cred, + host: host, + name: name, + noCAEMu: &sync.Mutex{}, + opts: opts, + tenantID: tenantID, + }, nil +} + +// GetToken requests an access token from MSAL, checking the cache first. +func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + if len(tro.Scopes) < 1 { + return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", c.name) + } + client, mu, err := c.client(ctx, tro) + if err != nil { + return azcore.AccessToken{}, err + } + // we don't resolve the tenant for managed identities because they acquire tokens only from their home tenants + if c.name != credNameManagedIdentity { + tenant, err := c.resolveTenant(tro.TenantID) + if err != nil { + return azcore.AccessToken{}, err + } + tro.TenantID = tenant + } + mu.Lock() + defer mu.Unlock() + var ar confidential.AuthResult + if c.opts.Assertion != "" { + ar, err = client.AcquireTokenOnBehalfOf(ctx, c.opts.Assertion, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID)) + } else { + ar, err = client.AcquireTokenSilent(ctx, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID)) + if err != nil { + ar, err = client.AcquireTokenByCredential(ctx, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID)) + } + } + if err != nil { + // We could get a credentialUnavailableError from managed identity authentication because in that case the error comes from our code. + // We return it directly because it affects the behavior of credential chains. Otherwise, we return AuthenticationFailedError. + var unavailableErr *credentialUnavailableError + if !errors.As(err, &unavailableErr) { + res := getResponseFromError(err) + err = newAuthenticationFailedError(c.name, err.Error(), res, err) + } + } else { + msg := fmt.Sprintf("%s.GetToken() acquired a token for scope %q", c.name, strings.Join(ar.GrantedScopes, ", ")) + log.Write(EventAuthentication, msg) + } + return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err +} + +func (c *confidentialClient) client(ctx context.Context, tro policy.TokenRequestOptions) (msalConfidentialClient, *sync.Mutex, error) { + c.clientMu.Lock() + defer c.clientMu.Unlock() + if tro.EnableCAE { + if c.cae == nil { + client, err := c.newMSALClient(true) + if err != nil { + return nil, nil, err + } + c.cae = client + } + return c.cae, c.caeMu, nil + } + if c.noCAE == nil { + client, err := c.newMSALClient(false) + if err != nil { + return nil, nil, err + } + c.noCAE = client + } + return c.noCAE, c.noCAEMu, nil +} + +func (c *confidentialClient) newMSALClient(enableCAE bool) (msalConfidentialClient, error) { + authority := runtime.JoinPaths(c.host, c.tenantID) + o := []confidential.Option{ + confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)), + confidential.WithHTTPClient(newPipelineAdapter(&c.opts.ClientOptions)), + } + if enableCAE { + o = append(o, confidential.WithClientCapabilities(cp1)) + } + if c.opts.SendX5C { + o = append(o, confidential.WithX5C()) + } + if c.opts.DisableInstanceDiscovery || strings.ToLower(c.tenantID) == "adfs" { + o = append(o, confidential.WithInstanceDiscovery(false)) + } + return confidential.New(authority, c.clientID, c.cred, o...) +} + +// resolveTenant returns the correct tenant for a token request given the client's +// configuration, or an error when that configuration doesn't allow the specified tenant +func (c *confidentialClient) resolveTenant(specified string) (string, error) { + return resolveTenant(c.tenantID, specified, c.name, c.opts.AdditionallyAllowedTenants) +} diff --git a/sdk/azidentity/public_client.go b/sdk/azidentity/public_client.go new file mode 100644 index 000000000000..dfd75861179c --- /dev/null +++ b/sdk/azidentity/public_client.go @@ -0,0 +1,175 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" +) + +type publicClientOptions struct { + azcore.ClientOptions + + AdditionallyAllowedTenants []string + DeviceCodePrompt func(context.Context, DeviceCodeMessage) error + DisableInstanceDiscovery bool + LoginHint, RedirectURL string + Username, Password string +} + +// publicClient wraps the MSAL public client +type publicClient struct { + account public.Account + cae, noCAE msalPublicClient + clientID, tenantID string + clientMu, caeMu, noCAEMu *sync.Mutex + host string + name string + opts publicClientOptions +} + +func newPublicClient(tenantID, clientID, name string, o publicClientOptions) (*publicClient, error) { + if !validTenantID(tenantID) { + return nil, errInvalidTenantID + } + host, err := setAuthorityHost(o.Cloud) + if err != nil { + return nil, err + } + return &publicClient{ + caeMu: &sync.Mutex{}, + clientID: clientID, + clientMu: &sync.Mutex{}, + host: host, + name: name, + noCAEMu: &sync.Mutex{}, + opts: o, + tenantID: tenantID, + }, nil +} + +// GetToken requests an access token from MSAL, checking the cache first. +func (p *publicClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + if len(tro.Scopes) < 1 { + return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", p.name) + } + tenant, err := p.resolveTenant(tro.TenantID) + if err != nil { + return azcore.AccessToken{}, err + } + client, mu, err := p.client(tro) + if err != nil { + return azcore.AccessToken{}, err + } + mu.Lock() + defer mu.Unlock() + ar, err := client.AcquireTokenSilent(ctx, tro.Scopes, public.WithSilentAccount(p.account), public.WithClaims(tro.Claims), public.WithTenantID(tenant)) + if err == nil { + return p.token(ar, err) + } + at, err := p.reqToken(ctx, client, tro) + if err == nil { + msg := fmt.Sprintf("%s.GetToken() acquired a token for scope %q", p.name, strings.Join(ar.GrantedScopes, ", ")) + log.Write(EventAuthentication, msg) + } + return at, err +} + +// reqToken requests a token from the MSAL public client. It's separate from GetToken() to enable Authenticate() to bypass the cache. +func (m *publicClient) reqToken(ctx context.Context, c msalPublicClient, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + tenant, err := m.resolveTenant(tro.TenantID) + if err != nil { + return azcore.AccessToken{}, err + } + var ar public.AuthResult + switch { + case m.opts.DeviceCodePrompt != nil: + dc, e := c.AcquireTokenByDeviceCode(ctx, tro.Scopes, public.WithClaims(tro.Claims), public.WithTenantID(tenant)) + if e != nil { + return azcore.AccessToken{}, e + } + err = m.opts.DeviceCodePrompt(ctx, DeviceCodeMessage{ + Message: dc.Result.Message, + UserCode: dc.Result.UserCode, + VerificationURL: dc.Result.VerificationURL, + }) + if err == nil { + ar, err = dc.AuthenticationResult(ctx) + } + case m.opts.Username != "" && m.opts.Password != "": + ar, err = c.AcquireTokenByUsernamePassword(ctx, tro.Scopes, m.opts.Username, m.opts.Password, public.WithClaims(tro.Claims), public.WithTenantID(tenant)) + default: + ar, err = c.AcquireTokenInteractive(ctx, tro.Scopes, + public.WithClaims(tro.Claims), + public.WithLoginHint(m.opts.LoginHint), + public.WithRedirectURI(m.opts.RedirectURL), + public.WithTenantID(tenant), + ) + } + return m.token(ar, err) +} + +func (p *publicClient) client(tro policy.TokenRequestOptions) (msalPublicClient, *sync.Mutex, error) { + p.clientMu.Lock() + defer p.clientMu.Unlock() + if tro.EnableCAE { + if p.cae == nil { + client, err := p.newMSALClient(true) + if err != nil { + return nil, nil, err + } + p.cae = client + } + return p.cae, p.caeMu, nil + } + if p.noCAE == nil { + client, err := p.newMSALClient(false) + if err != nil { + return nil, nil, err + } + p.noCAE = client + } + return p.noCAE, p.noCAEMu, nil +} + +func (p *publicClient) newMSALClient(enableCAE bool) (msalPublicClient, error) { + o := []public.Option{ + public.WithAuthority(runtime.JoinPaths(p.host, p.tenantID)), + public.WithHTTPClient(newPipelineAdapter(&p.opts.ClientOptions)), + } + if enableCAE { + o = append(o, public.WithClientCapabilities(cp1)) + } + if p.opts.DisableInstanceDiscovery || strings.ToLower(p.tenantID) == "adfs" { + o = append(o, public.WithInstanceDiscovery(false)) + } + return public.New(p.clientID, o...) +} + +func (p *publicClient) token(ar public.AuthResult, err error) (azcore.AccessToken, error) { + if err == nil { + p.account = ar.Account + } else { + res := getResponseFromError(err) + err = newAuthenticationFailedError(p.name, err.Error(), res, err) + } + return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err +} + +// resolveTenant returns the correct tenant for a token request given the client's +// configuration, or an error when that configuration doesn't allow the specified tenant +func (m *publicClient) resolveTenant(specified string) (string, error) { + return resolveTenant(m.tenantID, specified, m.name, m.opts.AdditionallyAllowedTenants) +} From ed2134543506442702f008d08e5b200db6c97ac5 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 4 Aug 2023 18:49:42 +0000 Subject: [PATCH 06/13] refactor credentials to use the new clients --- sdk/azidentity/client_assertion_credential.go | 34 +++-------- .../client_certificate_credential.go | 36 +++-------- .../client_certificate_credential_test.go | 6 +- sdk/azidentity/client_secret_credential.go | 34 +++-------- .../client_secret_credential_test.go | 2 +- .../default_azure_credential_test.go | 2 +- sdk/azidentity/device_code_credential.go | 61 +++---------------- sdk/azidentity/device_code_credential_test.go | 4 +- .../interactive_browser_credential.go | 48 +++------------ .../interactive_browser_credential_test.go | 19 +----- sdk/azidentity/managed_identity_credential.go | 26 ++------ sdk/azidentity/on_behalf_of_credential.go | 31 +++------- .../on_behalf_of_credential_test.go | 2 - .../username_password_credential.go | 48 +++------------ .../username_password_credential_test.go | 12 ---- sdk/azidentity/workload_identity.go | 2 +- 16 files changed, 79 insertions(+), 288 deletions(-) diff --git a/sdk/azidentity/client_assertion_credential.go b/sdk/azidentity/client_assertion_credential.go index 6dff48a1f32c..303d5fc0925c 100644 --- a/sdk/azidentity/client_assertion_credential.go +++ b/sdk/azidentity/client_assertion_credential.go @@ -24,8 +24,7 @@ const credNameAssertion = "ClientAssertionCredential" // // [Azure AD documentation]: https://docs.microsoft.com/azure/active-directory/develop/active-directory-certificate-credentials#assertion-format type ClientAssertionCredential struct { - client confidentialClient - s *syncer + client *confidentialClient } // ClientAssertionCredentialOptions contains optional parameters for ClientAssertionCredential. @@ -56,38 +55,21 @@ func NewClientAssertionCredential(tenantID, clientID string, getAssertion func(c return getAssertion(ctx) }, ) - msalOpts := msalClientOptions{ - ClientOptions: options.ClientOptions, - DisableInstanceDiscovery: options.DisableInstanceDiscovery, + msalOpts := confidentialClientOptions{ + AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, } - c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts) + c, err := newConfidentialClient(tenantID, clientID, credNameAssertion, cred, msalOpts) if err != nil { return nil, err } - cac := ClientAssertionCredential{client: c} - cac.s = newSyncer( - credNameAssertion, - tenantID, - cac.requestToken, - cac.silentAuth, - syncerOptions{AdditionallyAllowedTenants: options.AdditionallyAllowedTenants}, - ) - return &cac, nil + return &ClientAssertionCredential{client: c}, nil } // GetToken requests an access token from Azure Active Directory. This method is called automatically by Azure SDK clients. func (c *ClientAssertionCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - return c.s.GetToken(ctx, opts) -} - -func (c *ClientAssertionCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID)) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err -} - -func (c *ClientAssertionCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID)) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err + return c.client.GetToken(ctx, opts) } var _ azcore.TokenCredential = (*ClientAssertionCredential)(nil) diff --git a/sdk/azidentity/client_certificate_credential.go b/sdk/azidentity/client_certificate_credential.go index 2f2db8e2cd37..d3300e3053bd 100644 --- a/sdk/azidentity/client_certificate_credential.go +++ b/sdk/azidentity/client_certificate_credential.go @@ -42,8 +42,7 @@ type ClientCertificateCredentialOptions struct { // ClientCertificateCredential authenticates a service principal with a certificate. type ClientCertificateCredential struct { - client confidentialClient - s *syncer + client *confidentialClient } // NewClientCertificateCredential constructs a ClientCertificateCredential. Pass nil for options to accept defaults. @@ -58,39 +57,22 @@ func NewClientCertificateCredential(tenantID string, clientID string, certs []*x if err != nil { return nil, err } - msalOpts := msalClientOptions{ - ClientOptions: options.ClientOptions, - DisableInstanceDiscovery: options.DisableInstanceDiscovery, - SendX5C: options.SendCertificateChain, + msalOpts := confidentialClientOptions{ + AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, + SendX5C: options.SendCertificateChain, } - c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts) + c, err := newConfidentialClient(tenantID, clientID, credNameCert, cred, msalOpts) if err != nil { return nil, err } - cc := ClientCertificateCredential{client: c} - cc.s = newSyncer( - credNameCert, - tenantID, - cc.requestToken, - cc.silentAuth, - syncerOptions{AdditionallyAllowedTenants: options.AdditionallyAllowedTenants}, - ) - return &cc, nil + return &ClientCertificateCredential{client: c}, nil } // GetToken requests an access token from Azure Active Directory. This method is called automatically by Azure SDK clients. func (c *ClientCertificateCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - return c.s.GetToken(ctx, opts) -} - -func (c *ClientCertificateCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID)) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err -} - -func (c *ClientCertificateCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID)) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err + return c.client.GetToken(ctx, opts) } // ParseCertificates loads certificates and a private key, in PEM or PKCS12 format, for use with NewClientCertificateCredential. diff --git a/sdk/azidentity/client_certificate_credential_test.go b/sdk/azidentity/client_certificate_credential_test.go index cf16da52b687..ff9ba1d908c8 100644 --- a/sdk/azidentity/client_certificate_credential_test.go +++ b/sdk/azidentity/client_certificate_credential_test.go @@ -81,7 +81,7 @@ func TestClientCertificateCredential_GetTokenSuccess(t *testing.T) { if err != nil { t.Fatalf("Expected an empty error but received: %s", err.Error()) } - cred.client = fakeConfidentialClient{} + cred.client.noCAE = fakeConfidentialClient{} _, err = cred.GetToken(context.Background(), testTRO) if err != nil { t.Fatalf("Expected an empty error but received: %s", err.Error()) @@ -98,7 +98,7 @@ func TestClientCertificateCredential_GetTokenSuccess_withCertificateChain(t *tes if err != nil { t.Fatalf("Expected an empty error but received: %s", err.Error()) } - cred.client = fakeConfidentialClient{} + cred.client.noCAE = fakeConfidentialClient{} _, err = cred.GetToken(context.Background(), testTRO) if err != nil { t.Fatalf("Expected an empty error but received: %s", err.Error()) @@ -132,7 +132,7 @@ func TestClientCertificateCredential_GetTokenCheckPrivateKeyBlocks(t *testing.T) if err != nil { t.Fatalf("Expected an empty error but received: %s", err.Error()) } - cred.client = fakeConfidentialClient{} + cred.client.noCAE = fakeConfidentialClient{} _, err = cred.GetToken(context.Background(), testTRO) if err != nil { t.Fatalf("Expected an empty error but received: %s", err.Error()) diff --git a/sdk/azidentity/client_secret_credential.go b/sdk/azidentity/client_secret_credential.go index f9ec5bef58e1..d2ff7582b997 100644 --- a/sdk/azidentity/client_secret_credential.go +++ b/sdk/azidentity/client_secret_credential.go @@ -33,8 +33,7 @@ type ClientSecretCredentialOptions struct { // ClientSecretCredential authenticates an application with a client secret. type ClientSecretCredential struct { - client confidentialClient - s *syncer + client *confidentialClient } // NewClientSecretCredential constructs a ClientSecretCredential. Pass nil for options to accept defaults. @@ -46,38 +45,21 @@ func NewClientSecretCredential(tenantID string, clientID string, clientSecret st if err != nil { return nil, err } - msalOpts := msalClientOptions{ - ClientOptions: options.ClientOptions, - DisableInstanceDiscovery: options.DisableInstanceDiscovery, + msalOpts := confidentialClientOptions{ + AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, } - c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts) + c, err := newConfidentialClient(tenantID, clientID, credNameSecret, cred, msalOpts) if err != nil { return nil, err } - csc := ClientSecretCredential{client: c} - csc.s = newSyncer( - credNameSecret, - tenantID, - csc.requestToken, - csc.silentAuth, - syncerOptions{AdditionallyAllowedTenants: options.AdditionallyAllowedTenants}, - ) - return &csc, nil + return &ClientSecretCredential{c}, nil } // GetToken requests an access token from Azure Active Directory. This method is called automatically by Azure SDK clients. func (c *ClientSecretCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - return c.s.GetToken(ctx, opts) -} - -func (c *ClientSecretCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID)) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err -} - -func (c *ClientSecretCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID)) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err + return c.client.GetToken(ctx, opts) } var _ azcore.TokenCredential = (*ClientSecretCredential)(nil) diff --git a/sdk/azidentity/client_secret_credential_test.go b/sdk/azidentity/client_secret_credential_test.go index b494ef98a7d7..868c05d3df91 100644 --- a/sdk/azidentity/client_secret_credential_test.go +++ b/sdk/azidentity/client_secret_credential_test.go @@ -32,7 +32,7 @@ func TestClientSecretCredential_GetTokenSuccess(t *testing.T) { if err != nil { t.Fatalf("Unable to create credential. Received: %v", err) } - cred.client = fakeConfidentialClient{} + cred.client.noCAE = fakeConfidentialClient{} _, err = cred.GetToken(context.Background(), testTRO) if err != nil { t.Fatalf("Expected an empty error but received: %v", err) diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index 1945ea82a68a..0287a2a88808 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -32,7 +32,7 @@ func TestDefaultAzureCredential_GetTokenSuccess(t *testing.T) { t.Fatalf("Unable to create credential. Received: %v", err) } c := cred.chain.sources[0].(*EnvironmentCredential) - c.cred.(*ClientSecretCredential).client = fakeConfidentialClient{} + c.cred.(*ClientSecretCredential).client.noCAE = fakeConfidentialClient{} _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"scope"}}) if err != nil { t.Fatalf("GetToken error: %v", err) diff --git a/sdk/azidentity/device_code_credential.go b/sdk/azidentity/device_code_credential.go index 0e635a017ff7..d245c269a760 100644 --- a/sdk/azidentity/device_code_credential.go +++ b/sdk/azidentity/device_code_credential.go @@ -12,7 +12,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) const credNameDeviceCode = "DeviceCodeCredential" @@ -74,10 +73,7 @@ type DeviceCodeMessage struct { // If a web browser is available, InteractiveBrowserCredential is more convenient because it // automatically opens a browser to the login page. type DeviceCodeCredential struct { - account public.Account - client publicClient - s *syncer - prompt func(context.Context, DeviceCodeMessage) error + client *publicClient } // NewDeviceCodeCredential creates a DeviceCodeCredential. Pass nil to accept default options. @@ -87,61 +83,24 @@ func NewDeviceCodeCredential(options *DeviceCodeCredentialOptions) (*DeviceCodeC cp = *options } cp.init() - msalOpts := msalClientOptions{ - ClientOptions: cp.ClientOptions, - DisableInstanceDiscovery: cp.DisableInstanceDiscovery, + msalOpts := publicClientOptions{ + AdditionallyAllowedTenants: cp.AdditionallyAllowedTenants, + ClientOptions: cp.ClientOptions, + DeviceCodePrompt: cp.UserPrompt, + DisableInstanceDiscovery: cp.DisableInstanceDiscovery, } - c, err := getPublicClient(cp.ClientID, cp.TenantID, msalOpts) + c, err := newPublicClient(cp.TenantID, cp.ClientID, credNameDeviceCode, msalOpts) if err != nil { return nil, err } - cred := DeviceCodeCredential{client: c, prompt: cp.UserPrompt} - cred.s = newSyncer( - credNameDeviceCode, - cp.TenantID, - cred.requestToken, - cred.silentAuth, - syncerOptions{ - AdditionallyAllowedTenants: cp.AdditionallyAllowedTenants, - }, - ) - return &cred, nil + c.name = credNameDeviceCode + return &DeviceCodeCredential{client: c}, nil } // GetToken requests an access token from Azure Active Directory. It will begin the device code flow and poll until the user completes authentication. // This method is called automatically by Azure SDK clients. func (c *DeviceCodeCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - return c.s.GetToken(ctx, opts) -} - -func (c *DeviceCodeCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - dc, err := c.client.AcquireTokenByDeviceCode(ctx, opts.Scopes, public.WithClaims(opts.Claims), public.WithTenantID(opts.TenantID)) - if err != nil { - return azcore.AccessToken{}, err - } - err = c.prompt(ctx, DeviceCodeMessage{ - Message: dc.Result.Message, - UserCode: dc.Result.UserCode, - VerificationURL: dc.Result.VerificationURL, - }) - if err != nil { - return azcore.AccessToken{}, err - } - ar, err := dc.AuthenticationResult(ctx) - if err != nil { - return azcore.AccessToken{}, err - } - c.account = ar.Account - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err -} - -func (c *DeviceCodeCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, - public.WithClaims(opts.Claims), - public.WithSilentAccount(c.account), - public.WithTenantID(opts.TenantID), - ) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err + return c.client.GetToken(ctx, opts) } var _ azcore.TokenCredential = (*DeviceCodeCredential)(nil) diff --git a/sdk/azidentity/device_code_credential_test.go b/sdk/azidentity/device_code_credential_test.go index d680fa35f049..ab2fe1679014 100644 --- a/sdk/azidentity/device_code_credential_test.go +++ b/sdk/azidentity/device_code_credential_test.go @@ -33,7 +33,7 @@ func TestDeviceCodeCredential_GetTokenInvalidCredentials(t *testing.T) { if err != nil { t.Fatalf("Unable to create credential. Received: %v", err) } - cred.client = fakePublicClient{err: errors.New("invalid credentials")} + cred.client.noCAE = fakePublicClient{err: errors.New("invalid credentials")} _, err = cred.GetToken(context.Background(), testTRO) if err == nil { t.Fatalf("Expected an error but did not receive one.") @@ -67,7 +67,7 @@ func TestDeviceCodeCredential_UserPromptError(t *testing.T) { if err != nil { t.Fatalf("Unable to create credential: %v", err) } - cred.client = fakePublicClient{ + cred.client.noCAE = fakePublicClient{ dc: public.DeviceCode{ Result: public.DeviceCodeResult{ Message: expected.Message, diff --git a/sdk/azidentity/interactive_browser_credential.go b/sdk/azidentity/interactive_browser_credential.go index b5c0a8d9be56..08f3efbf3ec4 100644 --- a/sdk/azidentity/interactive_browser_credential.go +++ b/sdk/azidentity/interactive_browser_credential.go @@ -11,7 +11,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) const credNameBrowser = "InteractiveBrowserCredential" @@ -56,10 +55,7 @@ func (o *InteractiveBrowserCredentialOptions) init() { // InteractiveBrowserCredential opens a browser to interactively authenticate a user. type InteractiveBrowserCredential struct { - account public.Account - client publicClient - options InteractiveBrowserCredentialOptions - s *syncer + client *publicClient } // NewInteractiveBrowserCredential constructs a new InteractiveBrowserCredential. Pass nil to accept default options. @@ -69,52 +65,22 @@ func NewInteractiveBrowserCredential(options *InteractiveBrowserCredentialOption cp = *options } cp.init() - msalOpts := msalClientOptions{ + msalOpts := publicClientOptions{ ClientOptions: cp.ClientOptions, DisableInstanceDiscovery: cp.DisableInstanceDiscovery, + LoginHint: cp.LoginHint, + RedirectURL: cp.RedirectURL, } - c, err := getPublicClient(cp.ClientID, cp.TenantID, msalOpts) + c, err := newPublicClient(cp.TenantID, cp.ClientID, credNameBrowser, msalOpts) if err != nil { return nil, err } - ibc := InteractiveBrowserCredential{client: c, options: cp} - ibc.s = newSyncer( - credNameBrowser, - cp.TenantID, - ibc.requestToken, - ibc.silentAuth, - syncerOptions{ - AdditionallyAllowedTenants: cp.AdditionallyAllowedTenants, - }, - ) - return &ibc, nil + return &InteractiveBrowserCredential{client: c}, nil } // GetToken requests an access token from Azure Active Directory. This method is called automatically by Azure SDK clients. func (c *InteractiveBrowserCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - return c.s.GetToken(ctx, opts) -} - -func (c *InteractiveBrowserCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenInteractive(ctx, opts.Scopes, - public.WithClaims(opts.Claims), - public.WithLoginHint(c.options.LoginHint), - public.WithRedirectURI(c.options.RedirectURL), - public.WithTenantID(opts.TenantID), - ) - if err == nil { - c.account = ar.Account - } - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err -} - -func (c *InteractiveBrowserCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, - public.WithClaims(opts.Claims), - public.WithSilentAccount(c.account), - public.WithTenantID(opts.TenantID), - ) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err + return c.client.GetToken(ctx, opts) } var _ azcore.TokenCredential = (*InteractiveBrowserCredential)(nil) diff --git a/sdk/azidentity/interactive_browser_credential_test.go b/sdk/azidentity/interactive_browser_credential_test.go index a47323a06cec..26b061d99a4f 100644 --- a/sdk/azidentity/interactive_browser_credential_test.go +++ b/sdk/azidentity/interactive_browser_credential_test.go @@ -36,7 +36,7 @@ func TestInteractiveBrowserCredential_GetTokenSuccess(t *testing.T) { if err != nil { t.Fatalf("Unable to create credential. Received: %v", err) } - cred.client = fakePublicClient{ + cred.client.noCAE = fakePublicClient{ ar: public.AuthResult{ AccessToken: tokenValue, ExpiresOn: time.Now().Add(1 * time.Hour), @@ -51,19 +51,6 @@ func TestInteractiveBrowserCredential_GetTokenSuccess(t *testing.T) { } } -func TestInteractiveBrowserCredential_CreateWithNilOptions(t *testing.T) { - cred, err := NewInteractiveBrowserCredential(nil) - if err != nil { - t.Fatalf("Failed to create interactive browser credential: %v", err) - } - if cred.options.ClientID != developerSignOnClientID { - t.Fatalf("Wrong clientID set. Expected: %s, Received: %s", developerSignOnClientID, cred.options.ClientID) - } - if cred.options.TenantID != organizationsTenantID { - t.Fatalf("Wrong tenantID set. Expected: %s, Received: %s", organizationsTenantID, cred.options.TenantID) - } -} - // instanceDiscoveryPolicy fails the test when the client requests instance metadata type instanceDiscoveryPolicy struct { t *testing.T @@ -89,7 +76,7 @@ func TestInteractiveBrowserCredential_Live(t *testing.T) { }) t.Run("LoginHint", func(t *testing.T) { upn := "test@pass" - fmt.Printf("\t%s: consider this test passing when %q appears in the login prompt", t.Name(), upn) + fmt.Printf("\t%s: consider this test passing when %q appears in the login prompt\n", t.Name(), upn) cred, err := NewInteractiveBrowserCredential(&InteractiveBrowserCredentialOptions{LoginHint: upn}) if err != nil { t.Fatal(err) @@ -98,7 +85,7 @@ func TestInteractiveBrowserCredential_Live(t *testing.T) { }) t.Run("RedirectURL", func(t *testing.T) { url := "http://localhost:8180" - fmt.Printf("\t%s: consider this test passing when AAD redirects to %s", t.Name(), url) + fmt.Printf("\t%s: consider this test passing when AAD redirects to %s\n", t.Name(), url) cred, err := NewInteractiveBrowserCredential(&InteractiveBrowserCredentialOptions{RedirectURL: url}) if err != nil { t.Fatal(err) diff --git a/sdk/azidentity/managed_identity_credential.go b/sdk/azidentity/managed_identity_credential.go index 68c41a701532..35c5e6725cda 100644 --- a/sdk/azidentity/managed_identity_credential.go +++ b/sdk/azidentity/managed_identity_credential.go @@ -8,7 +8,6 @@ package azidentity import ( "context" - "errors" "fmt" "strings" @@ -71,9 +70,8 @@ type ManagedIdentityCredentialOptions struct { // user-assigned identity. See Azure Active Directory documentation for more information about managed identities: // https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview type ManagedIdentityCredential struct { - client confidentialClient + client *confidentialClient mic *managedIdentityClient - s *syncer } // NewManagedIdentityCredential creates a ManagedIdentityCredential. Pass nil to accept default options. @@ -93,35 +91,23 @@ func NewManagedIdentityCredential(options *ManagedIdentityCredentialOptions) (*M if options.ID != nil { clientID = options.ID.String() } - // similarly, it's okay to give MSAL an incorrect authority URL because that URL won't be used - c, err := confidential.New("https://login.microsoftonline.com/common", clientID, cred) + // similarly, it's okay to give MSAL an incorrect tenant because MSAL won't use the value + c, err := newConfidentialClient("common", clientID, credNameManagedIdentity, cred, confidentialClientOptions{}) if err != nil { return nil, err } - m := ManagedIdentityCredential{client: c, mic: mic} - m.s = newSyncer(credNameManagedIdentity, "", m.requestToken, m.silentAuth, syncerOptions{}) - return &m, nil + return &ManagedIdentityCredential{client: c, mic: mic}, nil } // GetToken requests an access token from the hosting environment. This method is called automatically by Azure SDK clients. func (c *ManagedIdentityCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { if len(opts.Scopes) != 1 { - err := errors.New(credNameManagedIdentity + ": GetToken() requires exactly one scope") + err := fmt.Errorf("%s.GetToken() requires exactly one scope", credNameManagedIdentity) return azcore.AccessToken{}, err } // managed identity endpoints require an AADv1 resource (i.e. token audience), not a v2 scope, so we remove "/.default" here opts.Scopes = []string{strings.TrimSuffix(opts.Scopes[0], defaultSuffix)} - return c.s.GetToken(ctx, opts) -} - -func (c *ManagedIdentityCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err -} - -func (c *ManagedIdentityCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err + return c.client.GetToken(ctx, opts) } var _ azcore.TokenCredential = (*ManagedIdentityCredential)(nil) diff --git a/sdk/azidentity/on_behalf_of_credential.go b/sdk/azidentity/on_behalf_of_credential.go index d2c106347c5c..2b360b681df1 100644 --- a/sdk/azidentity/on_behalf_of_credential.go +++ b/sdk/azidentity/on_behalf_of_credential.go @@ -25,9 +25,7 @@ const credNameOBO = "OnBehalfOfCredential" // // [Azure Active Directory documentation]: https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow type OnBehalfOfCredential struct { - assertion string - client confidentialClient - s *syncer + client *confidentialClient } // OnBehalfOfCredentialOptions contains optional parameters for OnBehalfOfCredential @@ -72,32 +70,23 @@ func newOnBehalfOfCredential(tenantID, clientID, userAssertion string, cred conf if options == nil { options = &OnBehalfOfCredentialOptions{} } - msalOpts := msalClientOptions{ - ClientOptions: options.ClientOptions, - DisableInstanceDiscovery: options.DisableInstanceDiscovery, - SendX5C: options.SendCertificateChain, + opts := confidentialClientOptions{ + AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, + Assertion: userAssertion, + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, + SendX5C: options.SendCertificateChain, } - c, err := getConfidentialClient(clientID, tenantID, cred, msalOpts) + c, err := newConfidentialClient(tenantID, clientID, credNameOBO, cred, opts) if err != nil { return nil, err } - obo := OnBehalfOfCredential{assertion: userAssertion, client: c} - // this credential doesn't have a silent auth method because MSAL implements that in AcquireTokenOnBehalfOf; GetToken should just call that method, once - obo.s = newSyncer(credNameOBO, tenantID, obo.requestToken, nil, syncerOptions{AdditionallyAllowedTenants: options.AdditionallyAllowedTenants}) - return &obo, nil + return &OnBehalfOfCredential{c}, nil } // GetToken requests an access token from Azure Active Directory. This method is called automatically by Azure SDK clients. func (o *OnBehalfOfCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - return o.s.GetToken(ctx, opts) -} - -func (o *OnBehalfOfCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := o.client.AcquireTokenOnBehalfOf(ctx, o.assertion, opts.Scopes, - confidential.WithClaims(opts.Claims), - confidential.WithTenantID(opts.TenantID), - ) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err + return o.client.GetToken(ctx, opts) } var _ azcore.TokenCredential = (*OnBehalfOfCredential)(nil) diff --git a/sdk/azidentity/on_behalf_of_credential_test.go b/sdk/azidentity/on_behalf_of_credential_test.go index 390a2f4e36b4..1dd39c900dc0 100644 --- a/sdk/azidentity/on_behalf_of_credential_test.go +++ b/sdk/azidentity/on_behalf_of_credential_test.go @@ -18,8 +18,6 @@ import ( ) func TestOnBehalfOfCredential(t *testing.T) { - realGetClient := getConfidentialClient - t.Cleanup(func() { getConfidentialClient = realGetClient }) expectedAssertion := "user-assertion" certs, key := allCertTests[0].certs, allCertTests[0].key for _, test := range []struct { diff --git a/sdk/azidentity/username_password_credential.go b/sdk/azidentity/username_password_credential.go index f031687464ca..f787ec0ce18f 100644 --- a/sdk/azidentity/username_password_credential.go +++ b/sdk/azidentity/username_password_credential.go @@ -11,7 +11,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) const credNameUserPassword = "UsernamePasswordCredential" @@ -36,10 +35,7 @@ type UsernamePasswordCredentialOptions struct { // with any form of multi-factor authentication, and the application must already have user or admin consent. // This credential can only authenticate work and school accounts; it can't authenticate Microsoft accounts. type UsernamePasswordCredential struct { - account public.Account - client publicClient - password, username string - s *syncer + client *publicClient } // NewUsernamePasswordCredential creates a UsernamePasswordCredential. clientID is the ID of the application the user @@ -48,47 +44,23 @@ func NewUsernamePasswordCredential(tenantID string, clientID string, username st if options == nil { options = &UsernamePasswordCredentialOptions{} } - msalOpts := msalClientOptions{ - ClientOptions: options.ClientOptions, - DisableInstanceDiscovery: options.DisableInstanceDiscovery, + opts := publicClientOptions{ + AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, + ClientOptions: options.ClientOptions, + DisableInstanceDiscovery: options.DisableInstanceDiscovery, + Password: password, + Username: username, } - c, err := getPublicClient(clientID, tenantID, msalOpts) + c, err := newPublicClient(tenantID, clientID, credNameUserPassword, opts) if err != nil { return nil, err } - upc := UsernamePasswordCredential{client: c, password: password, username: username} - upc.s = newSyncer( - credNameUserPassword, - tenantID, - upc.requestToken, - upc.silentAuth, - syncerOptions{ - AdditionallyAllowedTenants: options.AdditionallyAllowedTenants, - }, - ) - return &upc, nil + return &UsernamePasswordCredential{client: c}, err } // GetToken requests an access token from Azure Active Directory. This method is called automatically by Azure SDK clients. func (c *UsernamePasswordCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - return c.s.GetToken(ctx, opts) -} - -func (c *UsernamePasswordCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenByUsernamePassword(ctx, opts.Scopes, c.username, c.password, public.WithClaims(opts.Claims), public.WithTenantID(opts.TenantID)) - if err == nil { - c.account = ar.Account - } - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err -} - -func (c *UsernamePasswordCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, - public.WithClaims(opts.Claims), - public.WithSilentAccount(c.account), - public.WithTenantID(opts.TenantID), - ) - return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err + return c.client.GetToken(ctx, opts) } var _ azcore.TokenCredential = (*UsernamePasswordCredential)(nil) diff --git a/sdk/azidentity/username_password_credential_test.go b/sdk/azidentity/username_password_credential_test.go index b029d5029b00..64960ab7f4e3 100644 --- a/sdk/azidentity/username_password_credential_test.go +++ b/sdk/azidentity/username_password_credential_test.go @@ -25,18 +25,6 @@ func TestUsernamePasswordCredential_InvalidTenantID(t *testing.T) { } } -func TestUsernamePasswordCredential_GetTokenSuccess(t *testing.T) { - cred, err := NewUsernamePasswordCredential(fakeTenantID, fakeClientID, "username", "password", nil) - if err != nil { - t.Fatalf("Unable to create credential. Received: %v", err) - } - cred.client = fakePublicClient{} - _, err = cred.GetToken(context.Background(), testTRO) - if err != nil { - t.Fatalf("Expected an empty error but received: %s", err.Error()) - } -} - func TestUsernamePasswordCredential_Live(t *testing.T) { for _, disabledID := range []bool{true, false} { name := "default options" diff --git a/sdk/azidentity/workload_identity.go b/sdk/azidentity/workload_identity.go index dd424c1e103e..7e016324d229 100644 --- a/sdk/azidentity/workload_identity.go +++ b/sdk/azidentity/workload_identity.go @@ -88,7 +88,7 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) ( return nil, err } // we want "WorkloadIdentityCredential" in log messages, not "ClientAssertionCredential" - cred.s.name = credNameWorkloadIdentity + cred.client.name = credNameWorkloadIdentity w.cred = cred return &w, nil } From 7cca0769897224cbf7ac16bca4cc90f00d0d8b5e Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 4 Aug 2023 22:08:30 +0000 Subject: [PATCH 07/13] add workload identity test --- sdk/azidentity/azidentity_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index d6e3a9461b30..6c4fd61b19a6 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -496,6 +496,17 @@ func TestClaims(t *testing.T) { return NewUsernamePasswordCredential(fakeTenantID, fakeClientID, fakeUsername, "password", &o) }, }, + { + name: credNameWorkloadIdentity, + ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) { + tokenFile := filepath.Join(t.TempDir(), "token") + if err := os.WriteFile(tokenFile, []byte(tokenValue), os.ModePerm); err != nil { + t.Fatalf("failed to write token file: %v", err) + } + o := WorkloadIdentityCredentialOptions{ClientID: fakeClientID, ClientOptions: co, TenantID: fakeTenantID, TokenFilePath: tokenFile} + return NewWorkloadIdentityCredential(&o) + }, + }, } { for _, enableCAE := range []bool{true, false} { name := test.name From b371c8f2e900577de83c50c7e2ec323323269904 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 4 Aug 2023 22:37:32 +0000 Subject: [PATCH 08/13] AzureCLICredential can synchronize itself --- sdk/azidentity/azidentity_test.go | 2 +- sdk/azidentity/azure_cli_credential.go | 32 ++++---- sdk/azidentity/syncer.go | 106 ------------------------- sdk/azidentity/syncer_test.go | 55 ------------- 4 files changed, 16 insertions(+), 179 deletions(-) delete mode 100644 sdk/azidentity/syncer.go delete mode 100644 sdk/azidentity/syncer_test.go diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index 6c4fd61b19a6..ef721d363072 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -420,7 +420,7 @@ func TestAdditionallyAllowedTenants(t *testing.T) { called := false for _, source := range c.chain.sources { if cli, ok := source.(*AzureCLICredential); ok { - cli.tokenProvider = func(ctx context.Context, resource, tenantID string) ([]byte, error) { + cli.opts.tokenProvider = func(ctx context.Context, resource, tenantID string) ([]byte, error) { called = true if tenantID != test.expected { t.Fatalf(`unexpected tenantID "%s"`, tenantID) diff --git a/sdk/azidentity/azure_cli_credential.go b/sdk/azidentity/azure_cli_credential.go index 3604c9597fff..c71368af6342 100644 --- a/sdk/azidentity/azure_cli_credential.go +++ b/sdk/azidentity/azure_cli_credential.go @@ -17,10 +17,12 @@ import ( "regexp" "runtime" "strings" + "sync" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) const ( @@ -53,8 +55,8 @@ func (o *AzureCLICredentialOptions) init() { // AzureCLICredential authenticates as the identity logged in to the Azure CLI. type AzureCLICredential struct { - s *syncer - tokenProvider azureCLITokenProvider + mu *sync.Mutex + opts AzureCLICredentialOptions } // NewAzureCLICredential constructs an AzureCLICredential. Pass nil to accept default options. @@ -64,15 +66,7 @@ func NewAzureCLICredential(options *AzureCLICredentialOptions) (*AzureCLICredent cp = *options } cp.init() - c := AzureCLICredential{tokenProvider: cp.tokenProvider} - c.s = newSyncer( - credNameAzureCLI, - cp.TenantID, - c.requestToken, - nil, // this credential doesn't have a silent auth method because the CLI handles caching - syncerOptions{AdditionallyAllowedTenants: cp.AdditionallyAllowedTenants}, - ) - return &c, nil + return &AzureCLICredential{mu: &sync.Mutex{}, opts: cp}, nil } // GetToken requests a token from the Azure CLI. This credential doesn't cache tokens, so every call invokes the CLI. @@ -81,13 +75,15 @@ func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequ if len(opts.Scopes) != 1 { return azcore.AccessToken{}, errors.New(credNameAzureCLI + ": GetToken() requires exactly one scope") } - // CLI expects an AAD v1 resource, not a v2 scope + tenant, err := resolveTenant(c.opts.TenantID, opts.TenantID, credNameAzureCLI, c.opts.AdditionallyAllowedTenants) + if err != nil { + return azcore.AccessToken{}, err + } + // pass the CLI an AAD v1 resource because we don't know which CLI version is installed and older ones don't support v2 scopes opts.Scopes = []string{strings.TrimSuffix(opts.Scopes[0], defaultSuffix)} - return c.s.GetToken(ctx, opts) -} - -func (c *AzureCLICredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - b, err := c.tokenProvider(ctx, opts.Scopes[0], opts.TenantID) + c.mu.Lock() + defer c.mu.Unlock() + b, err := c.opts.tokenProvider(ctx, opts.Scopes[0], tenant) if err != nil { return azcore.AccessToken{}, err } @@ -95,6 +91,8 @@ func (c *AzureCLICredential) requestToken(ctx context.Context, opts policy.Token if err != nil { return azcore.AccessToken{}, err } + msg := fmt.Sprintf("%s.GetToken() acquired a token for scope %q", credNameAzureCLI, strings.Join(opts.Scopes, ", ")) + log.Write(EventAuthentication, msg) return at, nil } diff --git a/sdk/azidentity/syncer.go b/sdk/azidentity/syncer.go deleted file mode 100644 index fe4f146d51cd..000000000000 --- a/sdk/azidentity/syncer.go +++ /dev/null @@ -1,106 +0,0 @@ -//go:build go1.18 -// +build go1.18 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azidentity - -import ( - "context" - "errors" - "fmt" - "strings" - "sync" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/internal/log" -) - -type authFn func(context.Context, policy.TokenRequestOptions) (azcore.AccessToken, error) - -// syncer synchronizes authentication calls so that goroutines can share a credential instance -type syncer struct { - addlTenants []string - mu *sync.Mutex - reqToken, silent authFn - name, tenant string -} - -type syncerOptions struct { - // AdditionallyAllowedTenants syncer may authenticate to - AdditionallyAllowedTenants []string -} - -func newSyncer(name, tenant string, reqToken, silentAuth authFn, opts syncerOptions) *syncer { - return &syncer{ - addlTenants: resolveAdditionalTenants(opts.AdditionallyAllowedTenants), - mu: &sync.Mutex{}, - name: name, - reqToken: reqToken, - silent: silentAuth, - tenant: tenant, - } -} - -// GetToken ensures that only one goroutine authenticates at a time -func (s *syncer) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - at := azcore.AccessToken{} - if len(opts.Scopes) == 0 { - return at, errors.New(s.name + ".GetToken() requires at least one scope") - } - // we don't resolve the tenant for managed identities because they can acquire tokens only from their home tenants - if s.name != credNameManagedIdentity { - tenant, err := s.resolveTenant(opts.TenantID) - if err != nil { - return at, err - } - opts.TenantID = tenant - } - var err error - s.mu.Lock() - defer s.mu.Unlock() - if s.silent == nil { - at, err = s.reqToken(ctx, opts) - } else if at, err = s.silent(ctx, opts); err != nil { - // cache miss; request a new token - at, err = s.reqToken(ctx, opts) - } - if err != nil { - // Return credentialUnavailableError directly because that type affects the behavior of credential chains. - // Otherwise, return AuthenticationFailedError. - var unavailableErr *credentialUnavailableError - if !errors.As(err, &unavailableErr) { - res := getResponseFromError(err) - err = newAuthenticationFailedError(s.name, err.Error(), res, err) - } - } else if log.Should(EventAuthentication) { - scope := strings.Join(opts.Scopes, ", ") - msg := fmt.Sprintf(`%s.GetToken() acquired a token for scope "%s"\n`, s.name, scope) - log.Write(EventAuthentication, msg) - } - return at, err -} - -// resolveTenant returns the correct tenant for a token request given the credential's -// configuration, or an error when the specified tenant isn't allowed by that configuration -func (s *syncer) resolveTenant(requested string) (string, error) { - return resolveTenant(s.tenant, requested, s.name, s.addlTenants) -} - -// resolveAdditionalTenants returns a copy of tenants, simplified when tenants contains a wildcard -func resolveAdditionalTenants(tenants []string) []string { - if len(tenants) == 0 { - return nil - } - for _, t := range tenants { - // a wildcard makes all other values redundant - if t == "*" { - return []string{"*"} - } - } - cp := make([]string, len(tenants)) - copy(cp, tenants) - return cp -} diff --git a/sdk/azidentity/syncer_test.go b/sdk/azidentity/syncer_test.go deleted file mode 100644 index 2c60ed445c8f..000000000000 --- a/sdk/azidentity/syncer_test.go +++ /dev/null @@ -1,55 +0,0 @@ -//go:build go1.18 -// +build go1.18 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azidentity - -import ( - "context" - "errors" - "sync" - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" -) - -func TestSyncer(t *testing.T) { - silentAuths, tokenRequests := 0, 0 - s := newSyncer("", "tenant", - func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { - tokenRequests++ - return azcore.AccessToken{}, nil - }, - func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { - var err error - if tokenRequests == 0 { - err = errors.New("cache miss") - } - silentAuths++ - return azcore.AccessToken{}, err - }, - syncerOptions{}, - ) - goroutines := 50 - wg := sync.WaitGroup{} - for i := 0; i < goroutines; i++ { - wg.Add(1) - go func() { - _, err := s.GetToken(context.Background(), testTRO) - if err != nil { - t.Error(err) - } - wg.Done() - }() - } - wg.Wait() - if tokenRequests != 1 { - t.Errorf("expected 1 token request, got %d", tokenRequests) - } - if silentAuths != goroutines { - t.Errorf("expected %d silent auth attempts, got %d", goroutines, silentAuths) - } -} From d7c89c66a521a753467afc362d71c49e58137cc6 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 17 Aug 2023 11:00:47 -0700 Subject: [PATCH 09/13] thanks, Scott Co-authored-by: Scott Addie <10702007+scottaddie@users.noreply.github.com> --- sdk/azidentity/azidentity.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index f0b22649859e..2ae9d0d3ab2b 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -46,7 +46,7 @@ const ( var ( // capability CP1 indicates the client application is capable of handling CAE claims challenges cp1 = []string{"CP1"} - errInvalidTenantID = errors.New("invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names") + errInvalidTenantID = errors.New("invalid tenantID. You can locate your tenantID by following the instructions listed here: https://learn.microsoft.com/partner-center/find-ids-and-domain-names") ) // setAuthorityHost initializes the authority host for credentials. Precedence is: From 6eab27748626ac9bbfc8fca9c810d6493f336c43 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Mon, 21 Aug 2023 15:23:39 -0700 Subject: [PATCH 10/13] restore resolveAdditionalTenants --- sdk/azidentity/azidentity.go | 16 ++++++++++++++++ sdk/azidentity/azure_cli_credential.go | 1 + sdk/azidentity/confidential_client.go | 1 + sdk/azidentity/public_client.go | 1 + 4 files changed, 19 insertions(+) diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index 2ae9d0d3ab2b..10b742ce1a13 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -74,6 +74,22 @@ func setAuthorityHost(cc cloud.Configuration) (string, error) { return host, nil } +// resolveAdditionalTenants returns a copy of tenants, simplified when tenants contains a wildcard +func resolveAdditionalTenants(tenants []string) []string { + if len(tenants) == 0 { + return nil + } + for _, t := range tenants { + // a wildcard makes all other values redundant + if t == "*" { + return []string{"*"} + } + } + cp := make([]string, len(tenants)) + copy(cp, tenants) + return cp +} + // resolveTenant returns the correct tenant for a token request func resolveTenant(defaultTenant, specified, credName string, additionalTenants []string) (string, error) { if specified == "" || specified == defaultTenant { diff --git a/sdk/azidentity/azure_cli_credential.go b/sdk/azidentity/azure_cli_credential.go index c71368af6342..55a0d654347e 100644 --- a/sdk/azidentity/azure_cli_credential.go +++ b/sdk/azidentity/azure_cli_credential.go @@ -66,6 +66,7 @@ func NewAzureCLICredential(options *AzureCLICredentialOptions) (*AzureCLICredent cp = *options } cp.init() + cp.AdditionallyAllowedTenants = resolveAdditionalTenants(cp.AdditionallyAllowedTenants) return &AzureCLICredential{mu: &sync.Mutex{}, opts: cp}, nil } diff --git a/sdk/azidentity/confidential_client.go b/sdk/azidentity/confidential_client.go index 13ca474fe171..5d18bb2114c8 100644 --- a/sdk/azidentity/confidential_client.go +++ b/sdk/azidentity/confidential_client.go @@ -49,6 +49,7 @@ func newConfidentialClient(tenantID, clientID, name string, cred confidential.Cr if err != nil { return nil, err } + opts.AdditionallyAllowedTenants = resolveAdditionalTenants(opts.AdditionallyAllowedTenants) return &confidentialClient{ caeMu: &sync.Mutex{}, clientID: clientID, diff --git a/sdk/azidentity/public_client.go b/sdk/azidentity/public_client.go index dfd75861179c..87ca0644f796 100644 --- a/sdk/azidentity/public_client.go +++ b/sdk/azidentity/public_client.go @@ -48,6 +48,7 @@ func newPublicClient(tenantID, clientID, name string, o publicClientOptions) (*p if err != nil { return nil, err } + o.AdditionallyAllowedTenants = resolveAdditionalTenants(o.AdditionallyAllowedTenants) return &publicClient{ caeMu: &sync.Mutex{}, clientID: clientID, From b55504dc531e1d05416640384c9ffc948c395534 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Mon, 21 Aug 2023 15:48:00 -0700 Subject: [PATCH 11/13] thanks, Joel --- sdk/azidentity/confidential_client.go | 10 +++++----- sdk/azidentity/public_client.go | 24 ++++++++++++------------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sdk/azidentity/confidential_client.go b/sdk/azidentity/confidential_client.go index 5d18bb2114c8..e8670e020d89 100644 --- a/sdk/azidentity/confidential_client.go +++ b/sdk/azidentity/confidential_client.go @@ -33,7 +33,7 @@ type confidentialClientOptions struct { // confidentialClient wraps the MSAL confidential client type confidentialClient struct { cae, noCAE msalConfidentialClient - caeMu, clientMu, noCAEMu *sync.Mutex + caeMu, noCAEMu, clientMu *sync.Mutex clientID, tenantID string cred confidential.Credential host string @@ -68,10 +68,6 @@ func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenReque if len(tro.Scopes) < 1 { return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", c.name) } - client, mu, err := c.client(ctx, tro) - if err != nil { - return azcore.AccessToken{}, err - } // we don't resolve the tenant for managed identities because they acquire tokens only from their home tenants if c.name != credNameManagedIdentity { tenant, err := c.resolveTenant(tro.TenantID) @@ -80,6 +76,10 @@ func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenReque } tro.TenantID = tenant } + client, mu, err := c.client(ctx, tro) + if err != nil { + return azcore.AccessToken{}, err + } mu.Lock() defer mu.Unlock() var ar confidential.AuthResult diff --git a/sdk/azidentity/public_client.go b/sdk/azidentity/public_client.go index 87ca0644f796..8b362f07f89c 100644 --- a/sdk/azidentity/public_client.go +++ b/sdk/azidentity/public_client.go @@ -33,8 +33,8 @@ type publicClientOptions struct { type publicClient struct { account public.Account cae, noCAE msalPublicClient + caeMu, noCAEMu, clientMu *sync.Mutex clientID, tenantID string - clientMu, caeMu, noCAEMu *sync.Mutex host string name string opts publicClientOptions @@ -89,19 +89,19 @@ func (p *publicClient) GetToken(ctx context.Context, tro policy.TokenRequestOpti } // reqToken requests a token from the MSAL public client. It's separate from GetToken() to enable Authenticate() to bypass the cache. -func (m *publicClient) reqToken(ctx context.Context, c msalPublicClient, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { - tenant, err := m.resolveTenant(tro.TenantID) +func (p *publicClient) reqToken(ctx context.Context, c msalPublicClient, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + tenant, err := p.resolveTenant(tro.TenantID) if err != nil { return azcore.AccessToken{}, err } var ar public.AuthResult switch { - case m.opts.DeviceCodePrompt != nil: + case p.opts.DeviceCodePrompt != nil: dc, e := c.AcquireTokenByDeviceCode(ctx, tro.Scopes, public.WithClaims(tro.Claims), public.WithTenantID(tenant)) if e != nil { return azcore.AccessToken{}, e } - err = m.opts.DeviceCodePrompt(ctx, DeviceCodeMessage{ + err = p.opts.DeviceCodePrompt(ctx, DeviceCodeMessage{ Message: dc.Result.Message, UserCode: dc.Result.UserCode, VerificationURL: dc.Result.VerificationURL, @@ -109,17 +109,17 @@ func (m *publicClient) reqToken(ctx context.Context, c msalPublicClient, tro pol if err == nil { ar, err = dc.AuthenticationResult(ctx) } - case m.opts.Username != "" && m.opts.Password != "": - ar, err = c.AcquireTokenByUsernamePassword(ctx, tro.Scopes, m.opts.Username, m.opts.Password, public.WithClaims(tro.Claims), public.WithTenantID(tenant)) + case p.opts.Username != "" && p.opts.Password != "": + ar, err = c.AcquireTokenByUsernamePassword(ctx, tro.Scopes, p.opts.Username, p.opts.Password, public.WithClaims(tro.Claims), public.WithTenantID(tenant)) default: ar, err = c.AcquireTokenInteractive(ctx, tro.Scopes, public.WithClaims(tro.Claims), - public.WithLoginHint(m.opts.LoginHint), - public.WithRedirectURI(m.opts.RedirectURL), + public.WithLoginHint(p.opts.LoginHint), + public.WithRedirectURI(p.opts.RedirectURL), public.WithTenantID(tenant), ) } - return m.token(ar, err) + return p.token(ar, err) } func (p *publicClient) client(tro policy.TokenRequestOptions) (msalPublicClient, *sync.Mutex, error) { @@ -171,6 +171,6 @@ func (p *publicClient) token(ar public.AuthResult, err error) (azcore.AccessToke // resolveTenant returns the correct tenant for a token request given the client's // configuration, or an error when that configuration doesn't allow the specified tenant -func (m *publicClient) resolveTenant(specified string) (string, error) { - return resolveTenant(m.tenantID, specified, m.name, m.opts.AdditionallyAllowedTenants) +func (p *publicClient) resolveTenant(specified string) (string, error) { + return resolveTenant(p.tenantID, specified, p.name, p.opts.AdditionallyAllowedTenants) } From 4a026a3b85f3b9d9ec6df4ae7bf4de963d9b2db0 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Mon, 21 Aug 2023 15:57:32 -0700 Subject: [PATCH 12/13] publicClient switches on credential name --- sdk/azidentity/public_client.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/sdk/azidentity/public_client.go b/sdk/azidentity/public_client.go index 8b362f07f89c..6512d3e25fd8 100644 --- a/sdk/azidentity/public_client.go +++ b/sdk/azidentity/public_client.go @@ -95,8 +95,15 @@ func (p *publicClient) reqToken(ctx context.Context, c msalPublicClient, tro pol return azcore.AccessToken{}, err } var ar public.AuthResult - switch { - case p.opts.DeviceCodePrompt != nil: + switch p.name { + case credNameBrowser: + ar, err = c.AcquireTokenInteractive(ctx, tro.Scopes, + public.WithClaims(tro.Claims), + public.WithLoginHint(p.opts.LoginHint), + public.WithRedirectURI(p.opts.RedirectURL), + public.WithTenantID(tenant), + ) + case credNameDeviceCode: dc, e := c.AcquireTokenByDeviceCode(ctx, tro.Scopes, public.WithClaims(tro.Claims), public.WithTenantID(tenant)) if e != nil { return azcore.AccessToken{}, e @@ -109,15 +116,10 @@ func (p *publicClient) reqToken(ctx context.Context, c msalPublicClient, tro pol if err == nil { ar, err = dc.AuthenticationResult(ctx) } - case p.opts.Username != "" && p.opts.Password != "": + case credNameUserPassword: ar, err = c.AcquireTokenByUsernamePassword(ctx, tro.Scopes, p.opts.Username, p.opts.Password, public.WithClaims(tro.Claims), public.WithTenantID(tenant)) default: - ar, err = c.AcquireTokenInteractive(ctx, tro.Scopes, - public.WithClaims(tro.Claims), - public.WithLoginHint(p.opts.LoginHint), - public.WithRedirectURI(p.opts.RedirectURL), - public.WithTenantID(tenant), - ) + return azcore.AccessToken{}, fmt.Errorf("unknown credential %q", p.name) } return p.token(ar, err) } From 2d1789952570678b1bc19b874e33cce5ddd01b2f Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Mon, 21 Aug 2023 16:38:33 -0700 Subject: [PATCH 13/13] changelog --- sdk/azidentity/CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 551dbe4b1ecc..782a1061505f 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -3,8 +3,13 @@ ## 1.4.0-beta.5 (Unreleased) ### Features Added +* Service principal credentials can request CAE tokens ### Breaking Changes +> These changes affect only code written against a beta version such as v1.4.0-beta.4 +* Whether `GetToken` requests a CAE token is now determined by `TokenRequestOptions.EnableCAE`. Azure + SDK clients which support CAE will set this option automatically. Credentials no longer request CAE + tokens by default or observe the environment variable "AZURE_IDENTITY_DISABLE_CP1". ### Bugs Fixed