diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index cf3ee5fbfc8d..e2fe89610e65 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -3,6 +3,10 @@ ## 1.3.0-beta.3 (Unreleased) ### Features Added +* By default, credentials set client capability "CP1" to enable support for + [Continuous Access Evaluation (CAE)](https://docs.microsoft.com/azure/active-directory/develop/app-resilience-continuous-access-evaluation). + This indicates to Azure Active Directory that your application can handle CAE claims challenges. + You can disable this behavior by setting the environment variable "AZURE_IDENTITY_DISABLE_CP1" to "true". * `InteractiveBrowserCredentialOptions.LoginHint` enables pre-populating the login prompt with a username ([#15599](https://github.com/Azure/azure-sdk-for-go/pull/15599)) diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index f0deced24c60..5be975682348 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -45,6 +45,12 @@ const ( tenantIDValidationErr = "invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names" ) +var ( + // capability CP1 indicates the client application is capable of handling CAE claims challenges + cp1 = []string{"CP1"} + disableCP1 = strings.ToLower(os.Getenv("AZURE_IDENTITY_DISABLE_CP1")) == "true" +) + var getConfidentialClient = func(clientID, tenantID string, cred confidential.Credential, co *azcore.ClientOptions, additionalOpts ...confidential.Option) (confidentialClient, error) { if !validTenantID(tenantID) { return confidential.Client{}, errors.New(tenantIDValidationErr) @@ -58,6 +64,9 @@ var getConfidentialClient = func(clientID, tenantID string, cred confidential.Cr confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)), confidential.WithHTTPClient(newPipelineAdapter(co)), } + if !disableCP1 { + o = append(o, confidential.WithClientCapabilities(cp1)) + } o = append(o, additionalOpts...) if strings.ToLower(tenantID) == "adfs" { o = append(o, confidential.WithInstanceDiscovery(false)) @@ -73,11 +82,13 @@ var getPublicClient = func(clientID, tenantID string, co *azcore.ClientOptions, if err != nil { return public.Client{}, err } - o := []public.Option{ public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)), public.WithHTTPClient(newPipelineAdapter(co)), } + if !disableCP1 { + o = append(o, public.WithClientCapabilities(cp1)) + } o = append(o, additionalOpts...) if strings.ToLower(tenantID) == "adfs" { o = append(o, public.WithInstanceDiscovery(false)) diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index e1d9532292fd..721fdafeb52a 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -11,7 +11,6 @@ import ( "crypto/x509" "errors" "fmt" - "io" "net/http" "os" "strings" @@ -136,16 +135,20 @@ func getTenantDiscoveryResponse(tenant string) []byte { func validateX5C(t *testing.T, certs []*x509.Certificate) mock.ResponsePredicate { return func(req *http.Request) bool { - body, err := io.ReadAll(req.Body) + err := req.ParseForm() if err != nil { - t.Fatal("Expected a request with the JWT in the body.") + t.Fatal("expected a form body") } - bodystr := string(body) - kvps := strings.Split(bodystr, "&") - assertion := strings.Split(kvps[0], "=") - token, _ := jwt.Parse(assertion[1], nil) + assertion, ok := req.PostForm["client_assertion"] + if !ok { + t.Fatal("expected a client_assertion field") + } + if len(assertion) != 1 { + t.Fatalf(`unexpected client_assertion "%v"`, assertion) + } + token, _ := jwt.Parse(assertion[0], nil) if token == nil { - t.Fatalf("Failed to parse the JWT token: %s.", assertion[1]) + t.Fatalf("failed to parse the assertion: %s", assertion) } if v, ok := token.Header["x5c"].([]any); !ok { t.Fatal("missing x5c header") @@ -525,6 +528,107 @@ func TestAdditionallyAllowedTenants(t *testing.T) { } } +func TestClaims(t *testing.T) { + realCP1 := disableCP1 + t.Cleanup(func() { disableCP1 = realCP1 }) + claim := `"test":"pass"` + for _, test := range []struct { + ctor func(azcore.ClientOptions) (azcore.TokenCredential, error) + name string + }{ + { + name: credNameAssertion, + ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) { + o := ClientAssertionCredentialOptions{ClientOptions: co} + return NewClientAssertionCredential(fakeTenantID, fakeClientID, func(context.Context) (string, error) { return "...", nil }, &o) + }, + }, + { + name: credNameCert, + ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) { + o := ClientCertificateCredentialOptions{ClientOptions: co} + return NewClientCertificateCredential(fakeTenantID, fakeClientID, allCertTests[0].certs, allCertTests[0].key, &o) + }, + }, + { + name: credNameDeviceCode, + ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) { + o := DeviceCodeCredentialOptions{ + ClientOptions: co, + UserPrompt: func(context.Context, DeviceCodeMessage) error { return nil }, + } + return NewDeviceCodeCredential(&o) + }, + }, + { + name: credNameOBO, + ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) { + o := OnBehalfOfCredentialOptions{ClientOptions: co} + return NewOnBehalfOfCredentialFromSecret(fakeTenantID, fakeClientID, "assertion", fakeSecret, &o) + }, + }, + { + name: credNameSecret, + ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) { + o := ClientSecretCredentialOptions{ClientOptions: co} + return NewClientSecretCredential(fakeTenantID, fakeClientID, fakeSecret, &o) + }, + }, + { + name: credNameUserPassword, + ctor: func(co azcore.ClientOptions) (azcore.TokenCredential, error) { + o := UsernamePasswordCredentialOptions{ClientOptions: co} + return NewUsernamePasswordCredential(fakeTenantID, fakeClientID, fakeUsername, "password", &o) + }, + }, + } { + for _, d := range []bool{true, false} { + name := test.name + if d { + name += " disableCP1" + } + t.Run(name, func(t *testing.T) { + disableCP1 = d + reqs := 0 + sts := mockSTS{ + tokenRequestCallback: func(r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Error(err) + } + reqs++ + // If the disableCP1 flag isn't set, both requests should specify CP1. The second + // GetToken call specifies claims we should find in the following token request. + // We check only for substrings because MSAL is responsible for formatting claims. + actual := fmt.Sprint(r.Form["claims"]) + if strings.Contains(actual, "CP1") == disableCP1 { + t.Fatalf(`unexpected claims "%v"`, actual) + } + if reqs == 2 { + if !strings.Contains(strings.ReplaceAll(actual, " ", ""), claim) { + t.Fatalf(`unexpected claims "%v"`, actual) + } + } + }, + } + o := azcore.ClientOptions{Transport: &sts} + cred, err := test.ctor(o) + if err != nil { + t.Fatal(err) + } + if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"A"}}); err != nil { + t.Fatal(err) + } + if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Claims: fmt.Sprintf("{%s}", claim), Scopes: []string{"B"}}); err != nil { + t.Fatal(err) + } + if reqs != 2 { + t.Fatalf("expected %d token requests, got %d", 2, reqs) + } + }) + } + } +} + func TestResolveTenant(t *testing.T) { defaultTenant := "default-tenant" otherTenant := "other-tenant" diff --git a/sdk/azidentity/client_assertion_credential.go b/sdk/azidentity/client_assertion_credential.go index 15bebaa16932..1bbcfc6a5be5 100644 --- a/sdk/azidentity/client_assertion_credential.go +++ b/sdk/azidentity/client_assertion_credential.go @@ -77,13 +77,13 @@ func (c *ClientAssertionCredential) GetToken(ctx context.Context, opts policy.To if err != nil { return azcore.AccessToken{}, err } - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithTenantID(tenant)) + ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant)) if err == nil { logGetTokenSuccessImpl(c.name, opts) return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err } - ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithTenantID(tenant)) + ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant)) if err != nil { return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(c.name, err) } diff --git a/sdk/azidentity/client_certificate_credential.go b/sdk/azidentity/client_certificate_credential.go index 4f0122f17944..16bed018b276 100644 --- a/sdk/azidentity/client_certificate_credential.go +++ b/sdk/azidentity/client_certificate_credential.go @@ -81,13 +81,13 @@ func (c *ClientCertificateCredential) GetToken(ctx context.Context, opts policy. if err != nil { return azcore.AccessToken{}, err } - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithTenantID(tenant)) + ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant)) if err == nil { logGetTokenSuccess(c, opts) return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err } - ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithTenantID(tenant)) + ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant)) if err != nil { return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameCert, err) } diff --git a/sdk/azidentity/client_secret_credential.go b/sdk/azidentity/client_secret_credential.go index bfdc353d26e8..c020f1d8cb89 100644 --- a/sdk/azidentity/client_secret_credential.go +++ b/sdk/azidentity/client_secret_credential.go @@ -65,13 +65,13 @@ func (c *ClientSecretCredential) GetToken(ctx context.Context, opts policy.Token if err != nil { return azcore.AccessToken{}, err } - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithTenantID(tenant)) + ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant)) if err == nil { logGetTokenSuccess(c, opts) return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err } - ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithTenantID(tenant)) + ar, err = c.client.AcquireTokenByCredential(ctx, opts.Scopes, confidential.WithClaims(opts.Claims), confidential.WithTenantID(tenant)) if err != nil { return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameSecret, err) } diff --git a/sdk/azidentity/device_code_credential.go b/sdk/azidentity/device_code_credential.go index 49c46a15528f..e7fb79c884ee 100644 --- a/sdk/azidentity/device_code_credential.go +++ b/sdk/azidentity/device_code_credential.go @@ -108,11 +108,15 @@ func (c *DeviceCodeCredential) GetToken(ctx context.Context, opts policy.TokenRe if err != nil { return azcore.AccessToken{}, err } - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, public.WithSilentAccount(c.account), public.WithTenantID(tenant)) + ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, + public.WithClaims(opts.Claims), + public.WithSilentAccount(c.account), + public.WithTenantID(tenant), + ) if err == nil { return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err } - dc, err := c.client.AcquireTokenByDeviceCode(ctx, opts.Scopes, public.WithTenantID(tenant)) + dc, err := c.client.AcquireTokenByDeviceCode(ctx, opts.Scopes, public.WithClaims(opts.Claims), public.WithTenantID(tenant)) if err != nil { return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameDeviceCode, err) } diff --git a/sdk/azidentity/interactive_browser_credential.go b/sdk/azidentity/interactive_browser_credential.go index b95f850fbe67..cf1771fb44b1 100644 --- a/sdk/azidentity/interactive_browser_credential.go +++ b/sdk/azidentity/interactive_browser_credential.go @@ -87,13 +87,18 @@ func (c *InteractiveBrowserCredential) GetToken(ctx context.Context, opts policy if err != nil { return azcore.AccessToken{}, err } - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, public.WithSilentAccount(c.account), public.WithTenantID(tenant)) + ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, + public.WithClaims(opts.Claims), + public.WithSilentAccount(c.account), + public.WithTenantID(tenant), + ) if err == nil { logGetTokenSuccess(c, opts) return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err } ar, err = c.client.AcquireTokenInteractive(ctx, opts.Scopes, + public.WithClaims(opts.Claims), public.WithLoginHint(c.options.LoginHint), public.WithRedirectURI(c.options.RedirectURL), public.WithTenantID(tenant), diff --git a/sdk/azidentity/on_behalf_of_credential.go b/sdk/azidentity/on_behalf_of_credential.go index ac020d1facca..a18e9adc4c3c 100644 --- a/sdk/azidentity/on_behalf_of_credential.go +++ b/sdk/azidentity/on_behalf_of_credential.go @@ -96,7 +96,10 @@ func (o *OnBehalfOfCredential) GetToken(ctx context.Context, opts policy.TokenRe if err != nil { return azcore.AccessToken{}, err } - ar, err := o.client.AcquireTokenOnBehalfOf(ctx, o.assertion, opts.Scopes, confidential.WithTenantID(tenant)) + ar, err := o.client.AcquireTokenOnBehalfOf(ctx, o.assertion, opts.Scopes, + confidential.WithClaims(opts.Claims), + confidential.WithTenantID(tenant), + ) if err != nil { return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameOBO, err) } diff --git a/sdk/azidentity/username_password_credential.go b/sdk/azidentity/username_password_credential.go index b4a84199c216..107aa9d20ef2 100644 --- a/sdk/azidentity/username_password_credential.go +++ b/sdk/azidentity/username_password_credential.go @@ -68,12 +68,16 @@ func (c *UsernamePasswordCredential) GetToken(ctx context.Context, opts policy.T if err != nil { return azcore.AccessToken{}, err } - ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, public.WithSilentAccount(c.account), public.WithTenantID(tenant)) + ar, err := c.client.AcquireTokenSilent(ctx, opts.Scopes, + public.WithClaims(opts.Claims), + public.WithSilentAccount(c.account), + public.WithTenantID(tenant), + ) if err == nil { logGetTokenSuccess(c, opts) return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err } - ar, err = c.client.AcquireTokenByUsernamePassword(ctx, opts.Scopes, c.username, c.password, public.WithTenantID(tenant)) + ar, err = c.client.AcquireTokenByUsernamePassword(ctx, opts.Scopes, c.username, c.password, public.WithClaims(opts.Claims), public.WithTenantID(tenant)) if err != nil { return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameUserPassword, err) }