Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## 1.3.0-beta.2 (Unreleased)

### Features Added
* Added `OnBehalfOfCredential` to support the on-behalf-of flow
([#16642](https://github.com/Azure/azure-sdk-for-go/issues/16642))

### Breaking Changes

Expand Down
5 changes: 3 additions & 2 deletions sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ const (
tenantIDValidationErr = "invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names"
)

func getConfidentialClient(clientID, tenantID string, cred confidential.Credential, co *azcore.ClientOptions, additionalOpts ...confidential.Option) (confidential.Client, error) {
var getConfidentialClient = func(clientID, tenantID string, cred confidential.Credential, co *azcore.ClientOptions, additionalOpts ...confidential.Option) (confidentialClient, error) {
if !validTenantID(tenantID) {
return confidential.Client{}, errors.New(tenantIDValidationErr)
}
Expand All @@ -59,7 +59,7 @@ func getConfidentialClient(clientID, tenantID string, cred confidential.Credenti
return confidential.New(clientID, cred, o...)
}

func getPublicClient(clientID, tenantID string, co *azcore.ClientOptions) (public.Client, error) {
var getPublicClient = func(clientID, tenantID string, co *azcore.ClientOptions) (public.Client, error) {
if !validTenantID(tenantID) {
return public.Client{}, errors.New(tenantIDValidationErr)
}
Expand Down Expand Up @@ -154,6 +154,7 @@ type confidentialClient interface {
AcquireTokenSilent(ctx context.Context, scopes []string, options ...confidential.AcquireTokenSilentOption) (confidential.AuthResult, error)
AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, options ...confidential.AcquireTokenByAuthCodeOption) (confidential.AuthResult, error)
AcquireTokenByCredential(ctx context.Context, scopes []string) (confidential.AuthResult, error)
AcquireTokenOnBehalfOf(ctx context.Context, userAssertion string, scopes []string) (confidential.AuthResult, error)
}

// enables fakes for test scenarios
Expand Down
10 changes: 10 additions & 0 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ type fakeConfidentialClient struct {

// set true to have silent auth succeed
silentAuth bool

// optional callbacks for validating MSAL call args
oboCallback func(context.Context, string, []string)
}

func (f fakeConfidentialClient) returnResult() (confidential.AuthResult, error) {
Expand All @@ -331,6 +334,13 @@ func (f fakeConfidentialClient) AcquireTokenByCredential(ctx context.Context, sc
return f.returnResult()
}

func (f fakeConfidentialClient) AcquireTokenOnBehalfOf(ctx context.Context, userAssertion string, scopes []string) (confidential.AuthResult, error) {
if f.oboCallback != nil {
f.oboCallback(ctx, userAssertion, scopes)
}
return f.returnResult()
}

var _ confidentialClient = (*fakeConfidentialClient)(nil)

// ==================================================================================================================================
Expand Down
22 changes: 22 additions & 0 deletions sdk/azidentity/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
)

func ExampleNewOnBehalfOfCredentialFromCertificate() {
data, err := os.ReadFile(certPath)
if err != nil {
// TODO: handle error
}

// NewOnBehalfOfCredentialFromCertificate requires at least one *x509.Certificate, and a crypto.PrivateKey.
// ParseCertificates returns these given certificate data in PEM or PKCS12 format. It handles common
// scenarios but has limitations, for example it doesn't load PEM encrypted private keys.
certs, key, err := azidentity.ParseCertificates(data, nil)
if err != nil {
// TODO: handle error
}

cred, err = azidentity.NewClientCertificateCredential(tenantID, clientID, certs, key, nil)
if err != nil {
// TODO: handle error
}

// Output:
}

func ExampleNewClientCertificateCredential() {
data, err := os.ReadFile(certPath)
handleError(err)
Expand Down
88 changes: 88 additions & 0 deletions sdk/azidentity/on_behalf_of_credential.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//go:build go1.18
// +build go1.18

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azidentity

import (
"context"
"crypto"
"crypto/x509"
"errors"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)

const credNameOBO = "OnBehalfOfCredential"

// OnBehalfOfCredential authenticates a service principal via the on-behalf-of flow. This is typically used by
// middle-tier services that authorize requests to other services with a delegated user identity. Because this
// is not an interactive authentication flow, an application using it must have admin consent for any delegated
// permissions before requesting tokens for them. See [Azure Active Directory documentation] for more details.
//
// [Azure Active Directory documentation]: https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow
type OnBehalfOfCredential struct {
assertion string
client confidentialClient
}

// OnBehalfOfCredentialOptions contains optional parameters for OnBehalfOfCredential
type OnBehalfOfCredentialOptions struct {
azcore.ClientOptions

// SendCertificateChain applies only when the credential is configured to authenticate with a certificate.
// This setting controls whether the credential sends the public certificate chain in the x5c header of each
// token request's JWT. This is required for, and only used in, Subject Name/Issuer (SNI) authentication.
SendCertificateChain bool
}

// NewOnBehalfOfCredentialFromCertificate constructs an OnBehalfOfCredential that authenticates with a certificate.
// See [ParseCertificates] for help loading a certificate.
func NewOnBehalfOfCredentialFromCertificate(tenantID, clientID, userAssertion string, certs []*x509.Certificate, key crypto.PrivateKey, options *OnBehalfOfCredentialOptions) (*OnBehalfOfCredential, error) {
cred, err := confidential.NewCredFromCertChain(certs, key)
if err != nil {
return nil, err
}
return newOnBehalfOfCredential(tenantID, clientID, userAssertion, cred, options)
}

// NewOnBehalfOfCredentialFromSecret constructs an OnBehalfOfCredential that authenticates with a client secret.
func NewOnBehalfOfCredentialFromSecret(tenantID, clientID, userAssertion, clientSecret string, options *OnBehalfOfCredentialOptions) (*OnBehalfOfCredential, error) {
cred, err := confidential.NewCredFromSecret(clientSecret)
if err != nil {
return nil, err
}
return newOnBehalfOfCredential(tenantID, clientID, userAssertion, cred, options)
}

func newOnBehalfOfCredential(tenantID, clientID, userAssertion string, cred confidential.Credential, options *OnBehalfOfCredentialOptions) (*OnBehalfOfCredential, error) {
if options == nil {
options = &OnBehalfOfCredentialOptions{}
}
opts := []confidential.Option{}
if options.SendCertificateChain {
opts = append(opts, confidential.WithX5C())
}
c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, opts...)
if err != nil {
return nil, err
}
return &OnBehalfOfCredential{assertion: userAssertion, client: c}, nil
}

// GetToken requests an access token from Azure Active Directory. This method is called automatically by Azure SDK clients.
func (o *OnBehalfOfCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
if len(opts.Scopes) == 0 {
return azcore.AccessToken{}, errors.New(credNameSecret + ": GetToken() requires at least one scope")
}
ar, err := o.client.AcquireTokenOnBehalfOf(ctx, o.assertion, opts.Scopes)
if err != nil {
return azcore.AccessToken{}, newAuthenticationFailedErrorFromMSALError(credNameOBO, err)
}
logGetTokenSuccess(o, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, nil
}
107 changes: 107 additions & 0 deletions sdk/azidentity/on_behalf_of_credential_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
//go:build go1.18
// +build go1.18

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azidentity

import (
"context"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)

func TestOnBehalfOfCredential(t *testing.T) {
realGetClient := getConfidentialClient
t.Cleanup(func() { getConfidentialClient = realGetClient })
expectedAssertion := "user-assertion"
for _, test := range []struct {
ctor func() (*OnBehalfOfCredential, error)
name string
sendX5C bool
}{
{
ctor: func() (*OnBehalfOfCredential, error) {
certs, key := allCertTests[0].certs, allCertTests[0].key
return NewOnBehalfOfCredentialFromCertificate(fakeTenantID, fakeClientID, expectedAssertion, certs, key, nil)
},
name: "certificate",
},
{
ctor: func() (*OnBehalfOfCredential, error) {
certs, key := allCertTests[0].certs, allCertTests[0].key
return NewOnBehalfOfCredentialFromCertificate(fakeTenantID, fakeClientID, expectedAssertion, certs, key, &OnBehalfOfCredentialOptions{SendCertificateChain: true})
},
name: "certificate_SNI",
sendX5C: true,
},
{
ctor: func() (*OnBehalfOfCredential, error) {
return NewOnBehalfOfCredentialFromSecret(fakeTenantID, fakeClientID, expectedAssertion, "secret", nil)
},
name: "secret",
},
} {
t.Run(test.name, func(t *testing.T) {
called := false
key := struct{}{}
ctx := context.WithValue(context.Background(), key, true)
fake := fakeConfidentialClient{
ar: confidential.AuthResult{AccessToken: tokenValue, ExpiresOn: time.Now().Add(time.Hour)},
oboCallback: func(c context.Context, assertion string, scopes []string) {
called = true
if v := c.Value(key); v == nil || !v.(bool) {
t.Error("AcquireTokenOnBehalfOf received unexpected Context")
}
if len(scopes) != 1 || scopes[0] != liveTestScope {
t.Errorf(`unexpected scopes "%v"`, scopes)
}
if assertion != expectedAssertion {
t.Errorf(`unexpected assertion "%s"`, assertion)
}
},
}
getConfidentialClient = func(clientID, tenantID string, cred confidential.Credential, co *azcore.ClientOptions, opts ...confidential.Option) (confidentialClient, error) {
if clientID != fakeClientID {
t.Errorf(`unexpected clientID "%s"`, clientID)
}
if tenantID != fakeTenantID {
t.Errorf(`unexpected tenantID "%s"`, tenantID)
}
msalOpts := confidential.Options{}
for _, o := range opts {
o(&msalOpts)
}
if test.sendX5C != msalOpts.SendX5C {
t.Fatal("incorrect value for SendX5C")
}
return fake, nil
}
cred, err := test.ctor()
if err != nil {
t.Fatal(err)
}
tk, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
}
if tk.Token != tokenValue {
t.Errorf(`unexpected token "%s"`, tk.Token)
}
if tk.ExpiresOn.Before(time.Now()) {
t.Error("GetToken returned an invalid expiration time")
}
if tk.ExpiresOn.Location() != time.UTC {
t.Error("ExpiresOn isn't UTC")
}
if !called {
t.Fatal("validation function wasn't called")
}
})
}
}