Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/azcore/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
* Added `BearerTokenOptions.AuthorizationHandler` to enable extending `runtime.BearerTokenPolicy`
with custom authorization logic
* Added `Client` types and matching constructors to the `azcore` and `arm` packages. These represent a basic client for HTTP and ARM respectively.
* Added support for ARM cross-tenant authentication. Set the `AuxiliaryTenants` field of `arm.ClientOptions` to enable.
* Added `TenantID` field to `policy.TokenRequestOptions`.

### Breaking Changes

Expand Down
12 changes: 12 additions & 0 deletions sdk/azcore/arm/policy/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ import (

// BearerTokenOptions configures the bearer token policy's behavior.
type BearerTokenOptions struct {
// AuxiliaryTenants are additional tenant IDs for authenticating cross-tenant requests.
// The policy will add a token from each of these tenants to every request. The
// authenticating user or service principal must be a guest in these tenants, and the
// policy's credential must support multitenant authentication.
AuxiliaryTenants []string

// Scopes contains the list of permission scopes required for the token.
Scopes []string
}
Expand Down Expand Up @@ -44,6 +50,12 @@ type RegistrationOptions struct {
type ClientOptions struct {
policy.ClientOptions

// AuxiliaryTenants are additional tenant IDs for authenticating cross-tenant requests.
// The client will add a token from each of these tenants to every request. The
// authenticating user or service principal must be a guest in these tenants, and the
// client's credential must support multitenant authentication.
AuxiliaryTenants []string

// DisableRPRegistration disables the auto-RP registration policy. Defaults to false.
DisableRPRegistration bool
}
5 changes: 4 additions & 1 deletion sdk/azcore/arm/runtime/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, plOpts azr
if err != nil {
return azruntime.Pipeline{}, err
}
authPolicy := NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{Scopes: []string{conf.Audience + "/.default"}})
authPolicy := NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{
AuxiliaryTenants: options.AuxiliaryTenants,
Scopes: []string{conf.Audience + "/.default"},
})
perRetry := make([]azpolicy.Policy, 0, len(plOpts.PerRetry)+1)
copy(perRetry, plOpts.PerRetry)
plOpts.PerRetry = append(perRetry, authPolicy)
Expand Down
63 changes: 39 additions & 24 deletions sdk/azcore/arm/runtime/policy_bearer_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ import (
armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared"
azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/temporal"
)

const headerAuxiliaryAuthorization = "x-ms-authorization-auxiliary"

// acquiringResourceState holds data for an auxiliary token request
type acquiringResourceState struct {
ctx context.Context
p *BearerTokenPolicy
Expand All @@ -26,7 +30,10 @@ type acquiringResourceState struct {
// acquire acquires or updates the resource; only one
// thread/goroutine at a time ever calls this function
func acquire(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) {
tk, err := state.p.cred.GetToken(state.ctx, azpolicy.TokenRequestOptions{Scopes: state.p.options.Scopes})
tk, err := state.p.cred.GetToken(state.ctx, azpolicy.TokenRequestOptions{
Scopes: state.p.scopes,
TenantID: state.tenant,
})
if err != nil {
return azcore.AccessToken{}, time.Time{}, err
}
Expand All @@ -35,13 +42,10 @@ func acquire(state acquiringResourceState) (newResource azcore.AccessToken, newE

// BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential.
type BearerTokenPolicy struct {
// mainResource is the resource to be retreived using the tenant specified in the credential
mainResource *temporal.Resource[azcore.AccessToken, acquiringResourceState]
// auxResources are additional resources that are required for cross-tenant applications
auxResources map[string]*temporal.Resource[azcore.AccessToken, acquiringResourceState]
// the following fields are read-only
cred azcore.TokenCredential
options armpolicy.BearerTokenOptions
btp *azruntime.BearerTokenPolicy
cred azcore.TokenCredential
scopes []string
}

// NewBearerTokenPolicy creates a policy object that authorizes requests with bearer tokens.
Expand All @@ -51,36 +55,47 @@ func NewBearerTokenPolicy(cred azcore.TokenCredential, opts *armpolicy.BearerTok
if opts == nil {
opts = &armpolicy.BearerTokenOptions{}
}
p := &BearerTokenPolicy{
cred: cred,
options: *opts,
mainResource: temporal.NewResource(acquire),
p := &BearerTokenPolicy{cred: cred}
p.auxResources = make(map[string]*temporal.Resource[azcore.AccessToken, acquiringResourceState], len(opts.AuxiliaryTenants))
for _, t := range opts.AuxiliaryTenants {
p.auxResources[t] = temporal.NewResource(acquire)
}
p.scopes = make([]string, len(opts.Scopes))
copy(p.scopes, opts.Scopes)
p.btp = azruntime.NewBearerTokenPolicy(cred, opts.Scopes, &azpolicy.BearerTokenOptions{
AuthorizationHandler: azpolicy.AuthorizationHandler{
OnRequest: p.onRequest,
},
})
return p
}

// Do authorizes a request with a bearer token
func (b *BearerTokenPolicy) Do(req *azpolicy.Request) (*http.Response, error) {
// onRequest authorizes requests with one or more bearer tokens
func (b *BearerTokenPolicy) onRequest(req *azpolicy.Request, authNZ func(azpolicy.TokenRequestOptions) error) error {
// authorize the request with a token for the primary tenant
err := authNZ(azpolicy.TokenRequestOptions{Scopes: b.scopes})
if err != nil || len(b.auxResources) == 0 {
return err
}
// add tokens for auxiliary tenants
as := acquiringResourceState{
ctx: req.Raw().Context(),
p: b,
}
tk, err := b.mainResource.Get(as)
if err != nil {
return nil, err
}
req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+tk.Token)
auxTokens := []string{}
auxTokens := make([]string, 0, len(b.auxResources))
for tenant, er := range b.auxResources {
as.tenant = tenant
auxTk, err := er.Get(as)
if err != nil {
return nil, err
return err
}
auxTokens = append(auxTokens, fmt.Sprintf("%s%s", shared.BearerTokenPrefix, auxTk.Token))
}
if len(auxTokens) > 0 {
req.Raw().Header.Set(shared.HeaderAuxiliaryAuthorization, strings.Join(auxTokens, ", "))
}
return req.Next()
req.Raw().Header.Set(headerAuxiliaryAuthorization, strings.Join(auxTokens, ", "))
return nil
}

// Do authorizes a request with a bearer token
func (b *BearerTokenPolicy) Do(req *azpolicy.Request) (*http.Response, error) {
return b.btp.Do(req)
}
84 changes: 41 additions & 43 deletions sdk/azcore/arm/runtime/policy_bearer_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ package runtime

import (
"context"
"errors"
"strings"

"errors"
"net/http"
"testing"
"time"
Expand All @@ -17,7 +17,9 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared"
azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/stretchr/testify/require"
)

const (
Expand Down Expand Up @@ -90,10 +92,10 @@ func TestBearerPolicy_SuccessGetToken(t *testing.T) {
func TestBearerPolicy_CredentialFailGetToken(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
expectedErr := errors.New("oops")
expectedErr := "oops"
failCredential := mockCredential{}
failCredential.getTokenImpl = func(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error) {
return azcore.AccessToken{}, expectedErr
return azcore.AccessToken{}, errors.New(expectedErr)
}
b := NewBearerTokenPolicy(failCredential, nil)
pipeline := newTestPipeline(&azpolicy.ClientOptions{
Expand All @@ -104,16 +106,11 @@ func TestBearerPolicy_CredentialFailGetToken(t *testing.T) {
PerRetryPolicies: []azpolicy.Policy{b},
})
req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
resp, err := pipeline.Do(req)
if err != expectedErr {
t.Fatalf("unexpected error: %v", err)
}
if resp != nil {
t.Fatal("expected nil response")
}
require.EqualError(t, err, expectedErr)
require.Nil(t, resp)
require.Implements(t, (*errorinfo.NonRetriable)(nil), err)
}

func TestBearerTokenPolicy_TokenExpired(t *testing.T) {
Expand Down Expand Up @@ -165,41 +162,42 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) {
}
}

func TestBearerTokenWithAuxiliaryTenants(t *testing.T) {
t.Skip("this feature isn't implemented yet")
func TestAuxiliaryTenants(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse()
retryOpts := azpolicy.RetryOptions{
MaxRetryDelay: 500 * time.Millisecond,
RetryDelay: 50 * time.Millisecond,
}
srv.SetResponse(mock.WithStatusCode(http.StatusOK))
primary := "primary"
auxTenants := []string{"aux1", "aux2", "aux3"}
expectCache := false
b := NewBearerTokenPolicy(
mockCredential{},
&armpolicy.BearerTokenOptions{
Scopes: []string{scope},
//AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"},
mockCredential{
// getTokenImpl returns a token whose value equals the requested tenant so the test can validate how the policy handles tenants
// i.e., primary tenant token goes in Authorization header and aux tenant tokens go in x-ms-authorization-auxiliary
getTokenImpl: func(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error) {
require.False(t, expectCache, "client should have used a cached token instead of requesting another")
tenant := primary
if options.TenantID != "" {
tenant = options.TenantID
}
return azcore.AccessToken{Token: tenant, ExpiresOn: time.Now().Add(time.Hour).UTC()}, nil
},
},
&armpolicy.BearerTokenOptions{AuxiliaryTenants: auxTenants, Scopes: []string{scope}},
)
pipeline := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []azpolicy.Policy{b}})
req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
resp, err := pipeline.Do(req)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
expectedHeader := strings.Repeat(shared.BearerTokenPrefix+tokenValue+", ", 3)
expectedHeader = expectedHeader[:len(expectedHeader)-2]
if auxH := resp.Request.Header.Get(shared.HeaderAuxiliaryAuthorization); auxH != expectedHeader {
t.Fatalf("unexpected auxiliary authorization header %s", auxH)
pipeline := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, PerRetryPolicies: []azpolicy.Policy{b}})
expected := strings.Split(shared.BearerTokenPrefix+strings.Join(auxTenants, ","+shared.BearerTokenPrefix), ",")
for i := 0; i < 3; i++ {
if i == 1 {
// policy should have a cached token after the first iteration
expectCache = true
}
req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL())
require.NoError(t, err)
resp, err := pipeline.Do(req)
require.NoError(t, err)
require.Equal(t, shared.BearerTokenPrefix+primary, resp.Request.Header.Get(shared.HeaderAuthorization), "Authorization header must contain primary tenant token")
actual := strings.Split(resp.Request.Header.Get(headerAuxiliaryAuthorization), ", ")
// auxiliary tokens may appear in arbitrary order
require.ElementsMatch(t, expected, actual)
}
}
4 changes: 4 additions & 0 deletions sdk/azcore/internal/exported/exported.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ type AccessToken struct {
type TokenRequestOptions struct {
// Scopes contains the list of permission scopes required for the token.
Scopes []string

// TenantID identifies the tenant from which to request the token. azidentity credentials authenticate in
// their configured default tenants when this field isn't set.
TenantID string
}

// TokenCredential represents a credential capable of providing an OAuth token.
Expand Down
17 changes: 8 additions & 9 deletions sdk/azcore/internal/shared/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ const (
)

const (
HeaderAuthorization = "Authorization"
HeaderAuxiliaryAuthorization = "x-ms-authorization-auxiliary"
HeaderAzureAsync = "Azure-AsyncOperation"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type"
HeaderLocation = "Location"
HeaderOperationLocation = "Operation-Location"
HeaderRetryAfter = "Retry-After"
HeaderUserAgent = "User-Agent"
HeaderAuthorization = "Authorization"
HeaderAzureAsync = "Azure-AsyncOperation"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type"
HeaderLocation = "Location"
HeaderOperationLocation = "Operation-Location"
HeaderRetryAfter = "Retry-After"
HeaderUserAgent = "User-Agent"
)

const BearerTokenPrefix = "Bearer "
Expand Down