diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index bbe90e194b59..007dcf14b951 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -9,11 +9,15 @@ * `runtime.NewPipeline` has a new signature that simplifies implementing custom authentication * `arm/runtime.RegistrationOptions` embeds `policy.ClientOptions` * Contents in the `log` package have been slightly renamed. +* Removed `AuthenticationOptions` in favor of `policy.BearerTokenOptions` +* Changed parameters for `NewBearerTokenPolicy()` +* Moved policy config options out of `arm/runtime` and into `arm/policy` ### Features Added * Updating Documentation * Added string typdef `arm.Endpoint` to provide a hint toward expected ARM client endpoints * `azcore.ClientOptions` contains common pipeline configuration settings +* Added support for multi-tenant authorization in `arm/runtime` ### Bug Fixes * Fixed a potential panic when creating the default Transporter. diff --git a/sdk/azcore/arm/policy/policy.go b/sdk/azcore/arm/policy/policy.go new file mode 100644 index 000000000000..f49dbc313282 --- /dev/null +++ b/sdk/azcore/arm/policy/policy.go @@ -0,0 +1,44 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package policy + +import ( + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +// BearerTokenOptions configures the bearer token policy's behavior. +type BearerTokenOptions struct { + // Scopes contains the list of permission scopes required for the token. + Scopes []string + // AuxiliaryTenants contains a list of additional tenant IDs to be used to authenticate + // in cross-tenant applications. + AuxiliaryTenants []string +} + +// RegistrationOptions configures the registration policy's behavior. +// All zero-value fields will be initialized with their default values. +type RegistrationOptions struct { + policy.ClientOptions + + // MaxAttempts is the total number of times to attempt automatic registration + // in the event that an attempt fails. + // The default value is 3. + // Set to a value less than zero to disable the policy. + MaxAttempts int + + // PollingDelay is the amount of time to sleep between polling intervals. + // The default value is 15 seconds. + // A value less than zero means no delay between polling intervals (not recommended). + PollingDelay time.Duration + + // PollingDuration is the amount of time to wait before abandoning polling. + // The default valule is 5 minutes. + // NOTE: Setting this to a small value might cause the policy to prematurely fail. + PollingDuration time.Duration +} diff --git a/sdk/azcore/arm/runtime/pipeline.go b/sdk/azcore/arm/runtime/pipeline.go index 655b36567904..cc1974d3f2b7 100644 --- a/sdk/azcore/arm/runtime/pipeline.go +++ b/sdk/azcore/arm/runtime/pipeline.go @@ -9,9 +9,10 @@ package runtime import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" ) @@ -25,19 +26,16 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, options *a if len(ep) == 0 { ep = arm.AzurePublicCloud } - perCallPolicies := []policy.Policy{} + perCallPolicies := []azpolicy.Policy{} if !options.DisableRPRegistration { - regRPOpts := RegistrationOptions{ClientOptions: options.ClientOptions} + regRPOpts := armpolicy.RegistrationOptions{ClientOptions: options.ClientOptions} perCallPolicies = append(perCallPolicies, NewRPRegistrationPolicy(string(ep), cred, ®RPOpts)) } - perRetryPolicies := []policy.Policy{ - azruntime.NewBearerTokenPolicy(cred, azruntime.AuthenticationOptions{ - TokenRequest: policy.TokenRequestOptions{ - Scopes: []string{shared.EndpointToScope(string(ep))}, - }, + perRetryPolicies := []azpolicy.Policy{ + NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{ + Scopes: []string{shared.EndpointToScope(string(ep))}, AuxiliaryTenants: options.AuxiliaryTenants, - }, - ), + }), } return azruntime.NewPipeline(module, version, perCallPolicies, perRetryPolicies, &options.ClientOptions) } diff --git a/sdk/azcore/arm/runtime/policy_bearer_token.go b/sdk/azcore/arm/runtime/policy_bearer_token.go new file mode 100644 index 000000000000..ada0405e8f3d --- /dev/null +++ b/sdk/azcore/arm/runtime/policy_bearer_token.go @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +type acquiringResourceState struct { + ctx context.Context + p *BearerTokenPolicy + tenant string +} + +// acquire acquires or updates the resource; only one +// thread/goroutine at a time ever calls this function +func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { + s := state.(acquiringResourceState) + tk, err := s.p.cred.GetToken(s.ctx, azpolicy.TokenRequestOptions{ + Scopes: s.p.options.Scopes, + TenantID: s.tenant, + }) + if err != nil { + return nil, time.Time{}, err + } + return tk, tk.ExpiresOn, nil +} + +// BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential. +type BearerTokenPolicy struct { + // mainResource is the resource to be retreived using the tenant specified in the credential + mainResource *shared.ExpiringResource + // auxResources are additional resources that are required for cross-tenant applications + auxResources map[string]*shared.ExpiringResource + // the following fields are read-only + cred azcore.TokenCredential + options armpolicy.BearerTokenOptions +} + +// NewBearerTokenPolicy creates a policy object that authorizes requests with bearer tokens. +// cred: an azcore.TokenCredential implementation such as a credential object from azidentity +// opts: optional settings. Pass nil to accept default values; this is the same as passing a zero-value options. +func NewBearerTokenPolicy(cred azcore.TokenCredential, opts *armpolicy.BearerTokenOptions) *BearerTokenPolicy { + if opts == nil { + opts = &armpolicy.BearerTokenOptions{} + } + p := &BearerTokenPolicy{ + cred: cred, + options: *opts, + mainResource: shared.NewExpiringResource(acquire), + } + if len(opts.AuxiliaryTenants) > 0 { + p.auxResources = map[string]*shared.ExpiringResource{} + } + for _, t := range opts.AuxiliaryTenants { + p.auxResources[t] = shared.NewExpiringResource(acquire) + + } + return p +} + +// Do authorizes a request with a bearer token +func (b *BearerTokenPolicy) Do(req *azpolicy.Request) (*http.Response, error) { + as := acquiringResourceState{ + ctx: req.Raw().Context(), + p: b, + } + tk, err := b.mainResource.GetResource(as) + if err != nil { + return nil, err + } + if token, ok := tk.(*azcore.AccessToken); ok { + req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+token.Token) + } + auxTokens := []string{} + for tenant, er := range b.auxResources { + as.tenant = tenant + auxTk, err := er.GetResource(as) + if err != nil { + return nil, err + } + auxTokens = append(auxTokens, fmt.Sprintf("%s%s", shared.BearerTokenPrefix, auxTk.(*azcore.AccessToken).Token)) + } + if len(auxTokens) > 0 { + req.Raw().Header.Set(shared.HeaderAuxiliaryAuthorization, strings.Join(auxTokens, ", ")) + } + return req.Next() +} diff --git a/sdk/azcore/arm/runtime/policy_bearer_token_test.go b/sdk/azcore/arm/runtime/policy_bearer_token_test.go new file mode 100644 index 000000000000..d17aa2813b9c --- /dev/null +++ b/sdk/azcore/arm/runtime/policy_bearer_token_test.go @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "strings" + + "errors" + "net/http" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" +) + +const ( + tokenValue = "***" + accessTokenRespSuccess = `{"access_token": "` + tokenValue + `", "expires_in": 3600}` + accessTokenRespShortLived = `{"access_token": "` + tokenValue + `", "expires_in": 0}` + scope = "scope" +) + +type mockCredential struct { + getTokenImpl func(ctx context.Context, options azpolicy.TokenRequestOptions) (*azcore.AccessToken, error) +} + +func (mc mockCredential) GetToken(ctx context.Context, options azpolicy.TokenRequestOptions) (*azcore.AccessToken, error) { + if mc.getTokenImpl != nil { + return mc.getTokenImpl(ctx, options) + } + return &azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil +} + +func (mc mockCredential) NewAuthenticationPolicy() azpolicy.Policy { + return mc +} + +func (mc mockCredential) Do(req *azpolicy.Request) (*http.Response, error) { + return nil, nil +} + +func newTestPipeline(opts *azpolicy.ClientOptions) pipeline.Pipeline { + return runtime.NewPipeline("testmodule", "v0.1.0", nil, nil, opts) +} + +func defaultTestPipeline(srv azpolicy.Transporter, scope string) pipeline.Pipeline { + retryOpts := azpolicy.RetryOptions{ + MaxRetryDelay: 500 * time.Millisecond, + RetryDelay: time.Millisecond, + } + return NewPipeline( + "testmodule", + "v0.1.0", + mockCredential{}, + &arm.ClientOptions{ + ClientOptions: azcore.ClientOptions{ + Retry: retryOpts, + Transport: srv, + }, + }) +} + +func TestBearerPolicy_SuccessGetToken(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + pipeline := defaultTestPipeline(srv, scope) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + resp, err := pipeline.Do(req) + if err != nil { + t.Fatalf("Expected nil error but received one") + } + const expectedToken = shared.BearerTokenPrefix + tokenValue + if token := resp.Request.Header.Get(shared.HeaderAuthorization); token != expectedToken { + t.Fatalf("expected token '%s', got '%s'", expectedToken, token) + } +} + +func TestBearerPolicy_CredentialFailGetToken(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + expectedErr := errors.New("oops") + failCredential := mockCredential{} + failCredential.getTokenImpl = func(ctx context.Context, options azpolicy.TokenRequestOptions) (*azcore.AccessToken, error) { + return nil, expectedErr + } + b := NewBearerTokenPolicy(failCredential, nil) + pipeline := newTestPipeline(&azpolicy.ClientOptions{ + Transport: srv, + Retry: azpolicy.RetryOptions{ + RetryDelay: 10 * time.Millisecond, + }, + PerRetryPolicies: []azpolicy.Policy{b}, + }) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + resp, err := pipeline.Do(req) + if err != expectedErr { + t.Fatalf("unexpected error: %v", err) + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestBearerTokenPolicy_TokenExpired(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespShortLived))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + pipeline := defaultTestPipeline(srv, scope) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + _, err = pipeline.Do(req) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + _, err = pipeline.Do(req) + if err != nil { + t.Fatalf("unexpected error %v", err) + } +} + +func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + retryOpts := azpolicy.RetryOptions{ + // use a negative try timeout to trigger a deadline exceeded error causing GetToken() to fail + TryTimeout: -1 * time.Nanosecond, + MaxRetryDelay: 500 * time.Millisecond, + RetryDelay: 50 * time.Millisecond, + MaxRetries: 3, + } + b := NewBearerTokenPolicy(mockCredential{}, nil) + pipeline := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatal(err) + } + resp, err := pipeline.Do(req) + if err == nil { + t.Fatal("unexpected nil error") + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestBearerTokenWithAuxiliaryTenants(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse() + retryOpts := azpolicy.RetryOptions{ + MaxRetryDelay: 500 * time.Millisecond, + RetryDelay: 50 * time.Millisecond, + } + b := NewBearerTokenPolicy( + mockCredential{}, + &armpolicy.BearerTokenOptions{ + Scopes: []string{scope}, + AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"}, + }, + ) + pipeline := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) + req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + resp, err := pipeline.Do(req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } + expectedHeader := strings.Repeat(shared.BearerTokenPrefix+tokenValue+", ", 3) + expectedHeader = expectedHeader[:len(expectedHeader)-2] + if auxH := resp.Request.Header.Get(shared.HeaderAuxiliaryAuthorization); auxH != expectedHeader { + t.Fatalf("unexpected auxiliary authorization header %s", auxH) + } +} diff --git a/sdk/azcore/arm/runtime/policy_register_rp.go b/sdk/azcore/arm/runtime/policy_register_rp.go index 4aa4541a7624..f1a2a4233052 100644 --- a/sdk/azcore/arm/runtime/policy_register_rp.go +++ b/sdk/azcore/arm/runtime/policy_register_rp.go @@ -17,9 +17,10 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) @@ -30,30 +31,8 @@ const ( LogRPRegistration log.Event = "RPRegistration" ) -// RegistrationOptions configures the registration policy's behavior. -// All zero-value fields will be initialized with their default values. -type RegistrationOptions struct { - policy.ClientOptions - - // MaxAttempts is the total number of times to attempt automatic registration - // in the event that an attempt fails. - // The default value is 3. - // Set to a value less than zero to disable the policy. - MaxAttempts int - - // PollingDelay is the amount of time to sleep between polling intervals. - // The default value is 15 seconds. - // A value less than zero means no delay between polling intervals (not recommended). - PollingDelay time.Duration - - // PollingDuration is the amount of time to wait before abandoning polling. - // The default valule is 5 minutes. - // NOTE: Setting this to a small value might cause the policy to prematurely fail. - PollingDuration time.Duration -} - // init sets any default values -func (r *RegistrationOptions) init() { +func setDefaults(r *armpolicy.RegistrationOptions) { if r.MaxAttempts == 0 { r.MaxAttempts = 3 } else if r.MaxAttempts < 0 { @@ -73,28 +52,28 @@ func (r *RegistrationOptions) init() { // credentials and options. The policy controls if an unregistered resource provider should // automatically be registered. See https://aka.ms/rps-not-found for more information. // Pass nil to accept the default options; this is the same as passing a zero-value options. -func NewRPRegistrationPolicy(endpoint string, cred azcore.TokenCredential, o *RegistrationOptions) policy.Policy { +func NewRPRegistrationPolicy(endpoint string, cred azcore.TokenCredential, o *armpolicy.RegistrationOptions) azpolicy.Policy { if o == nil { - o = &RegistrationOptions{} + o = &armpolicy.RegistrationOptions{} } - authPolicy := runtime.NewBearerTokenPolicy(cred, runtime.AuthenticationOptions{TokenRequest: policy.TokenRequestOptions{Scopes: []string{shared.EndpointToScope(endpoint)}}}) + authPolicy := NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{Scopes: []string{shared.EndpointToScope(endpoint)}}) p := &rpRegistrationPolicy{ endpoint: endpoint, pipeline: runtime.NewPipeline(shared.Module, shared.Version, nil, []pipeline.Policy{authPolicy}, &o.ClientOptions), options: *o, } // init the copy - p.options.init() + setDefaults(&p.options) return p } type rpRegistrationPolicy struct { endpoint string pipeline pipeline.Pipeline - options RegistrationOptions + options armpolicy.RegistrationOptions } -func (r *rpRegistrationPolicy) Do(req *policy.Request) (*http.Response, error) { +func (r *rpRegistrationPolicy) Do(req *azpolicy.Request) (*http.Response, error) { if r.options.MaxAttempts == 0 { // policy is disabled return req.Next() @@ -250,7 +229,7 @@ func (client *providersOperations) Get(ctx context.Context, resourceProviderName } // getCreateRequest creates the Get request. -func (client *providersOperations) getCreateRequest(ctx context.Context, resourceProviderNamespace string) (*policy.Request, error) { +func (client *providersOperations) getCreateRequest(ctx context.Context, resourceProviderNamespace string) (*azpolicy.Request, error) { urlPath := "/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}" urlPath = strings.ReplaceAll(urlPath, "{resourceProviderNamespace}", url.PathEscape(resourceProviderNamespace)) urlPath = strings.ReplaceAll(urlPath, "{subscriptionId}", url.PathEscape(client.subID)) @@ -307,7 +286,7 @@ func (client *providersOperations) Register(ctx context.Context, resourceProvide } // registerCreateRequest creates the Register request. -func (client *providersOperations) registerCreateRequest(ctx context.Context, resourceProviderNamespace string) (*policy.Request, error) { +func (client *providersOperations) registerCreateRequest(ctx context.Context, resourceProviderNamespace string) (*azpolicy.Request, error) { urlPath := "/subscriptions/{subscriptionId}/providers/{resourceProviderNamespace}/register" urlPath = strings.ReplaceAll(urlPath, "{resourceProviderNamespace}", url.PathEscape(resourceProviderNamespace)) urlPath = strings.ReplaceAll(urlPath, "{subscriptionId}", url.PathEscape(client.subID)) diff --git a/sdk/azcore/arm/runtime/policy_register_rp_test.go b/sdk/azcore/arm/runtime/policy_register_rp_test.go index d3818536777d..780762f49b32 100644 --- a/sdk/azcore/arm/runtime/policy_register_rp_test.go +++ b/sdk/azcore/arm/runtime/policy_register_rp_test.go @@ -16,9 +16,10 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -57,11 +58,11 @@ const requestEndpoint = "/subscriptions/00000000-0000-0000-0000-000000000000/res func newTestRPRegistrationPipeline(srv *mock.Server) pipeline.Pipeline { opts := azcore.ClientOptions{Transport: srv} rp := NewRPRegistrationPolicy(srv.URL(), mockTokenCred{}, testRPRegistrationOptions(srv)) - return runtime.NewPipeline("test", "v0.1.0", []policy.Policy{rp}, nil, &opts) + return runtime.NewPipeline("test", "v0.1.0", []azpolicy.Policy{rp}, nil, &opts) } -func testRPRegistrationOptions(t policy.Transporter) *RegistrationOptions { - def := RegistrationOptions{} +func testRPRegistrationOptions(t azpolicy.Transporter) *armpolicy.RegistrationOptions { + def := armpolicy.RegistrationOptions{} def.Transport = t def.PollingDelay = 100 * time.Millisecond def.PollingDuration = 1 * time.Second @@ -70,13 +71,13 @@ func testRPRegistrationOptions(t policy.Transporter) *RegistrationOptions { type mockTokenCred struct{} -func (mockTokenCred) NewAuthenticationPolicy(runtime.AuthenticationOptions) policy.Policy { - return pipeline.PolicyFunc(func(req *policy.Request) (*http.Response, error) { +func (mockTokenCred) NewAuthenticationPolicy() azpolicy.Policy { + return pipeline.PolicyFunc(func(req *azpolicy.Request) (*http.Response, error) { return req.Next() }) } -func (mockTokenCred) GetToken(context.Context, policy.TokenRequestOptions) (*azcore.AccessToken, error) { +func (mockTokenCred) GetToken(context.Context, azpolicy.TokenRequestOptions) (*azcore.AccessToken, error) { return &azcore.AccessToken{ Token: "abc123", ExpiresOn: time.Now().Add(1 * time.Hour), @@ -294,7 +295,7 @@ func TestRPRegistrationPolicyCanCancel(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) // polling responses to Register() and Get(), in progress but slow so we have time to cancel srv.RepeatResponse(10, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp)), mock.WithSlowResponse(300*time.Millisecond)) - opts := RegistrationOptions{} + opts := armpolicy.RegistrationOptions{} opts.Transport = srv pl := newTestRPRegistrationPipeline(srv) // log only RP registration @@ -317,7 +318,7 @@ func TestRPRegistrationPolicyCanCancel(t *testing.T) { go func() { defer wg.Done() // create request and start pipeline - var req *policy.Request + var req *azpolicy.Request req, err = runtime.NewRequest(ctx, http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) if err != nil { return diff --git a/sdk/azcore/internal/shared/constants.go b/sdk/azcore/internal/shared/constants.go index 06c5d32fd03c..5103d87e0053 100644 --- a/sdk/azcore/internal/shared/constants.go +++ b/sdk/azcore/internal/shared/constants.go @@ -12,19 +12,24 @@ const ( ) const ( - HeaderAzureAsync = "Azure-AsyncOperation" - HeaderContentLength = "Content-Length" - HeaderContentType = "Content-Type" - HeaderLocation = "Location" - HeaderOperationLocation = "Operation-Location" - HeaderRetryAfter = "Retry-After" - HeaderUserAgent = "User-Agent" + HeaderAuthorization = "Authorization" + HeaderAuxiliaryAuthorization = "x-ms-authorization-auxiliary" + HeaderAzureAsync = "Azure-AsyncOperation" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderLocation = "Location" + HeaderOperationLocation = "Operation-Location" + HeaderRetryAfter = "Retry-After" + HeaderUserAgent = "User-Agent" + HeaderXmsDate = "x-ms-date" ) const ( DefaultMaxRetries = 3 ) +const BearerTokenPrefix = "Bearer " + const ( // Module is the name of the calling module used in telemetry data. Module = "azcore" diff --git a/sdk/azcore/internal/shared/expiring_resource.go b/sdk/azcore/internal/shared/expiring_resource.go new file mode 100644 index 000000000000..9f97ca9559ab --- /dev/null +++ b/sdk/azcore/internal/shared/expiring_resource.go @@ -0,0 +1,99 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +import ( + "sync" + "time" +) + +// AcquireResource abstracts a method for refreshing an expiring resource. +type AcquireResource func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) + +// ExpiringResource is a temporal resource (usually a credential), that requires periodic refreshing. +type ExpiringResource struct { + // cond is used to synchronize access to the shared resource embodied by the remaining fields + cond *sync.Cond + + // acquiring indicates that some thread/goroutine is in the process of acquiring/updating the resource + acquiring bool + + // resource contains the value of the shared resource + resource interface{} + + // expiration indicates when the shared resource expires; it is 0 if the resource was never acquired + expiration time.Time + + // acquireResource is the callback function that actually acquires the resource + acquireResource AcquireResource +} + +// NewExpiringResource creates a new ExpiringResource that uses the specified AcquireResource for refreshing. +func NewExpiringResource(ar AcquireResource) *ExpiringResource { + return &ExpiringResource{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar} +} + +// GetResource returns the underlying resource. +// If the resource is fresh, no refresh is performed. +func (er *ExpiringResource) GetResource(state interface{}) (interface{}, error) { + // If the resource is expiring within this time window, update it eagerly. + // This allows other threads/goroutines to keep running by using the not-yet-expired + // resource value while one thread/goroutine updates the resource. + const window = 2 * time.Minute // This example updates the resource 2 minutes prior to expiration + + now, acquire, resource := time.Now(), false, er.resource + // acquire exclusive lock + er.cond.L.Lock() + for { + if er.expiration.IsZero() || er.expiration.Before(now) { + // The resource was never acquired or has expired + if !er.acquiring { + // If another thread/goroutine is not acquiring/updating the resource, this thread/goroutine will do it + er.acquiring, acquire = true, true + break + } + // Getting here means that this thread/goroutine will wait for the updated resource + } else if er.expiration.Add(-window).Before(now) { + // The resource is valid but is expiring within the time window + if !er.acquiring { + // If another thread/goroutine is not acquiring/renewing the resource, this thread/goroutine will do it + er.acquiring, acquire = true, true + break + } + // This thread/goroutine will use the existing resource value while another updates it + resource = er.resource + break + } else { + // The resource is not close to expiring, this thread/goroutine should use its current value + resource = er.resource + break + } + // If we get here, wait for the new resource value to be acquired/updated + er.cond.Wait() + } + er.cond.L.Unlock() // Release the lock so no threads/goroutines are blocked + + var err error + if acquire { + // This thread/goroutine has been selected to acquire/update the resource + var expiration time.Time + resource, expiration, err = er.acquireResource(state) + + // Atomically, update the shared resource's new value & expiration. + er.cond.L.Lock() + if err == nil { + // No error, update resource & expiration + er.resource, er.expiration = resource, expiration + } + er.acquiring = false // Indicate that no thread/goroutine is currently acquiring the resrouce + + // Wake up any waiting threads/goroutines since there is a resource they can ALL use + er.cond.L.Unlock() + er.cond.Broadcast() + } + return resource, err // Return the resource this thread/goroutine can use +} diff --git a/sdk/azcore/internal/shared/expiring_resource_test.go b/sdk/azcore/internal/shared/expiring_resource_test.go new file mode 100644 index 000000000000..9ccc244ced01 --- /dev/null +++ b/sdk/azcore/internal/shared/expiring_resource_test.go @@ -0,0 +1,48 @@ +//go:build go1.16 +// +build go1.16 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package shared + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewExpiringResource(t *testing.T) { + er := NewExpiringResource(func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { + s := state.(string) + switch s { + case "initial": + return "updated", time.Now(), nil + case "updated": + return "refreshed", time.Now().Add(1 * time.Hour), nil + default: + t.Fatalf("unexpected state %s", s) + return "", time.Time{}, errors.New("unexpected") + } + }) + res, err := er.GetResource("initial") + require.NoError(t, err) + require.Equal(t, "updated", res) + res, err = er.GetResource(res) + require.NoError(t, err) + require.Equal(t, "refreshed", res) + res, err = er.GetResource(res) + require.NoError(t, err) + require.Equal(t, "refreshed", res) +} + +func TestNewExpiringResourceError(t *testing.T) { + er := NewExpiringResource(func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { + return "", time.Time{}, errors.New("failed") + }) + res, err := er.GetResource("stale") + require.Error(t, err) + require.Equal(t, "", res) +} diff --git a/sdk/azcore/policy/policy.go b/sdk/azcore/policy/policy.go index f6968794defd..d739109c323b 100644 --- a/sdk/azcore/policy/policy.go +++ b/sdk/azcore/policy/policy.go @@ -106,6 +106,11 @@ type TokenRequestOptions struct { TenantID string } +// BearerTokenOptions configures the bearer token policy's behavior. +type BearerTokenOptions struct { + // placeholder for future options +} + // WithHTTPHeader adds the specified http.Header to the parent context. // Use this to specify custom HTTP headers at the API-call level. // Any overlapping headers will have their values replaced with the values specified here. diff --git a/sdk/azcore/runtime/policy_bearer_token.go b/sdk/azcore/runtime/policy_bearer_token.go index 2c5a22b6c401..d5ed61e14864 100644 --- a/sdk/azcore/runtime/policy_bearer_token.go +++ b/sdk/azcore/runtime/policy_bearer_token.go @@ -4,155 +4,55 @@ package runtime import ( - "fmt" "net/http" - "strings" - "sync" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) -const ( - bearerTokenPrefix = "Bearer " - headerXmsDate = "x-ms-date" - headerAuthorization = "Authorization" - headerAuxiliaryAuthorization = "x-ms-authorization-auxiliary" -) - // BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential. type BearerTokenPolicy struct { // mainResource is the resource to be retreived using the tenant specified in the credential - mainResource *expiringResource - // auxResources are additional resources that are required for cross-tenant applications - auxResources map[string]*expiringResource + mainResource *shared.ExpiringResource // the following fields are read-only - cred azcore.TokenCredential - options policy.TokenRequestOptions -} - -type expiringResource struct { - // cond is used to synchronize access to the shared resource embodied by the remaining fields - cond *sync.Cond - - // acquiring indicates that some thread/goroutine is in the process of acquiring/updating the resource - acquiring bool - - // resource contains the value of the shared resource - resource interface{} - - // expiration indicates when the shared resource expires; it is 0 if the resource was never acquired - expiration time.Time - - // acquireResource is the callback function that actually acquires the resource - acquireResource acquireResource + cred azcore.TokenCredential + scopes []string } -type acquireResource func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) - type acquiringResourceState struct { req *policy.Request - p BearerTokenPolicy + p *BearerTokenPolicy } // acquire acquires or updates the resource; only one // thread/goroutine at a time ever calls this function func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { s := state.(acquiringResourceState) - tk, err := s.p.cred.GetToken(s.req.Raw().Context(), s.p.options) + tk, err := s.p.cred.GetToken(s.req.Raw().Context(), policy.TokenRequestOptions{Scopes: s.p.scopes}) if err != nil { return nil, time.Time{}, err } return tk, tk.ExpiresOn, nil } -func newExpiringResource(ar acquireResource) *expiringResource { - return &expiringResource{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar} -} - -func (er *expiringResource) GetResource(state interface{}) (interface{}, error) { - // If the resource is expiring within this time window, update it eagerly. - // This allows other threads/goroutines to keep running by using the not-yet-expired - // resource value while one thread/goroutine updates the resource. - const window = 2 * time.Minute // This example updates the resource 2 minutes prior to expiration - - now, acquire, resource := time.Now(), false, er.resource - // acquire exclusive lock - er.cond.L.Lock() - for { - if er.expiration.IsZero() || er.expiration.Before(now) { - // The resource was never acquired or has expired - if !er.acquiring { - // If another thread/goroutine is not acquiring/updating the resource, this thread/goroutine will do it - er.acquiring, acquire = true, true - break - } - // Getting here means that this thread/goroutine will wait for the updated resource - } else if er.expiration.Add(-window).Before(now) { - // The resource is valid but is expiring within the time window - if !er.acquiring { - // If another thread/goroutine is not acquiring/renewing the resource, this thread/goroutine will do it - er.acquiring, acquire = true, true - break - } - // This thread/goroutine will use the existing resource value while another updates it - resource = er.resource - break - } else { - // The resource is not close to expiring, this thread/goroutine should use its current value - resource = er.resource - break - } - // If we get here, wait for the new resource value to be acquired/updated - er.cond.Wait() - } - er.cond.L.Unlock() // Release the lock so no threads/goroutines are blocked - - var err error - if acquire { - // This thread/goroutine has been selected to acquire/update the resource - var expiration time.Time - resource, expiration, err = er.acquireResource(state) - - // Atomically, update the shared resource's new value & expiration. - er.cond.L.Lock() - if err == nil { - // No error, update resource & expiration - er.resource, er.expiration = resource, expiration - } - er.acquiring = false // Indicate that no thread/goroutine is currently acquiring the resrouce - - // Wake up any waiting threads/goroutines since there is a resource they can ALL use - er.cond.L.Unlock() - er.cond.Broadcast() - } - return resource, err // Return the resource this thread/goroutine can use -} - // NewBearerTokenPolicy creates a policy object that authorizes requests with bearer tokens. // cred: an azcore.TokenCredential implementation such as a credential object from azidentity +// scopes: the list of permission scopes required for the token. // opts: optional settings. Pass nil to accept default values; this is the same as passing a zero-value options. -func NewBearerTokenPolicy(cred azcore.TokenCredential, opts AuthenticationOptions) *BearerTokenPolicy { - p := &BearerTokenPolicy{ +func NewBearerTokenPolicy(cred azcore.TokenCredential, scopes []string, opts *policy.BearerTokenOptions) *BearerTokenPolicy { + return &BearerTokenPolicy{ cred: cred, - options: opts.TokenRequest, - mainResource: newExpiringResource(acquire), - } - if len(opts.AuxiliaryTenants) > 0 { - p.auxResources = map[string]*expiringResource{} - } - for _, t := range opts.AuxiliaryTenants { - p.auxResources[t] = newExpiringResource(acquire) - + scopes: scopes, + mainResource: shared.NewExpiringResource(acquire), } - return p } // Do authorizes a request with a bearer token func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { as := acquiringResourceState{ - p: *b, + p: b, req: req, } tk, err := b.mainResource.GetResource(as) @@ -160,25 +60,7 @@ func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { return nil, err } if token, ok := tk.(*azcore.AccessToken); ok { - req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat)) - req.Raw().Header.Set(headerAuthorization, bearerTokenPrefix+token.Token) - } - auxTokens := []string{} - for tenant, er := range b.auxResources { - bCopy := *b - bCopy.options.TenantID = tenant - auxAS := acquiringResourceState{ - p: bCopy, - req: req, - } - auxTk, err := er.GetResource(auxAS) - if err != nil { - return nil, err - } - auxTokens = append(auxTokens, fmt.Sprintf("%s%s", bearerTokenPrefix, auxTk.(*azcore.AccessToken).Token)) - } - if len(auxTokens) > 0 { - req.Raw().Header.Set(headerAuxiliaryAuthorization, strings.Join(auxTokens, ", ")) + req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+token.Token) } return req.Next() } diff --git a/sdk/azcore/runtime/policy_bearer_token_test.go b/sdk/azcore/runtime/policy_bearer_token_test.go index 36440cbff8dd..02f9dd3a74e7 100644 --- a/sdk/azcore/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/runtime/policy_bearer_token_test.go @@ -5,7 +5,6 @@ package runtime import ( "context" - "strings" "errors" "net/http" @@ -14,6 +13,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -36,7 +36,7 @@ func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenReque return &azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil } -func (mc mockCredential) NewAuthenticationPolicy(options AuthenticationOptions) policy.Policy { +func (mc mockCredential) NewAuthenticationPolicy() policy.Policy { return mc } @@ -49,10 +49,7 @@ func defaultTestPipeline(srv policy.Transporter, scope string) Pipeline { MaxRetryDelay: 500 * time.Millisecond, RetryDelay: time.Millisecond, } - b := NewBearerTokenPolicy( - mockCredential{}, - AuthenticationOptions{TokenRequest: policy.TokenRequestOptions{Scopes: []string{scope}}}, - ) + b := NewBearerTokenPolicy(mockCredential{}, []string{scope}, nil) return NewPipeline( "testmodule", "v0.1.0", @@ -76,8 +73,8 @@ func TestBearerPolicy_SuccessGetToken(t *testing.T) { if err != nil { t.Fatalf("Expected nil error but received one") } - const expectedToken = bearerTokenPrefix + tokenValue - if token := resp.Request.Header.Get(headerAuthorization); token != expectedToken { + const expectedToken = shared.BearerTokenPrefix + tokenValue + if token := resp.Request.Header.Get(shared.HeaderAuthorization); token != expectedToken { t.Fatalf("expected token '%s', got '%s'", expectedToken, token) } } @@ -90,7 +87,7 @@ func TestBearerPolicy_CredentialFailGetToken(t *testing.T) { failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { return nil, expectedErr } - b := NewBearerTokenPolicy(failCredential, AuthenticationOptions{}) + b := NewBearerTokenPolicy(failCredential, nil, nil) pipeline := newTestPipeline(&policy.ClientOptions{ Transport: srv, Retry: policy.RetryOptions{ @@ -142,7 +139,7 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { RetryDelay: 50 * time.Millisecond, MaxRetries: 3, } - b := NewBearerTokenPolicy(mockCredential{}, AuthenticationOptions{}) + b := NewBearerTokenPolicy(mockCredential{}, nil, nil) pipeline := newTestPipeline(&policy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { @@ -156,43 +153,3 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { t.Fatal("expected nil response") } } - -func TestBearerTokenWithAuxiliaryTenants(t *testing.T) { - srv, close := mock.NewTLSServer() - defer close() - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - srv.AppendResponse() - retryOpts := policy.RetryOptions{ - MaxRetryDelay: 500 * time.Millisecond, - RetryDelay: 50 * time.Millisecond, - } - b := NewBearerTokenPolicy( - mockCredential{}, - AuthenticationOptions{ - TokenRequest: policy.TokenRequestOptions{ - Scopes: []string{scope}, - }, - AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"}, - }, - ) - pipeline := newTestPipeline(&policy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) - req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - resp, err := pipeline.Do(req) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status code: %d", resp.StatusCode) - } - expectedHeader := strings.Repeat(bearerTokenPrefix+tokenValue+", ", 3) - expectedHeader = expectedHeader[:len(expectedHeader)-2] - if auxH := resp.Request.Header.Get(headerAuxiliaryAuthorization); auxH != expectedHeader { - t.Fatalf("unexpected auxiliary authorization header %s", auxH) - } -} diff --git a/sdk/azcore/runtime/transport_default_http_client.go b/sdk/azcore/runtime/transport_default_http_client.go index d8bb8643c2ae..f7f3ca9c14ed 100644 --- a/sdk/azcore/runtime/transport_default_http_client.go +++ b/sdk/azcore/runtime/transport_default_http_client.go @@ -11,8 +11,6 @@ import ( "net" "net/http" "time" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) var defaultHTTPClient *http.Client @@ -37,14 +35,3 @@ func init() { Transport: defaultTransport, } } - -// AuthenticationOptions contains various options used to create a credential policy. -type AuthenticationOptions struct { - // TokenRequest is a TokenRequestOptions that includes a scopes field which contains - // the list of OAuth2 authentication scopes used when requesting a token. - // This field is ignored for other forms of authentication (e.g. shared key). - TokenRequest policy.TokenRequestOptions - // AuxiliaryTenants contains a list of additional tenant IDs to be used to authenticate - // in cross-tenant applications. - AuxiliaryTenants []string -}