Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
### Breaking Changes

### Bugs Fixed
* Credentials now synchronize within `GetToken()` so a single instance can be shared among goroutines
([#20044](https://github.com/Azure/azure-sdk-for-go/issues/20044))

### Other Changes

Expand Down
36 changes: 0 additions & 36 deletions sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
Expand Down Expand Up @@ -96,41 +95,6 @@ var getPublicClient = func(clientID, tenantID string, co *azcore.ClientOptions,
return public.New(clientID, o...)
}

// resolveAdditionallyAllowedTenants returns a copy of tenants, simplified when tenants contains a wildcard
func resolveAdditionallyAllowedTenants(tenants []string) []string {
if len(tenants) == 0 {
return nil
}
for _, t := range tenants {
// a wildcard makes all other values redundant
if t == "*" {
return []string{"*"}
}
}
cp := make([]string, len(tenants))
copy(cp, tenants)
return cp
}

// resolveTenant returns the correct tenant for a token request given a credential's configuration
func resolveTenant(defaultTenant, reqTenant string, allowedTenants []string) (string, error) {
if reqTenant == "" || reqTenant == defaultTenant {
return defaultTenant, nil
}
if defaultTenant == "adfs" {
return "", errors.New("ADFS doesn't support tenants")
}
if !validTenantID(reqTenant) {
return "", errors.New(tenantIDValidationErr)
}
for _, tenant := range allowedTenants {
if tenant == "*" || tenant == reqTenant {
return reqTenant, nil
}
}
return "", fmt.Errorf(`this credential isn't configured to acquire tokens for tenant "%s". To enable acquiring tokens for this tenant add it to the AdditionallyAllowedTenants on the credential options, or add "*" to allow acquiring tokens for any tenant`, reqTenant)
}

// setAuthorityHost initializes the authority host for credentials. Precedence is:
// 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user
// 2. value of AZURE_AUTHORITY_HOST
Expand Down
51 changes: 1 addition & 50 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
AdditionallyAllowedTenants: test.allowed,
tokenProvider: func(ctx context.Context, resource, tenantID string) ([]byte, error) {
if tenantID != test.expected {
t.Fatalf(`unexpected tenantID "%s"`, tenantID)
t.Errorf(`unexpected tenantID "%s"`, tenantID)
}
return mockCLITokenProviderSuccess(ctx, resource, tenantID)
},
Expand Down Expand Up @@ -655,55 +655,6 @@ func TestClaims(t *testing.T) {
}
}

func TestResolveTenant(t *testing.T) {
defaultTenant := "default-tenant"
otherTenant := "other-tenant"
for _, test := range []struct {
allowed []string
expected, tenant string
expectError bool
}{
// no alternate tenant specified -> should get default
{expected: defaultTenant},
{allowed: []string{""}, expected: defaultTenant},
{allowed: []string{"*"}, expected: defaultTenant},
{allowed: []string{otherTenant}, expected: defaultTenant},

// alternate tenant specified and allowed -> should get that tenant
{allowed: []string{"*"}, expected: otherTenant, tenant: otherTenant},
{allowed: []string{otherTenant}, expected: otherTenant, tenant: otherTenant},
{allowed: []string{"not-" + otherTenant, otherTenant}, expected: otherTenant, tenant: otherTenant},
{allowed: []string{"not-" + otherTenant, "*"}, expected: otherTenant, tenant: otherTenant},

// invalid or not allowed tenant -> should get an error
{tenant: otherTenant, expectError: true},
{allowed: []string{""}, tenant: otherTenant, expectError: true},
{allowed: []string{defaultTenant}, tenant: otherTenant, expectError: true},
{tenant: badTenantID, expectError: true},
{allowed: []string{""}, tenant: badTenantID, expectError: true},
{allowed: []string{"*", badTenantID}, tenant: badTenantID, expectError: true},
{tenant: "invalid@tenant", expectError: true},
{tenant: "invalid/tenant", expectError: true},
{tenant: "invalid(tenant", expectError: true},
{tenant: "invalid:tenant", expectError: true},
} {
t.Run("", func(t *testing.T) {
tenant, err := resolveTenant(defaultTenant, test.tenant, test.allowed)
if err != nil {
if test.expectError {
return
}
t.Fatal(err)
} else if test.expectError {
t.Fatal("expected an error")
}
if tenant != test.expected {
t.Fatalf(`expected "%s", got "%s"`, test.expected, tenant)
}
})
}
}

// ==================================================================================================================================

type fakeConfidentialClient struct {
Expand Down
35 changes: 16 additions & 19 deletions sdk/azidentity/azure_cli_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
)

const credNameAzureCLI = "AzureCLICredential"
const (
credNameAzureCLI = "AzureCLICredential"
timeoutCLIRequest = 10 * time.Second
)

// used by tests to fake invoking the CLI
type azureCLITokenProvider func(ctx context.Context, resource string, tenantID string) ([]byte, error)
Expand All @@ -50,9 +53,8 @@ func (o *AzureCLICredentialOptions) init() {

// AzureCLICredential authenticates as the identity logged in to the Azure CLI.
type AzureCLICredential struct {
additionallyAllowedTenants []string
tenantID string
tokenProvider azureCLITokenProvider
s *syncer
tokenProvider azureCLITokenProvider
}

// NewAzureCLICredential constructs an AzureCLICredential. Pass nil to accept default options.
Expand All @@ -62,11 +64,9 @@ func NewAzureCLICredential(options *AzureCLICredentialOptions) (*AzureCLICredent
cp = *options
}
cp.init()
return &AzureCLICredential{
additionallyAllowedTenants: resolveAdditionallyAllowedTenants(cp.AdditionallyAllowedTenants),
tenantID: cp.TenantID,
tokenProvider: cp.tokenProvider,
}, nil
c := AzureCLICredential{tokenProvider: cp.tokenProvider}
c.s = newSyncer(credNameAzureCLI, cp.TenantID, cp.AdditionallyAllowedTenants, c.requestToken, c.requestToken)
return &c, nil
}

// GetToken requests a token from the Azure CLI. This credential doesn't cache tokens, so every call invokes the CLI.
Expand All @@ -75,26 +75,23 @@ func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequ
if len(opts.Scopes) != 1 {
return azcore.AccessToken{}, errors.New(credNameAzureCLI + ": GetToken() requires exactly one scope")
}
tenant, err := resolveTenant(c.tenantID, opts.TenantID, c.additionallyAllowedTenants)
if err != nil {
return azcore.AccessToken{}, err
}
// CLI expects an AAD v1 resource, not a v2 resource
resource := strings.TrimSuffix(opts.Scopes[0], defaultSuffix)
b, err := c.tokenProvider(ctx, resource, tenant)
// CLI expects an AAD v1 resource, not a v2 scope
opts.Scopes = []string{strings.TrimSuffix(opts.Scopes[0], defaultSuffix)}
return c.s.GetToken(ctx, opts)
}

func (c *AzureCLICredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
b, err := c.tokenProvider(ctx, opts.Scopes[0], opts.TenantID)
if err != nil {
return azcore.AccessToken{}, err
}
at, err := c.createAccessToken(b)
if err != nil {
return azcore.AccessToken{}, err
}
logGetTokenSuccess(c, opts)
return at, nil
}

const timeoutCLIRequest = 10 * time.Second

func defaultTokenProvider() func(ctx context.Context, resource string, tenantID string) ([]byte, error) {
return func(ctx context.Context, resource string, tenantID string) ([]byte, error) {
match, err := regexp.MatchString("^[0-9a-zA-Z-.:/]+$", resource)
Expand Down
42 changes: 14 additions & 28 deletions sdk/azidentity/client_assertion_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@ const credNameAssertion = "ClientAssertionCredential"
//
// [Azure AD documentation]: https://docs.microsoft.com/azure/active-directory/develop/active-directory-certificate-credentials#assertion-format
type ClientAssertionCredential struct {
additionallyAllowedTenants []string
client confidentialClient
// name enables replacing "ClientAssertionCredential" with "WorkloadIdentityCredential" in log messages
name string
tenant string
client confidentialClient
s *syncer
}

// ClientAssertionCredentialOptions contains optional parameters for ClientAssertionCredential.
Expand Down Expand Up @@ -60,34 +57,23 @@ func NewClientAssertionCredential(tenantID, clientID string, getAssertion func(c
if err != nil {
return nil, err
}
return &ClientAssertionCredential{
additionallyAllowedTenants: resolveAdditionallyAllowedTenants(options.AdditionallyAllowedTenants),
client: c,
name: credNameAssertion,
tenant: tenantID,
}, nil
cac := ClientAssertionCredential{client: c}
cac.s = newSyncer(credNameAssertion, tenantID, options.AdditionallyAllowedTenants, cac.requestToken, cac.silentAuth)
return &cac, nil
}

// GetToken requests an access token from Azure Active Directory. This method is called automatically by Azure SDK clients.
func (c *ClientAssertionCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
if len(opts.Scopes) == 0 {
return azcore.AccessToken{}, errors.New(credNameAssertion + ": GetToken() requires at least one scope")
}
tenant, err := resolveTenant(c.tenant, opts.TenantID, c.additionallyAllowedTenants)
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err == nil {
logGetTokenSuccessImpl(c.name, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}
return c.s.GetToken(ctx, opts)
}

ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(c.name, err)
}
logGetTokenSuccessImpl(c.name, opts)
func (c *ClientAssertionCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID))
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

func (c *ClientAssertionCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
ar, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID))
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

Expand Down
39 changes: 14 additions & 25 deletions sdk/azidentity/client_certificate_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ type ClientCertificateCredentialOptions struct {

// ClientCertificateCredential authenticates a service principal with a certificate.
type ClientCertificateCredential struct {
additionallyAllowedTenants []string
client confidentialClient
tenant string
client confidentialClient
s *syncer
}

// NewClientCertificateCredential constructs a ClientCertificateCredential. Pass nil for options to accept defaults.
Expand All @@ -65,33 +64,23 @@ func NewClientCertificateCredential(tenantID string, clientID string, certs []*x
if err != nil {
return nil, err
}
return &ClientCertificateCredential{
additionallyAllowedTenants: resolveAdditionallyAllowedTenants(options.AdditionallyAllowedTenants),
client: c,
tenant: tenantID,
}, nil
cc := ClientCertificateCredential{client: c}
cc.s = newSyncer(credNameCert, tenantID, options.AdditionallyAllowedTenants, cc.requestToken, cc.silentAuth)
return &cc, nil
}

// GetToken requests an access token from Azure Active Directory. This method is called automatically by Azure SDK clients.
func (c *ClientCertificateCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
if len(opts.Scopes) == 0 {
return azcore.AccessToken{}, errors.New(credNameCert + ": GetToken() requires at least one scope")
}
tenant, err := resolveTenant(c.tenant, opts.TenantID, c.additionallyAllowedTenants)
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err == nil {
logGetTokenSuccess(c, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}
return c.s.GetToken(ctx, opts)
}

ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameCert, err)
}
logGetTokenSuccess(c, opts)
func (c *ClientCertificateCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID))
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

func (c *ClientCertificateCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
ar, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID))
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/client_certificate_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ func TestClientCertificateCredential_InvalidCertLive(t *testing.T) {
t.Fatalf("expected AuthenticationFailedError, received %T", err)
}
if !strings.HasPrefix(err.Error(), credNameCert) {
t.Fatal("missing credential type prefix")
t.Fatalf("error is missing credential type prefix: %q", err.Error())
}
}

Expand Down
40 changes: 14 additions & 26 deletions sdk/azidentity/client_secret_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package azidentity

import (
"context"
"errors"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand All @@ -31,9 +30,8 @@ type ClientSecretCredentialOptions struct {

// ClientSecretCredential authenticates an application with a client secret.
type ClientSecretCredential struct {
additionallyAllowedTenants []string
client confidentialClient
tenant string
client confidentialClient
s *syncer
}

// NewClientSecretCredential constructs a ClientSecretCredential. Pass nil for options to accept defaults.
Expand All @@ -49,33 +47,23 @@ func NewClientSecretCredential(tenantID string, clientID string, clientSecret st
if err != nil {
return nil, err
}
return &ClientSecretCredential{
additionallyAllowedTenants: resolveAdditionallyAllowedTenants(options.AdditionallyAllowedTenants),
client: c,
tenant: tenantID,
}, nil
csc := ClientSecretCredential{client: c}
csc.s = newSyncer(credNameSecret, tenantID, options.AdditionallyAllowedTenants, csc.requestToken, csc.silentAuth)
return &csc, nil
}

// GetToken requests an access token from Azure Active Directory. This method is called automatically by Azure SDK clients.
func (c *ClientSecretCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
if len(opts.Scopes) == 0 {
return azcore.AccessToken{}, errors.New(credNameSecret + ": GetToken() requires at least one scope")
}
tenant, err := resolveTenant(c.tenant, opts.TenantID, c.additionallyAllowedTenants)
if err != nil {
return azcore.AccessToken{}, err
}
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err == nil {
logGetTokenSuccess(c, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}
return c.s.GetToken(ctx, opts)
}

ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant))
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameSecret, err)
}
logGetTokenSuccess(c, opts)
func (c *ClientSecretCredential) silentAuth(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID))
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

func (c *ClientSecretCredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
ar, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(opts.TenantID))
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}

Expand Down
1 change: 0 additions & 1 deletion sdk/azidentity/default_azure_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default
additionalTenants = strings.Split(tenants, ";")
}
}
additionalTenants = resolveAdditionallyAllowedTenants(additionalTenants)

envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{
ClientOptions: options.ClientOptions, DisableInstanceDiscovery: options.DisableInstanceDiscovery, additionallyAllowedTenants: additionalTenants},
Expand Down
Loading