From ec84e3f973def7941cb7f0712b88500e429a3449 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 6 Oct 2022 04:51:58 -0700 Subject: [PATCH] Consolidate common MSAL options --- sdk/azidentity/azidentity.go | 31 +++++++++++++++++++ sdk/azidentity/client_assertion_credential.go | 15 +-------- .../client_certificate_credential.go | 17 ++-------- sdk/azidentity/client_secret_credential.go | 15 +-------- sdk/azidentity/device_code_credential.go | 13 +------- .../interactive_browser_credential.go | 13 +------- .../username_password_credential.go | 13 +------- 7 files changed, 38 insertions(+), 79 deletions(-) diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index a95c54046d2d..bc2975d8178e 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -35,6 +35,37 @@ const ( tenantIDValidationErr = "invalid tenantID. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names" ) +func getConfidentialClient(clientID, tenantID string, cred confidential.Credential, co *azcore.ClientOptions, additionalOpts ...confidential.Option) (confidential.Client, error) { + if !validTenantID(tenantID) { + return confidential.Client{}, errors.New(tenantIDValidationErr) + } + authorityHost, err := setAuthorityHost(co.Cloud) + if err != nil { + return confidential.Client{}, err + } + o := []confidential.Option{ + confidential.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)), + confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)), + confidential.WithHTTPClient(newPipelineAdapter(co)), + } + o = append(o, additionalOpts...) + return confidential.New(clientID, cred, o...) +} + +func getPublicClient(clientID, tenantID string, co *azcore.ClientOptions) (public.Client, error) { + if !validTenantID(tenantID) { + return public.Client{}, errors.New(tenantIDValidationErr) + } + authorityHost, err := setAuthorityHost(co.Cloud) + if err != nil { + return public.Client{}, err + } + return public.New(clientID, + public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)), + public.WithHTTPClient(newPipelineAdapter(co)), + ) +} + // setAuthorityHost initializes the authority host for credentials. Precedence is: // 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user // 2. value of AZURE_AUTHORITY_HOST diff --git a/sdk/azidentity/client_assertion_credential.go b/sdk/azidentity/client_assertion_credential.go index 27930b09bfad..ffcf2094be20 100644 --- a/sdk/azidentity/client_assertion_credential.go +++ b/sdk/azidentity/client_assertion_credential.go @@ -9,11 +9,9 @@ package azidentity import ( "context" "errors" - "os" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) @@ -39,26 +37,15 @@ func NewClientAssertionCredential(tenantID, clientID string, getAssertion func(c if getAssertion == nil { return nil, errors.New("getAssertion must be a function that returns assertions") } - if !validTenantID(tenantID) { - return nil, errors.New(tenantIDValidationErr) - } if options == nil { options = &ClientAssertionCredentialOptions{} } - authorityHost, err := setAuthorityHost(options.Cloud) - if err != nil { - return nil, err - } cred := confidential.NewCredFromAssertionCallback( func(ctx context.Context, _ confidential.AssertionRequestOptions) (string, error) { return getAssertion(ctx) }, ) - c, err := confidential.New(clientID, cred, - confidential.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)), - confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)), - confidential.WithHTTPClient(newPipelineAdapter(&options.ClientOptions)), - ) + c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions) if err != nil { return nil, err } diff --git a/sdk/azidentity/client_certificate_credential.go b/sdk/azidentity/client_certificate_credential.go index 855b25fdf7f1..a61d824ef59f 100644 --- a/sdk/azidentity/client_certificate_credential.go +++ b/sdk/azidentity/client_certificate_credential.go @@ -12,11 +12,9 @@ import ( "crypto/x509" "encoding/pem" "errors" - "os" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "golang.org/x/crypto/pkcs12" ) @@ -43,29 +41,18 @@ func NewClientCertificateCredential(tenantID string, clientID string, certs []*x if len(certs) == 0 { return nil, errors.New("at least one certificate is required") } - if !validTenantID(tenantID) { - return nil, errors.New(tenantIDValidationErr) - } if options == nil { options = &ClientCertificateCredentialOptions{} } - authorityHost, err := setAuthorityHost(options.Cloud) - if err != nil { - return nil, err - } cred, err := confidential.NewCredFromCertChain(certs, key) if err != nil { return nil, err } - o := []confidential.Option{ - confidential.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)), - confidential.WithHTTPClient(newPipelineAdapter(&options.ClientOptions)), - confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)), - } + var o []confidential.Option if options.SendCertificateChain { o = append(o, confidential.WithX5C()) } - c, err := confidential.New(clientID, cred, o...) + c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions, o...) if err != nil { return nil, err } diff --git a/sdk/azidentity/client_secret_credential.go b/sdk/azidentity/client_secret_credential.go index 6ecb8f4db816..1c3a516601b3 100644 --- a/sdk/azidentity/client_secret_credential.go +++ b/sdk/azidentity/client_secret_credential.go @@ -9,11 +9,9 @@ package azidentity import ( "context" "errors" - "os" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" ) @@ -31,25 +29,14 @@ type ClientSecretCredential struct { // NewClientSecretCredential constructs a ClientSecretCredential. Pass nil for options to accept defaults. func NewClientSecretCredential(tenantID string, clientID string, clientSecret string, options *ClientSecretCredentialOptions) (*ClientSecretCredential, error) { - if !validTenantID(tenantID) { - return nil, errors.New(tenantIDValidationErr) - } if options == nil { options = &ClientSecretCredentialOptions{} } - authorityHost, err := setAuthorityHost(options.Cloud) - if err != nil { - return nil, err - } cred, err := confidential.NewCredFromSecret(clientSecret) if err != nil { return nil, err } - c, err := confidential.New(clientID, cred, - confidential.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)), - confidential.WithHTTPClient(newPipelineAdapter(&options.ClientOptions)), - confidential.WithAzureRegion(os.Getenv(azureRegionalAuthorityName)), - ) + c, err := getConfidentialClient(clientID, tenantID, cred, &options.ClientOptions) if err != nil { return nil, err } diff --git a/sdk/azidentity/device_code_credential.go b/sdk/azidentity/device_code_credential.go index d0c72c348548..2e9b5438dbd0 100644 --- a/sdk/azidentity/device_code_credential.go +++ b/sdk/azidentity/device_code_credential.go @@ -13,7 +13,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) @@ -79,17 +78,7 @@ func NewDeviceCodeCredential(options *DeviceCodeCredentialOptions) (*DeviceCodeC cp = *options } cp.init() - if !validTenantID(cp.TenantID) { - return nil, errors.New(tenantIDValidationErr) - } - authorityHost, err := setAuthorityHost(cp.Cloud) - if err != nil { - return nil, err - } - c, err := public.New(cp.ClientID, - public.WithAuthority(runtime.JoinPaths(authorityHost, cp.TenantID)), - public.WithHTTPClient(newPipelineAdapter(&cp.ClientOptions)), - ) + c, err := getPublicClient(cp.ClientID, cp.TenantID, &cp.ClientOptions) if err != nil { return nil, err } diff --git a/sdk/azidentity/interactive_browser_credential.go b/sdk/azidentity/interactive_browser_credential.go index e4aaf45b6dda..9032ae9886a5 100644 --- a/sdk/azidentity/interactive_browser_credential.go +++ b/sdk/azidentity/interactive_browser_credential.go @@ -12,7 +12,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) @@ -56,17 +55,7 @@ func NewInteractiveBrowserCredential(options *InteractiveBrowserCredentialOption cp = *options } cp.init() - if !validTenantID(cp.TenantID) { - return nil, errors.New(tenantIDValidationErr) - } - authorityHost, err := setAuthorityHost(cp.Cloud) - if err != nil { - return nil, err - } - c, err := public.New(cp.ClientID, - public.WithAuthority(runtime.JoinPaths(authorityHost, cp.TenantID)), - public.WithHTTPClient(newPipelineAdapter(&cp.ClientOptions)), - ) + c, err := getPublicClient(cp.ClientID, cp.TenantID, &cp.ClientOptions) if err != nil { return nil, err } diff --git a/sdk/azidentity/username_password_credential.go b/sdk/azidentity/username_password_credential.go index 8b02e7b47bab..2ab248c3c616 100644 --- a/sdk/azidentity/username_password_credential.go +++ b/sdk/azidentity/username_password_credential.go @@ -12,7 +12,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) @@ -37,20 +36,10 @@ type UsernamePasswordCredential struct { // NewUsernamePasswordCredential creates a UsernamePasswordCredential. clientID is the ID of the application the user // will authenticate to. Pass nil for options to accept defaults. func NewUsernamePasswordCredential(tenantID string, clientID string, username string, password string, options *UsernamePasswordCredentialOptions) (*UsernamePasswordCredential, error) { - if !validTenantID(tenantID) { - return nil, errors.New(tenantIDValidationErr) - } if options == nil { options = &UsernamePasswordCredentialOptions{} } - authorityHost, err := setAuthorityHost(options.Cloud) - if err != nil { - return nil, err - } - c, err := public.New(clientID, - public.WithAuthority(runtime.JoinPaths(authorityHost, tenantID)), - public.WithHTTPClient(newPipelineAdapter(&options.ClientOptions)), - ) + c, err := getPublicClient(clientID, tenantID, &options.ClientOptions) if err != nil { return nil, err }