Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/storage/azblob/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Added [Blob Batch API](https://learn.microsoft.com/rest/api/storageservices/blob-batch).
* Added support for bearer challenge for identity based managed disks.
* Added support for GetAccountInfo to container and blob level clients.
* Added support for CopySourceAuthorization to appendblob.AppendBlockFromURL

### Breaking Changes

Expand Down
89 changes: 89 additions & 0 deletions sdk/storage/azblob/appendblob/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"context"
"crypto/md5"
"encoding/binary"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"hash/crc64"
"io"
"math/rand"
Expand Down Expand Up @@ -407,6 +408,94 @@ func (s *AppendBlobUnrecordedTestsSuite) TestAppendBlockFromURLWithMD5() {
testcommon.ValidateBlobErrorCode(_require, err, bloberror.MD5Mismatch)
}

func (s *AppendBlobRecordedTestsSuite) TestAppendBlockFromURLCopySourceAuth() {
_require := require.New(s.T())
testName := s.T().Name()
svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil)
_require.NoError(err)

// Random seed for data generation
seed := int64(crc64.Checksum([]byte(testName), shared.CRC64Table))
random := rand.New(rand.NewSource(seed))

// Getting AAD Authentication
cred, err := testcommon.GetGenericTokenCredential()
_require.NoError(err)

containerName := testcommon.GenerateContainerName(testName)
containerClient := testcommon.CreateNewContainer(context.Background(), _require, containerName, svcClient)
defer testcommon.DeleteContainer(context.Background(), _require, containerClient)

// Create source and destination blobs
srcABClient := containerClient.NewAppendBlobClient(testcommon.GenerateBlobName("appendsrc"))
destABClient := containerClient.NewAppendBlobClient(testcommon.GenerateBlobName("appenddest"))

// Upload some data to source
_, err = srcABClient.Create(context.Background(), nil)
_require.Nil(err)
contentSize := 4 * 1024 // 4KB
r, sourceData := testcommon.GetDataAndReader(random, contentSize)
_, err = srcABClient.AppendBlock(context.Background(), streaming.NopCloser(r), nil)
_require.Nil(err)
_, err = destABClient.Create(context.Background(), nil)
_require.Nil(err)

// Getting token
token, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"https://storage.azure.com/.default"}})
_require.NoError(err)

options := appendblob.AppendBlockFromURLOptions{
CopySourceAuthorization: to.Ptr("Bearer " + token.Token),
}

pbResp, err := destABClient.AppendBlockFromURL(context.Background(), srcABClient.URL(), &options)
_require.NoError(err)
_require.NotNil(pbResp)

// Download data from destination
destBuffer := make([]byte, 4*1024)
_, err = destABClient.DownloadBuffer(context.Background(), destBuffer, nil)
_require.Nil(err)
_require.Equal(destBuffer, sourceData)
}

func (s *AppendBlobRecordedTestsSuite) TestAppendBlockFromURLCopySourceAuthNegative() {
_require := require.New(s.T())
testName := s.T().Name()
svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil)
_require.NoError(err)

// Random seed for data generation
seed := int64(crc64.Checksum([]byte(testName), shared.CRC64Table))
random := rand.New(rand.NewSource(seed))

containerName := testcommon.GenerateContainerName(testName)
containerClient := testcommon.CreateNewContainer(context.Background(), _require, containerName, svcClient)
defer testcommon.DeleteContainer(context.Background(), _require, containerClient)

// Create source and destination blobs
srcABClient := containerClient.NewAppendBlobClient(testcommon.GenerateBlobName("appendsrc"))
destABClient := containerClient.NewAppendBlobClient(testcommon.GenerateBlobName("appenddest"))

// Upload some data to source
_, err = srcABClient.Create(context.Background(), nil)
_require.Nil(err)
contentSize := 4 * 1024 // 4KB
r, _ := testcommon.GetDataAndReader(random, contentSize)
_, err = srcABClient.AppendBlock(context.Background(), streaming.NopCloser(r), nil)
_require.Nil(err)
_, err = destABClient.Create(context.Background(), nil)
_require.Nil(err)

options := appendblob.AppendBlockFromURLOptions{
CopySourceAuthorization: to.Ptr("Bearer faketoken"),
}

_, err = destABClient.AppendBlockFromURL(context.Background(), srcABClient.URL(), &options)
_require.Error(err)
_require.True(bloberror.HasCode(err, bloberror.CannotVerifyCopySource))
}

func (s *AppendBlobRecordedTestsSuite) TestBlobCreateAppendMetadataNonEmpty() {
_require := require.New(s.T())
testName := s.T().Name()
Expand Down
6 changes: 5 additions & 1 deletion sdk/storage/azblob/appendblob/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ func (o *AppendBlockOptions) format() (*generated.AppendBlobClientAppendBlockOpt

// AppendBlockFromURLOptions contains the optional parameters for the Client.AppendBlockFromURL method.
type AppendBlockFromURLOptions struct {
// Only Bearer type is supported. Credentials should be a valid OAuth access token to copy source.
CopySourceAuthorization *string

// SourceContentValidation contains the validation mechanism used on the range of bytes read from the source.
SourceContentValidation blob.SourceContentValidationType

Expand All @@ -125,7 +128,8 @@ func (o *AppendBlockFromURLOptions) format() (*generated.AppendBlobClientAppendB
}

options := &generated.AppendBlobClientAppendBlockFromURLOptions{
SourceRange: exported.FormatHTTPRange(o.Range),
SourceRange: exported.FormatHTTPRange(o.Range),
CopySourceAuthorization: o.CopySourceAuthorization,
}

if o.SourceContentValidation != nil {
Expand Down
2 changes: 1 addition & 1 deletion sdk/storage/azblob/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/storage/azblob",
"Tag": "go/storage/azblob_5d20008f59"
"Tag": "go/storage/azblob_fad5549316"
}
16 changes: 16 additions & 0 deletions sdk/storage/azblob/internal/testcommon/clients_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"context"
"errors"
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/sas"
Expand Down Expand Up @@ -55,6 +56,7 @@ const (
const (
FakeStorageAccount = "fakestorage"
FakeStorageURL = "https://fakestorage.blob.core.windows.net"
FakeToken = "faketoken"
)

var (
Expand Down Expand Up @@ -145,6 +147,20 @@ func GetServiceClientNoCredential(t *testing.T, sasUrl string, options *service.
return serviceClient, err
}

type FakeCredential struct {
}

func (c *FakeCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
return azcore.AccessToken{Token: FakeToken, ExpiresOn: time.Now().Add(time.Hour).UTC()}, nil
}

func GetGenericTokenCredential() (azcore.TokenCredential, error) {
if recording.GetRecordMode() == recording.PlaybackMode {
return &FakeCredential{}, nil
}
return azidentity.NewDefaultAzureCredential(nil)
}

func GetGenericAccountInfo(accountType TestAccountType) (string, string) {
if recording.GetRecordMode() == recording.PlaybackMode {
return FakeStorageAccount, "ZmFrZQ=="
Expand Down
13 changes: 11 additions & 2 deletions sdk/storage/azblob/internal/testcommon/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ package testcommon
import (
"bytes"
"context"
"crypto/rand"
crypto_rand "crypto/rand"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"math/rand"
"os"
"runtime"
"strconv"
Expand Down Expand Up @@ -83,7 +84,13 @@ func GetReaderToGeneratedBytes(n int) io.ReadSeekCloser {

func GetRandomDataAndReader(n int) (*bytes.Reader, []byte) {
data := make([]byte, n)
_, _ = rand.Read(data)
_, _ = crypto_rand.Read(data)
return bytes.NewReader(data), data
}

func GetDataAndReader(r *rand.Rand, n int) (*bytes.Reader, []byte) {
data := make([]byte, n)
_, _ = r.Read(data)
return bytes.NewReader(data), data
}

Expand Down Expand Up @@ -181,8 +188,10 @@ func GetRequiredEnv(name string) (string, error) {

func BeforeTest(t *testing.T, suite string, test string) {
const urlRegex = `https://\S+\.blob\.core\.windows\.net`
const tokenRegex = `(?:Bearer\s).*`
require.NoError(t, recording.AddURISanitizer(FakeStorageURL, urlRegex, nil))
require.NoError(t, recording.AddHeaderRegexSanitizer("x-ms-copy-source", FakeStorageURL, urlRegex, nil))
require.NoError(t, recording.AddHeaderRegexSanitizer("x-ms-copy-source-authorization", FakeToken, tokenRegex, nil))
// we freeze request IDs and timestamps to avoid creating noisy diffs
// NOTE: we can't freeze time stamps as that breaks some tests that use if-modified-since etc (maybe it can be fixed?)
//testframework.AddHeaderRegexSanitizer("X-Ms-Date", "Wed, 10 Aug 2022 23:34:14 GMT", "", nil)
Expand Down