Skip to content
5 changes: 5 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
## 1.4.0-beta.5 (Unreleased)

### Features Added
* Service principal credentials can request CAE tokens

### Breaking Changes
> These changes affect only code written against a beta version such as v1.4.0-beta.4
* Whether `GetToken` requests a CAE token is now determined by `TokenRequestOptions.EnableCAE`. Azure
SDK clients which support CAE will set this option automatically. Credentials no longer request CAE
tokens by default or observe the environment variable "AZURE_IDENTITY_DISABLE_CP1".

### Bugs Fixed

Expand Down
106 changes: 43 additions & 63 deletions sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"regexp"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
Expand All @@ -41,73 +41,18 @@ const (
organizationsTenantID = "organizations"
developerSignOnClientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
defaultSuffix = "/.default"
tenantIDValidationErr = "invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names"
)

var (
// capability CP1 indicates the client application is capable of handling CAE claims challenges
cp1 = []string{"CP1"}
disableCP1 = strings.ToLower(os.Getenv("AZURE_IDENTITY_DISABLE_CP1")) == "true"
cp1 = []string{"CP1"}
errInvalidTenantID = errors.New("invalid tenantID. You can locate your tenantID by following the instructions listed here: https://learn.microsoft.com/partner-center/find-ids-and-domain-names")
)

type msalClientOptions struct {
azcore.ClientOptions

DisableInstanceDiscovery bool
// SendX5C applies only to confidential clients authenticating with a cert
SendX5C bool
}

var getConfidentialClient = func(clientID, tenantID string, cred confidential.Credential, opts msalClientOptions) (confidentialClient, error) {
if !validTenantID(tenantID) {
return confidential.Client{}, errors.New(tenantIDValidationErr)
}
authorityHost, err := setAuthorityHost(opts.Cloud)
if err != nil {
return confidential.Client{}, err
}
authority := runtime.JoinPaths(authorityHost, tenantID)
o := []confidential.Option{
confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)),
confidential.WithHTTPClient(newPipelineAdapter(&opts.ClientOptions)),
}
if !disableCP1 {
o = append(o, confidential.WithClientCapabilities(cp1))
}
if opts.SendX5C {
o = append(o, confidential.WithX5C())
}
if opts.DisableInstanceDiscovery || strings.ToLower(tenantID) == "adfs" {
o = append(o, confidential.WithInstanceDiscovery(false))
}
return confidential.New(authority, clientID, cred, o...)
}

var getPublicClient = func(clientID, tenantID string, opts msalClientOptions) (public.Client, error) {
if !validTenantID(tenantID) {
return public.Client{}, errors.New(tenantIDValidationErr)
}
authorityHost, err := setAuthorityHost(opts.Cloud)
if err != nil {
return public.Client{}, err
}
o := []public.Option{
public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)),
public.WithHTTPClient(newPipelineAdapter(&opts.ClientOptions)),
}
if !disableCP1 {
o = append(o, public.WithClientCapabilities(cp1))
}
if opts.DisableInstanceDiscovery || strings.ToLower(tenantID) == "adfs" {
o = append(o, public.WithInstanceDiscovery(false))
}
return public.New(clientID, o...)
}

// setAuthorityHost initializes the authority host for credentials. Precedence is:
// 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user
// 2. value of AZURE_AUTHORITY_HOST
// 3. default: Azure Public Cloud
// 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user
// 2. value of AZURE_AUTHORITY_HOST
// 3. default: Azure Public Cloud
func setAuthorityHost(cc cloud.Configuration) (string, error) {
host := cc.ActiveDirectoryAuthorityHost
if host == "" {
Expand All @@ -129,6 +74,41 @@ func setAuthorityHost(cc cloud.Configuration) (string, error) {
return host, nil
}

// resolveAdditionalTenants returns a copy of tenants, simplified when tenants contains a wildcard
func resolveAdditionalTenants(tenants []string) []string {
if len(tenants) == 0 {
return nil
}
for _, t := range tenants {
// a wildcard makes all other values redundant
if t == "*" {
return []string{"*"}
}
}
cp := make([]string, len(tenants))
copy(cp, tenants)
return cp
}

// resolveTenant returns the correct tenant for a token request
func resolveTenant(defaultTenant, specified, credName string, additionalTenants []string) (string, error) {
if specified == "" || specified == defaultTenant {
return defaultTenant, nil
}
if defaultTenant == "adfs" {
return "", errors.New("ADFS doesn't support tenants")
}
if !validTenantID(specified) {
return "", errInvalidTenantID
}
for _, t := range additionalTenants {
if t == "*" || t == specified {
return specified, nil
}
}
return "", fmt.Errorf(`%s isn't configured to acquire tokens for tenant %q. To enable acquiring tokens for this tenant add it to the AdditionallyAllowedTenants on the credential options, or add "*" to allow acquiring tokens for any tenant`, credName, specified)
}

// validTenantID return true is it receives a valid tenantID, returns false otherwise
func validTenantID(tenantID string) bool {
match, err := regexp.MatchString("^[0-9a-zA-Z-.]+$", tenantID)
Expand Down Expand Up @@ -181,15 +161,15 @@ func (p pipelineAdapter) Do(r *http.Request) (*http.Response, error) {
}

// enables fakes for test scenarios
type confidentialClient interface {
type msalConfidentialClient interface {
AcquireTokenSilent(ctx context.Context, scopes []string, options ...confidential.AcquireSilentOption) (confidential.AuthResult, error)
AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, options ...confidential.AcquireByAuthCodeOption) (confidential.AuthResult, error)
AcquireTokenByCredential(ctx context.Context, scopes []string, options ...confidential.AcquireByCredentialOption) (confidential.AuthResult, error)
AcquireTokenOnBehalfOf(ctx context.Context, userAssertion string, scopes []string, options ...confidential.AcquireOnBehalfOfOption) (confidential.AuthResult, error)
}

// enables fakes for test scenarios
type publicClient interface {
type msalPublicClient interface {
AcquireTokenSilent(ctx context.Context, scopes []string, options ...public.AcquireSilentOption) (public.AuthResult, error)
AcquireTokenByUsernamePassword(ctx context.Context, scopes []string, username string, password string, options ...public.AcquireByUsernamePasswordOption) (public.AuthResult, error)
AcquireTokenByDeviceCode(ctx context.Context, scopes []string, options ...public.AcquireByDeviceCodeOption) (public.DeviceCode, error)
Expand Down
91 changes: 77 additions & 14 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
called := false
for _, source := range c.chain.sources {
if cli, ok := source.(*AzureCLICredential); ok {
cli.tokenProvider = func(ctx context.Context, resource, tenantID string) ([]byte, error) {
cli.opts.tokenProvider = func(ctx context.Context, resource, tenantID string) ([]byte, error) {
called = true
if tenantID != test.expected {
t.Fatalf(`unexpected tenantID "%s"`, tenantID)
Expand All @@ -446,8 +446,6 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
}

func TestClaims(t *testing.T) {
realCP1 := disableCP1
t.Cleanup(func() { disableCP1 = realCP1 })
claim := `"test":"pass"`
for _, test := range []struct {
ctor func(azcore.ClientOptions) (azcore.TokenCredential, error)
Expand Down Expand Up @@ -498,29 +496,39 @@ func TestClaims(t *testing.T) {
return NewUsernamePasswordCredential(fakeTenantID, fakeClientID, fakeUsername, "password", &o)
},
},
{
name: credNameWorkloadIdentity,
ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) {
tokenFile := filepath.Join(t.TempDir(), "token")
if err := os.WriteFile(tokenFile, []byte(tokenValue), os.ModePerm); err != nil {
t.Fatalf("failed to write token file: %v", err)
}
o := WorkloadIdentityCredentialOptions{ClientID: fakeClientID, ClientOptions: co, TenantID: fakeTenantID, TokenFilePath: tokenFile}
return NewWorkloadIdentityCredential(&o)
},
},
} {
for _, d := range []bool{true, false} {
for _, enableCAE := range []bool{true, false} {
name := test.name
if d {
name += " disableCP1"
if enableCAE {
name += " CAE"
}
t.Run(name, func(t *testing.T) {
disableCP1 = d
reqs := 0
sts := mockSTS{
tokenRequestCallback: func(r *http.Request) *http.Response {
if err := r.ParseForm(); err != nil {
t.Error(err)
}
reqs++
// If the disableCP1 flag isn't set, both requests should specify CP1. The second
// GetToken call specifies claims we should find in the following token request.
// Both requests should specify CP1 when CAE is enabled for the token.
// We check only for substrings because MSAL is responsible for formatting claims.
actual := fmt.Sprint(r.Form["claims"])
if strings.Contains(actual, "CP1") == disableCP1 {
if strings.Contains(actual, "CP1") != enableCAE {
t.Fatalf(`unexpected claims "%v"`, actual)
}
if reqs == 2 {
// the second GetToken call specifies claims we should find in the following token request
if !strings.Contains(strings.ReplaceAll(actual, " ", ""), claim) {
t.Fatalf(`unexpected claims "%v"`, actual)
}
Expand All @@ -533,10 +541,12 @@ func TestClaims(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"A"}}); err != nil {
tro := policy.TokenRequestOptions{EnableCAE: enableCAE, Scopes: []string{"A"}}
if _, err = cred.GetToken(context.Background(), tro); err != nil {
t.Fatal(err)
}
if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Claims: fmt.Sprintf("{%s}", claim), Scopes: []string{"B"}}); err != nil {
tro = policy.TokenRequestOptions{Claims: fmt.Sprintf("{%s}", claim), EnableCAE: enableCAE, Scopes: []string{"B"}}
if _, err = cred.GetToken(context.Background(), tro); err != nil {
t.Fatal(err)
}
if reqs != 2 {
Expand All @@ -547,6 +557,59 @@ func TestClaims(t *testing.T) {
}
}

func TestResolveTenant(t *testing.T) {
credName := "testcred"
defaultTenant := "default-tenant"
otherTenant := "other-tenant"
for _, test := range []struct {
allowed []string
expected, tenant string
expectError bool
}{
// no alternate tenant specified -> should get default
{expected: defaultTenant},
{allowed: []string{""}, expected: defaultTenant},
{allowed: []string{"*"}, expected: defaultTenant},
{allowed: []string{otherTenant}, expected: defaultTenant},

// alternate tenant specified and allowed -> should get that tenant
{allowed: []string{"*"}, expected: otherTenant, tenant: otherTenant},
{allowed: []string{otherTenant}, expected: otherTenant, tenant: otherTenant},
{allowed: []string{"not-" + otherTenant, otherTenant}, expected: otherTenant, tenant: otherTenant},
{allowed: []string{"not-" + otherTenant, "*"}, expected: otherTenant, tenant: otherTenant},

// invalid or not allowed tenant -> should get an error
{tenant: otherTenant, expectError: true},
{allowed: []string{""}, tenant: otherTenant, expectError: true},
{allowed: []string{defaultTenant}, tenant: otherTenant, expectError: true},
{tenant: badTenantID, expectError: true},
{allowed: []string{""}, tenant: badTenantID, expectError: true},
{allowed: []string{"*", badTenantID}, tenant: badTenantID, expectError: true},
{tenant: "invalid@tenant", expectError: true},
{tenant: "invalid/tenant", expectError: true},
{tenant: "invalid(tenant", expectError: true},
{tenant: "invalid:tenant", expectError: true},
} {
t.Run("", func(t *testing.T) {
tenant, err := resolveTenant(defaultTenant, test.tenant, credName, test.allowed)
if err != nil {
if test.expectError {
if validTenantID(test.tenant) && !strings.Contains(err.Error(), credName) {
t.Fatalf("expected error to contain %q, got %q", credName, err.Error())
}
return
}
t.Fatal(err)
} else if test.expectError {
t.Fatal("expected an error")
}
if tenant != test.expected {
t.Fatalf(`expected "%s", got "%s"`, test.expected, tenant)
}
})
}
}

// ==================================================================================================================================

type fakeConfidentialClient struct {
Expand Down Expand Up @@ -592,7 +655,7 @@ func (f fakeConfidentialClient) AcquireTokenOnBehalfOf(ctx context.Context, user
return f.returnResult()
}

var _ confidentialClient = (*fakeConfidentialClient)(nil)
var _ msalConfidentialClient = (*fakeConfidentialClient)(nil)

// ==================================================================================================================================

Expand Down Expand Up @@ -643,4 +706,4 @@ func (f fakePublicClient) AcquireTokenInteractive(ctx context.Context, scopes []
return f.returnResult()
}

var _ publicClient = (*fakePublicClient)(nil)
var _ msalPublicClient = (*fakePublicClient)(nil)
33 changes: 16 additions & 17 deletions sdk/azidentity/azure_cli_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ import (
"regexp"
"runtime"
"strings"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
)

const (
Expand Down Expand Up @@ -53,8 +55,8 @@ func (o *AzureCLICredentialOptions) init() {

// AzureCLICredential authenticates as the identity logged in to the Azure CLI.
type AzureCLICredential struct {
s *syncer
tokenProvider azureCLITokenProvider
mu *sync.Mutex
Comment thread
jhendrixMSFT marked this conversation as resolved.
opts AzureCLICredentialOptions
}

// NewAzureCLICredential constructs an AzureCLICredential. Pass nil to accept default options.
Expand All @@ -64,15 +66,8 @@ func NewAzureCLICredential(options *AzureCLICredentialOptions) (*AzureCLICredent
cp = *options
}
cp.init()
c := AzureCLICredential{tokenProvider: cp.tokenProvider}
c.s = newSyncer(
credNameAzureCLI,
cp.TenantID,
c.requestToken,
nil, // this credential doesn't have a silent auth method because the CLI handles caching
syncerOptions{AdditionallyAllowedTenants: cp.AdditionallyAllowedTenants},
)
return &c, nil
cp.AdditionallyAllowedTenants = resolveAdditionalTenants(cp.AdditionallyAllowedTenants)
return &AzureCLICredential{mu: &sync.Mutex{}, opts: cp}, nil
}

// GetToken requests a token from the Azure CLI. This credential doesn't cache tokens, so every call invokes the CLI.
Expand All @@ -81,20 +76,24 @@ func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequ
if len(opts.Scopes) != 1 {
return azcore.AccessToken{}, errors.New(credNameAzureCLI + ": GetToken() requires exactly one scope")
}
// CLI expects an AAD v1 resource, not a v2 scope
tenant, err := resolveTenant(c.opts.TenantID, opts.TenantID, credNameAzureCLI, c.opts.AdditionallyAllowedTenants)
if err != nil {
return azcore.AccessToken{}, err
}
// pass the CLI an AAD v1 resource because we don't know which CLI version is installed and older ones don't support v2 scopes
opts.Scopes = []string{strings.TrimSuffix(opts.Scopes[0], defaultSuffix)}
return c.s.GetToken(ctx, opts)
}

func (c *AzureCLICredential) requestToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
b, err := c.tokenProvider(ctx, opts.Scopes[0], opts.TenantID)
c.mu.Lock()
defer c.mu.Unlock()
b, err := c.opts.tokenProvider(ctx, opts.Scopes[0], tenant)
if err != nil {
return azcore.AccessToken{}, err
}
at, err := c.createAccessToken(b)
if err != nil {
return azcore.AccessToken{}, err
}
msg := fmt.Sprintf("%s.GetToken() acquired a token for scope %q", credNameAzureCLI, strings.Join(opts.Scopes, ", "))
log.Write(EventAuthentication, msg)
return at, nil
}

Expand Down
Loading