diff --git a/sdk/storage/azdatalake/CHANGELOG.md b/sdk/storage/azdatalake/CHANGELOG.md index d6ed27e1bfb5..ad685c81e279 100644 --- a/sdk/storage/azdatalake/CHANGELOG.md +++ b/sdk/storage/azdatalake/CHANGELOG.md @@ -10,12 +10,16 @@ * HNS Encryption Context support * Pagination Support for recursive directory deletion * Bundle ability to set permission, owner, group, acl, lease, expiry time and umask along with FileSystem.CreateFile and FileSystem.CreateDirectory APIs. +* Added support for AAD Audience when OAuth is used. +* Updated service version to `2023-11-03` +* Integrate `InsecureAllowCredentialWithHTTP` client options. ### Breaking Changes ### Bugs Fixed ### Other Changes +* Updated azcore version to `1.11.1` ## 1.1.1 (2024-02-29) diff --git a/sdk/storage/azdatalake/assets.json b/sdk/storage/azdatalake/assets.json index d2635b49e8b9..7c9f73ffb5c5 100644 --- a/sdk/storage/azdatalake/assets.json +++ b/sdk/storage/azdatalake/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "go", "TagPrefix": "go/storage/azdatalake", - "Tag": "go/storage/azdatalake_36960f5092" + "Tag": "go/storage/azdatalake_8cf0ce4c24" } diff --git a/sdk/storage/azdatalake/directory/client.go b/sdk/storage/azdatalake/directory/client.go index 02671d07222e..c3cc6c30fd18 100644 --- a/sdk/storage/azdatalake/directory/client.go +++ b/sdk/storage/azdatalake/directory/client.go @@ -42,8 +42,9 @@ type Client base.CompositeClient[generated.PathClient, generated_blob.BlobClient func NewClient(directoryURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) { blobURL, directoryURL := shared.GetURLs(directoryURL) - authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil) + audience := base.GetAudience((*base.ClientOptions)(options)) conOptions := shared.GetClientOptions(options) + authPolicy := shared.NewStorageChallengePolicy(cred, audience, conOptions.InsecureAllowCredentialWithHTTP) plOpts := runtime.PipelineOptions{ PerRetry: []policy.Policy{authPolicy}, } diff --git a/sdk/storage/azdatalake/directory/client_test.go b/sdk/storage/azdatalake/directory/client_test.go index 3dcc454ed75c..6cad9521cd5c 100644 --- a/sdk/storage/azdatalake/directory/client_test.go +++ b/sdk/storage/azdatalake/directory/client_test.go @@ -2922,3 +2922,73 @@ func (s *UnrecordedTestSuite) TestDirCreateDeleteUsingOAuth() { _, err = dirClient.GetProperties(context.Background(), nil) _require.NoError(err) } + +func (s *RecordedTestSuite) TestCreateDirectoryClientDefaultAudience() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFileSystemName(testName) + fsClient, err := testcommon.GetFileSystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFileSystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.NoError(err) + + cred, err := testcommon.GetGenericTokenCredential() + _require.NoError(err) + + accountName, _ := testcommon.GetGenericAccountInfo(testcommon.TestAccountDatalake) + _require.Greater(len(accountName), 0) + + dirName := testcommon.GenerateDirName(testName) + dirURL := "https://" + accountName + ".dfs.core.windows.net/" + filesystemName + "/" + dirName + + options := &directory.ClientOptions{Audience: "https://storage.azure.com/"} + testcommon.SetClientOptions(s.T(), &options.ClientOptions) + + dirClient, err := directory.NewClient(dirURL, cred, options) + _require.NoError(err) + + _, err = dirClient.Create(context.Background(), nil) + _require.NoError(err) + + _, err = dirClient.GetProperties(context.Background(), nil) + _require.NoError(err) + +} + +func (s *RecordedTestSuite) TestCreateDirectoryClientCustomAudience() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFileSystemName(testName) + fsClient, err := testcommon.GetFileSystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFileSystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.NoError(err) + + cred, err := testcommon.GetGenericTokenCredential() + _require.NoError(err) + + accountName, _ := testcommon.GetGenericAccountInfo(testcommon.TestAccountDatalake) + _require.Greater(len(accountName), 0) + + dirName := testcommon.GenerateDirName(testName) + dirURL := "https://" + accountName + ".dfs.core.windows.net/" + filesystemName + "/" + dirName + + options := &directory.ClientOptions{Audience: "https://" + accountName + ".blob.core.windows.net"} + testcommon.SetClientOptions(s.T(), &options.ClientOptions) + + dirClient, err := directory.NewClient(dirURL, cred, options) + _require.NoError(err) + + _, err = dirClient.Create(context.Background(), nil) + _require.NoError(err) + + _, err = dirClient.GetProperties(context.Background(), nil) + _require.NoError(err) + +} diff --git a/sdk/storage/azdatalake/file/client.go b/sdk/storage/azdatalake/file/client.go index 91024ecde10c..42795df065ae 100644 --- a/sdk/storage/azdatalake/file/client.go +++ b/sdk/storage/azdatalake/file/client.go @@ -48,8 +48,9 @@ type Client base.CompositeClient[generated.PathClient, generated_blob.BlobClient // - options - client options; pass nil to accept the default values func NewClient(fileURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) { blobURL, fileURL := shared.GetURLs(fileURL) - authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil) + audience := base.GetAudience((*base.ClientOptions)(options)) conOptions := shared.GetClientOptions(options) + authPolicy := shared.NewStorageChallengePolicy(cred, audience, conOptions.InsecureAllowCredentialWithHTTP) plOpts := runtime.PipelineOptions{ PerRetry: []policy.Policy{authPolicy}, } diff --git a/sdk/storage/azdatalake/file/client_test.go b/sdk/storage/azdatalake/file/client_test.go index e541f4fd2084..802cdcb80d73 100644 --- a/sdk/storage/azdatalake/file/client_test.go +++ b/sdk/storage/azdatalake/file/client_test.go @@ -5484,3 +5484,71 @@ func TestUploadSmallChunkSize(t *testing.T) { _require.Equal(atomic.LoadUint64(&fbb.numChunks), numChunks) } + +func (s *RecordedTestSuite) TestFileClientCustomAudience() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFileSystemName(testName) + fsClient, err := testcommon.GetFileSystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFileSystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.NoError(err) + + cred, err := testcommon.GetGenericTokenCredential() + _require.NoError(err) + + accountName, _ := testcommon.GetGenericAccountInfo(testcommon.TestAccountDatalake) + _require.Greater(len(accountName), 0) + + fileName := testcommon.GenerateFileName(testName) + fileURL := "https://" + accountName + ".dfs.core.windows.net/" + filesystemName + "/" + fileName + + options := &file.ClientOptions{Audience: "https://" + accountName + ".blob.core.windows.net"} + testcommon.SetClientOptions(s.T(), &options.ClientOptions) + + fClient, err := file.NewClient(fileURL, cred, options) + _require.NoError(err) + + _, err = fClient.Create(context.Background(), nil) + _require.NoError(err) + + _, err = fClient.GetProperties(context.Background(), nil) + _require.NoError(err) +} + +func (s *RecordedTestSuite) TestFileClientDefaultAudience() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFileSystemName(testName) + fsClient, err := testcommon.GetFileSystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFileSystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.NoError(err) + + cred, err := testcommon.GetGenericTokenCredential() + _require.NoError(err) + + accountName, _ := testcommon.GetGenericAccountInfo(testcommon.TestAccountDatalake) + _require.Greater(len(accountName), 0) + + fileName := testcommon.GenerateFileName(testName) + fileURL := "https://" + accountName + ".dfs.core.windows.net/" + filesystemName + "/" + fileName + + options := &file.ClientOptions{Audience: "https://storage.azure.com/"} + testcommon.SetClientOptions(s.T(), &options.ClientOptions) + + fClient, err := file.NewClient(fileURL, cred, options) + _require.NoError(err) + + _, err = fClient.Create(context.Background(), nil) + _require.NoError(err) + + _, err = fClient.GetProperties(context.Background(), nil) + _require.NoError(err) +} diff --git a/sdk/storage/azdatalake/filesystem/client.go b/sdk/storage/azdatalake/filesystem/client.go index 242f27ca672d..cf6e06ce0c95 100644 --- a/sdk/storage/azdatalake/filesystem/client.go +++ b/sdk/storage/azdatalake/filesystem/client.go @@ -40,8 +40,9 @@ type Client base.CompositeClient[generated.FileSystemClient, generated.FileSyste // - options - client options; pass nil to accept the default values func NewClient(filesystemURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) { containerURL, filesystemURL := shared.GetURLs(filesystemURL) - authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil) + audience := base.GetAudience((*base.ClientOptions)(options)) conOptions := shared.GetClientOptions(options) + authPolicy := shared.NewStorageChallengePolicy(cred, audience, conOptions.InsecureAllowCredentialWithHTTP) plOpts := runtime.PipelineOptions{ PerRetry: []policy.Policy{authPolicy}, } diff --git a/sdk/storage/azdatalake/filesystem/client_test.go b/sdk/storage/azdatalake/filesystem/client_test.go index e11bef495b92..695548b14598 100644 --- a/sdk/storage/azdatalake/filesystem/client_test.go +++ b/sdk/storage/azdatalake/filesystem/client_test.go @@ -2084,3 +2084,57 @@ func (s *RecordedTestSuite) TestCreateDirectoryInFileSystemSetOptions() { _require.Equal(filesystem.StateTypeLeased, *response.LeaseState) } + +func (s *RecordedTestSuite) TestFSCreateDefaultAudience() { + _require := require.New(s.T()) + testName := s.T().Name() + + cred, err := testcommon.GetGenericTokenCredential() + _require.NoError(err) + + accountName, _ := testcommon.GetGenericAccountInfo(testcommon.TestAccountDatalake) + _require.Greater(len(accountName), 0) + + filesystemName := testcommon.GenerateFileSystemName(testName) + fsURL := "https://" + accountName + ".dfs.core.windows.net/" + filesystemName + + options := &filesystem.ClientOptions{Audience: "https://storage.azure.com/"} + testcommon.SetClientOptions(s.T(), &options.ClientOptions) + fsClient, err := filesystem.NewClient(fsURL, cred, options) + _require.NoError(err) + defer testcommon.DeleteFileSystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.NoError(err) + + _, err = fsClient.GetProperties(context.Background(), nil) + _require.NoError(err) + +} + +func (s *RecordedTestSuite) TestFSCreateCustomAudience() { + _require := require.New(s.T()) + testName := s.T().Name() + + cred, err := testcommon.GetGenericTokenCredential() + _require.NoError(err) + + accountName, _ := testcommon.GetGenericAccountInfo(testcommon.TestAccountDatalake) + _require.Greater(len(accountName), 0) + + filesystemName := testcommon.GenerateFileSystemName(testName) + fsURL := "https://" + accountName + ".dfs.core.windows.net/" + filesystemName + + options := &filesystem.ClientOptions{Audience: "https://" + accountName + ".blob.core.windows.net"} + testcommon.SetClientOptions(s.T(), &options.ClientOptions) + fsClient, err := filesystem.NewClient(fsURL, cred, options) + _require.NoError(err) + defer testcommon.DeleteFileSystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.NoError(err) + + _, err = fsClient.GetProperties(context.Background(), nil) + _require.NoError(err) + +} diff --git a/sdk/storage/azdatalake/go.mod b/sdk/storage/azdatalake/go.mod index 469cd5ce4c6d..f7594741bc21 100644 --- a/sdk/storage/azdatalake/go.mod +++ b/sdk/storage/azdatalake/go.mod @@ -3,7 +3,7 @@ module github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake go 1.18 require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1 github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.1 @@ -19,9 +19,9 @@ require ( github.com/kylelemons/godebug v1.1.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/crypto v0.18.0 // indirect - golang.org/x/net v0.20.0 // indirect - golang.org/x/sys v0.16.0 // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/net v0.22.0 // indirect + golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/sdk/storage/azdatalake/go.sum b/sdk/storage/azdatalake/go.sum index f7e8b87df681..3bf2319e3ad6 100644 --- a/sdk/storage/azdatalake/go.sum +++ b/sdk/storage/azdatalake/go.sum @@ -1,5 +1,5 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2 h1:c4k2FIYIh4xtwqrQwV0Ct1v5+ehlNXj5NI/MWVsiTkQ= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.9.2/go.mod h1:5FDJtLEO/GxwNgUxbwrY3LP0pEoThTQJtk2oysdXHxM= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1 h1:sO0/P7g68FrryJzljemN+6GTssUXdANk6aJ7T1ZxnsQ= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.1/go.mod h1:h8hyGFDsU5HMivxiS2iYFZsgDbU9OnnJ163x5UGVKYo= github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 h1:LqbJ/WzJUwBf8UiaSzgX7aMclParm9/5Vgp+TY51uBQ= @@ -26,13 +26,13 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= -golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= -golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/sdk/storage/azdatalake/internal/base/clients.go b/sdk/storage/azdatalake/internal/base/clients.go index e5b9e2a300c5..ad6de2d5583a 100644 --- a/sdk/storage/azdatalake/internal/base/clients.go +++ b/sdk/storage/azdatalake/internal/base/clients.go @@ -15,12 +15,18 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/generated" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/generated_blob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/shared" + "strings" ) // ClientOptions contains the optional parameters when creating a Client. type ClientOptions struct { azcore.ClientOptions pipelineOptions *runtime.PipelineOptions + // Audience to use when requesting tokens for Azure Active Directory authentication. + // Only has an effect when credential is of type TokenCredential. The value could be + // https://storage.azure.com/ (default) or https://.blob.core.windows.net. + Audience string } func GetPipelineOptions(clOpts *ClientOptions) *runtime.PipelineOptions { @@ -91,3 +97,11 @@ func NewPathClient(pathURL string, pathURLWithBlobEndpoint string, client *block func GetCompositeClientOptions[T, K, U any](client *CompositeClient[T, K, U]) *ClientOptions { return client.options } + +func GetAudience(clOpts *ClientOptions) string { + if clOpts == nil || len(strings.TrimSpace(clOpts.Audience)) == 0 { + return shared.TokenScope + } else { + return strings.TrimRight(clOpts.Audience, "/") + "/.default" + } +} diff --git a/sdk/storage/azdatalake/internal/shared/challenge_policy.go b/sdk/storage/azdatalake/internal/shared/challenge_policy.go new file mode 100644 index 000000000000..4aea4ee83b9a --- /dev/null +++ b/sdk/storage/azdatalake/internal/shared/challenge_policy.go @@ -0,0 +1,113 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. +package shared + +import ( + "errors" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "net/http" + "strings" +) + +type storageAuthorizer struct { + scopes []string + tenantID string +} + +func NewStorageChallengePolicy(cred azcore.TokenCredential, audience string, allowHTTP bool) policy.Policy { + s := storageAuthorizer{scopes: []string{audience}} + return runtime.NewBearerTokenPolicy(cred, []string{audience}, &policy.BearerTokenOptions{ + AuthorizationHandler: policy.AuthorizationHandler{ + OnRequest: s.onRequest, + OnChallenge: s.onChallenge, + }, + InsecureAllowCredentialWithHTTP: allowHTTP, + }) +} + +func (s *storageAuthorizer) onRequest(req *policy.Request, authNZ func(policy.TokenRequestOptions) error) error { + return authNZ(policy.TokenRequestOptions{Scopes: s.scopes}) +} + +func (s *storageAuthorizer) onChallenge(req *policy.Request, resp *http.Response, authNZ func(policy.TokenRequestOptions) error) error { + // parse the challenge + err := s.parseChallenge(resp) + if err != nil { + return err + } + // TODO: Set tenantID when policy.TokenRequestOptions supports it. https://github.com/Azure/azure-sdk-for-go/issues/19841 + return authNZ(policy.TokenRequestOptions{Scopes: s.scopes}) +} + +type challengePolicyError struct { + err error +} + +func (c *challengePolicyError) Error() string { + return c.err.Error() +} + +func (*challengePolicyError) NonRetriable() { + // marker method +} + +func (c *challengePolicyError) Unwrap() error { + return c.err +} + +// parses Tenant ID from auth challenge +// https://login.microsoftonline.com/00000000-0000-0000-0000-000000000000/oauth2/authorize +func parseTenant(url string) string { + if url == "" { + return "" + } + parts := strings.Split(url, "/") + if len(parts) >= 3 { + tenant := parts[3] + tenant = strings.ReplaceAll(tenant, ",", "") + return tenant + } else { + return "" + } +} + +func (s *storageAuthorizer) parseChallenge(resp *http.Response) error { + authHeader := resp.Header.Get("WWW-Authenticate") + if authHeader == "" { + return &challengePolicyError{err: errors.New("response has no WWW-Authenticate header for challenge authentication")} + } + + // Strip down to auth and resource + // Format is "Bearer authorization_uri=\"\" resource_id=\"\"" + authHeader = strings.ReplaceAll(authHeader, "Bearer ", "") + + parts := strings.Split(authHeader, " ") + + vals := map[string]string{} + for _, part := range parts { + subParts := strings.Split(part, "=") + if len(subParts) == 2 { + stripped := strings.ReplaceAll(subParts[1], "\"", "") + stripped = strings.TrimSuffix(stripped, ",") + vals[subParts[0]] = stripped + } + } + + s.tenantID = parseTenant(vals["authorization_uri"]) + + scope := vals["resource_id"] + if scope == "" { + return &challengePolicyError{err: errors.New("could not find a valid resource in the WWW-Authenticate header")} + } + + if !strings.HasSuffix(scope, "/.default") { + scope += "/.default" + } + s.scopes = []string{scope} + return nil +} diff --git a/sdk/storage/azdatalake/internal/shared/challenge_policy_test.go b/sdk/storage/azdatalake/internal/shared/challenge_policy_test.go new file mode 100644 index 000000000000..8eb25d8fa050 --- /dev/null +++ b/sdk/storage/azdatalake/internal/shared/challenge_policy_test.go @@ -0,0 +1,114 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package shared + +import ( + "context" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/stretchr/testify/require" + "strings" + "testing" + "time" +) + +type credentialFunc func(context.Context, policy.TokenRequestOptions) (azcore.AccessToken, error) + +func (cf credentialFunc) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { + return cf(ctx, options) +} + +func TestChallengePolicyStorage(t *testing.T) { + accessToken := "***" + storageScope := "https://storage.azure.com/.default" + + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithStatusCode(200), + ) + authenticated := false + cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + authenticated = true + require.Equal(t, []string{storageScope}, tro.Scopes) + return azcore.AccessToken{Token: accessToken, ExpiresOn: time.Now().Add(time.Hour)}, nil + }) + p := NewStorageChallengePolicy(cred, storageScope, false) + pl := runtime.NewPipeline("", "", + runtime.PipelineOptions{PerRetry: []policy.Policy{p}}, + &policy.ClientOptions{Transport: srv}, + ) + req, err := runtime.NewRequest(context.Background(), "GET", "https://localhost") + require.NoError(t, err) + _, err = pl.Do(req) + require.NoError(t, err) + require.True(t, authenticated, "policy should have authenticated") +} + +func TestChallengePolicyDisk(t *testing.T) { + accessToken := "***" + diskResource := "https://disk.azure.com/" + diskScope := "https://disk.azure.com//.default" + challenge := `Bearer authorization_uri="https://login.microsoftonline.com/{tenant}", resource_id="{storageResource}"` + + srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.AppendResponse( + mock.WithHeader("WWW-Authenticate", strings.ReplaceAll(challenge, "{storageResource}", diskResource)), + mock.WithStatusCode(401), + ) + srv.AppendResponse( + mock.WithStatusCode(200), + ) + attemptedAuthentication := false + authenticated := false + cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) { + if attemptedAuthentication { + authenticated = true + require.Equal(t, []string{diskScope}, tro.Scopes) + return azcore.AccessToken{Token: accessToken, ExpiresOn: time.Now().Add(time.Hour)}, nil + } + attemptedAuthentication = true + return azcore.AccessToken{}, nil + }) + p := NewStorageChallengePolicy(cred, "https://storage.azure.com/.default", false) + pl := runtime.NewPipeline("", "", + runtime.PipelineOptions{PerRetry: []policy.Policy{p}}, + &policy.ClientOptions{Transport: srv}, + ) + req, err := runtime.NewRequest(context.Background(), "GET", "https://localhost") + require.NoError(t, err) + _, err = pl.Do(req) + require.NoError(t, err) + require.True(t, authenticated, "policy should have authenticated") +} + +func TestParseTenant(t *testing.T) { + actual := parseTenant("") + require.Empty(t, actual) + + expected := "00000000-0000-0000-0000-000000000000" + sampleURL := "https://login.microsoftonline.com/" + expected + actual = parseTenant(sampleURL) + require.Equal(t, expected, actual, "tenant was not properly parsed") +} + +func TestParseTenantNegative(t *testing.T) { + actual := parseTenant("") + require.Empty(t, actual) + + expected := "" + sampleURL := "https://login.microsoftonline.com/" + expected + actual = parseTenant(sampleURL) + require.Equal(t, expected, actual) + + sampleURL = "" + actual = parseTenant(sampleURL) + require.Equal(t, expected, actual) +} diff --git a/sdk/storage/azdatalake/service/client.go b/sdk/storage/azdatalake/service/client.go index 327871cb4f56..fa2546c7bdb1 100644 --- a/sdk/storage/azdatalake/service/client.go +++ b/sdk/storage/azdatalake/service/client.go @@ -39,8 +39,9 @@ type Client base.CompositeClient[generated.ServiceClient, generated_blob.Service // - options - client options; pass nil to accept the default values func NewClient(serviceURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) { blobServiceURL, datalakeServiceURL := shared.GetURLs(serviceURL) - authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil) + audience := base.GetAudience((*base.ClientOptions)(options)) conOptions := shared.GetClientOptions(options) + authPolicy := shared.NewStorageChallengePolicy(cred, audience, conOptions.InsecureAllowCredentialWithHTTP) plOpts := runtime.PipelineOptions{ PerRetry: []policy.Policy{authPolicy}, } diff --git a/sdk/storage/azdatalake/service/client_test.go b/sdk/storage/azdatalake/service/client_test.go index 8172fac057c0..1af7e1491b57 100644 --- a/sdk/storage/azdatalake/service/client_test.go +++ b/sdk/storage/azdatalake/service/client_test.go @@ -860,3 +860,75 @@ func (s *ServiceRecordedTestsSuite) TestServiceClientWithNilSharedKey() { _require.Error(err) _require.Nil(svcClient) } + +func (s *ServiceRecordedTestsSuite) TestServiceClientUsingOauth() { + _require := require.New(s.T()) + + accountName, _ := testcommon.GetGenericAccountInfo(testcommon.TestAccountDatalake) + _require.Greater(len(accountName), 0) + + cred, err := testcommon.GetGenericTokenCredential() + _require.NoError(err) + + serviceUrl := "https://" + accountName + ".dfs.core.windows.net/" + + svcClient, err := service.NewClient(serviceUrl, cred, nil) + _require.NoError(err) + _require.NotNil(svcClient) + + fs, _ := svcClient.CreateFileSystem(context.Background(), "test", nil) + _require.NotNil(fs) + _require.NoError(err) +} + +func (s *ServiceRecordedTestsSuite) TestServiceClientUsingOauthWithDefaultAudience() { + _require := require.New(s.T()) + + accountName, _ := testcommon.GetGenericAccountInfo(testcommon.TestAccountDatalake) + _require.Greater(len(accountName), 0) + + cred, err := testcommon.GetGenericTokenCredential() + _require.NoError(err) + + serviceUrl := "https://" + accountName + ".dfs.core.windows.net/" + + options := service.ClientOptions{ + Audience: "https://storage.azure.com/", + } + + testcommon.SetClientOptions(s.T(), &options.ClientOptions) + svcClient, err := service.NewClient(serviceUrl, cred, &options) + _require.NoError(err) + _require.NotNil(svcClient) + + fs, _ := svcClient.CreateFileSystem(context.Background(), "test", nil) + _require.NotNil(fs) + _require.NoError(err) + +} + +func (s *ServiceRecordedTestsSuite) TestServiceClientUsingOauthWithCustomAudience() { + _require := require.New(s.T()) + + accountName, _ := testcommon.GetGenericAccountInfo(testcommon.TestAccountDatalake) + _require.Greater(len(accountName), 0) + + serviceUrl := "https://" + accountName + ".dfs.core.windows.net/" + + cred, err := testcommon.GetGenericTokenCredential() + _require.NoError(err) + + options := service.ClientOptions{ + Audience: "https://" + accountName + ".blob.core.windows.net", + } + + testcommon.SetClientOptions(s.T(), &options.ClientOptions) + svcClient, err := service.NewClient(serviceUrl, cred, &options) + _require.NoError(err) + _require.NotNil(svcClient) + + fs, _ := svcClient.CreateFileSystem(context.Background(), "test", nil) + _require.NotNil(fs) + _require.NoError(err) + +}