Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/storage/Azure.Storage.Blobs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Added support for listing deleted root blobs with versions to BlobContainerClient.GetBlobs() and .GetBlobsByHierarchy()
- Added support for OAuth copy sources for synchronous copy operations.
- Added support for Parquet as an input format in BlockBlobClient.Query().
- Added optimization to unwrap encryption key once for DownloadTo and OpenRead when Client Side Encryption is enabled.

## 12.9.0 (2021-06-08)
- Includes all features from 12.9.0-beta.4.
Expand Down
14 changes: 14 additions & 0 deletions sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,10 @@ internal async Task<Response> StagedDownloadAsync(
{
PartitionedDownloader downloader = new PartitionedDownloader(this, transferOptions);

if (UsingClientSideEncryption)
{
ClientSideDecryptor.BeginContentEncryptionKeyCaching();
}
if (async)
{
return await downloader.DownloadToAsync(destination, conditions, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -2077,12 +2081,22 @@ internal async Task<Stream> OpenReadInternal(
readConditions = readConditions?.WithIfMatch(etag) ?? new BlobRequestConditions { IfMatch = etag };
}

ClientSideDecryptor.ContentEncryptionKeyCache contentEncryptionKeyCache = default;
if (UsingClientSideEncryption && !allowModifications)
{
contentEncryptionKeyCache = new();
}

return new LazyLoadingReadOnlyStream<BlobProperties>(
async (HttpRange range,
bool rangeGetContentHash,
bool async,
CancellationToken cancellationToken) =>
{
if (UsingClientSideEncryption)
{
ClientSideDecryptor.BeginContentEncryptionKeyCaching(contentEncryptionKeyCache);
}
Response<BlobDownloadStreamingResult> response = await DownloadStreamingInternal(
range,
readConditions,
Expand Down
104 changes: 94 additions & 10 deletions sdk/storage/Azure.Storage.Blobs/tests/ClientSideEncryptionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,25 @@ private Mock<IKeyEncryptionKey> GetIKeyEncryptionKey(byte[] userKeyBytes = defau
{
keyMock.Setup(k => k.WrapKey(s_algorithmName, IsNotNull<ReadOnlyMemory<byte>>(), s_cancellationToken))
.Returns<string, ReadOnlyMemory<byte>, CancellationToken>((algorithm, key, cancellationToken) => Xor(userKeyBytes, key.ToArray()));
keyMock.Setup(k => k.UnwrapKey(s_algorithmName, IsNotNull<ReadOnlyMemory<byte>>(), s_cancellationToken))
keyMock.Setup(k => k.UnwrapKey(s_algorithmName, IsNotNull<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Returns<string, ReadOnlyMemory<byte>, CancellationToken>((algorithm, wrappedKey, cancellationToken) => Xor(userKeyBytes, wrappedKey.ToArray()));
}

return keyMock;
}

private void VerifyUnwrappedKeyWasCached(Mock<IKeyEncryptionKey> keyMock)
{
if (IsAsync)
{
keyMock.Verify(k => k.UnwrapKeyAsync(s_algorithmName, IsNotNull<ReadOnlyMemory<byte>>(), s_cancellationToken), Times.Once);
}
else
{
keyMock.Verify(k => k.UnwrapKey(s_algorithmName, IsNotNull<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
}
}

private Mock<IKeyEncryptionKeyResolver> GetAlwaysFailsKeyResolver(bool throws)
{
var mock = new Mock<IKeyEncryptionKeyResolver>(MockBehavior.Strict);
Expand Down Expand Up @@ -288,20 +300,21 @@ public async Task UploadAsync(long dataSize)
}
}

[TestCase(16)] // a single cipher block
[TestCase(14)] // a single unalligned cipher block
[TestCase(Constants.KB)] // multiple blocks
[TestCase(Constants.KB - 4)] // multiple unalligned blocks
[TestCase(16, null)] // a single cipher block
[TestCase(14, null)] // a single unalligned cipher block
[TestCase(Constants.KB, null)] // multiple blocks
[TestCase(Constants.KB - 4, null)] // multiple unalligned blocks
[TestCase(Constants.MB, 64*Constants.KB)] // make sure we cache unwrapped key for large downloads
[LiveOnly] // cannot seed content encryption key
public async Task RoundtripAsync(long dataSize)
public async Task RoundtripAsync(long dataSize, long? initialDownloadRequestSize)
{
var data = GetRandomBuffer(dataSize);
var mockKey = GetIKeyEncryptionKey().Object;
var mockKeyResolver = GetIKeyEncryptionKeyResolver(mockKey).Object;
var mockKey = GetIKeyEncryptionKey();
var mockKeyResolver = GetIKeyEncryptionKeyResolver(mockKey.Object).Object;
await using (var disposable = await GetTestContainerEncryptionAsync(
new ClientSideEncryptionOptions(ClientSideEncryptionVersion.V1_0)
{
KeyEncryptionKey = mockKey,
KeyEncryptionKey = mockKey.Object,
KeyResolver = mockKeyResolver,
KeyWrapAlgorithm = s_algorithmName
}))
Expand All @@ -315,12 +328,58 @@ public async Task RoundtripAsync(long dataSize)
byte[] downloadData;
using (var stream = new MemoryStream())
{
await blob.DownloadToAsync(stream, cancellationToken: s_cancellationToken);
await blob.DownloadToAsync(stream,
transferOptions: new StorageTransferOptions() { InitialTransferSize = initialDownloadRequestSize },
cancellationToken: s_cancellationToken);
downloadData = stream.ToArray();
}

// compare data
Assert.AreEqual(data, downloadData);
VerifyUnwrappedKeyWasCached(mockKey);
}
}

[TestCase(Constants.MB, 64*Constants.KB)]
[TestCase(Constants.MB, Constants.MB)]
[TestCase(Constants.MB, 4*Constants.MB)]
[LiveOnly] // cannot seed content encryption key
public async Task RoundtripAsyncWithOpenRead(long dataSize, int bufferSize)
{
var data = GetRandomBuffer(dataSize);
var mockKey = GetIKeyEncryptionKey();
var mockKeyResolver = GetIKeyEncryptionKeyResolver(mockKey.Object).Object;
await using (var disposable = await GetTestContainerEncryptionAsync(
new ClientSideEncryptionOptions(ClientSideEncryptionVersion.V1_0)
{
KeyEncryptionKey = mockKey.Object,
KeyResolver = mockKeyResolver,
KeyWrapAlgorithm = s_algorithmName
}))
{
var blob = InstrumentClient(disposable.Container.GetBlobClient(GetNewBlobName()));

// upload with encryption
await blob.UploadAsync(new MemoryStream(data), cancellationToken: s_cancellationToken);

// download with decryption
byte[] downloadData;
using (var stream = new MemoryStream())
{
using var blobStream = await blob.OpenReadAsync(new BlobOpenReadOptions(false) { BufferSize = bufferSize }, cancellationToken: s_cancellationToken);
if (IsAsync)
{
await blobStream.CopyToAsync(stream, bufferSize, s_cancellationToken);
} else
{
blobStream.CopyTo(stream, bufferSize);
}
downloadData = stream.ToArray();
}

// compare data
Assert.AreEqual(data, downloadData);
VerifyUnwrappedKeyWasCached(mockKey);
}
}

Expand Down Expand Up @@ -534,6 +593,31 @@ public async Task RoundtripWithKeyvaultProvider()
}
}

[TestCase(Constants.MB, 64*Constants.KB)]
[LiveOnly] // need access to keyvault service && cannot seed content encryption key
public async Task RoundtripWithKeyvaultProviderOpenRead(long dataSize, int bufferSize)
{
var data = GetRandomBuffer(dataSize);
IKeyEncryptionKey key = await GetKeyvaultIKeyEncryptionKey();
await using (var disposable = await GetTestContainerEncryptionAsync(
new ClientSideEncryptionOptions(ClientSideEncryptionVersion.V1_0)
{
KeyEncryptionKey = key,
KeyWrapAlgorithm = "RSA-OAEP-256"
}))
{
var blob = disposable.Container.GetBlobClient(GetNewBlobName());

await blob.UploadAsync(new MemoryStream(data), cancellationToken: s_cancellationToken);

var downloadStream = new MemoryStream();
using var blobStream = await blob.OpenReadAsync(new BlobOpenReadOptions(false) { BufferSize = bufferSize});
await blobStream.CopyToAsync(downloadStream);

Assert.AreEqual(data, downloadStream.ToArray());
}
}

[TestCase(true)]
[TestCase(false)]
[LiveOnly]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ namespace Azure.Storage.Cryptography
{
internal class ClientSideDecryptor
{
/// <summary>
/// A cache for encryption key if high level API spans across multiple service calls.
/// </summary>
private static readonly AsyncLocal<ContentEncryptionKeyCache> s_contentEncryptionKeyCache = new();

/// <summary>
/// Clients that can upload data have a key encryption key stored on them. Checking if
/// a cached key exists and matches a given <see cref="EncryptionData"/> saves a call
Expand Down Expand Up @@ -179,6 +184,11 @@ private async Task<Memory<byte>> GetContentEncryptionKeyAsync(
bool async,
CancellationToken cancellationToken)
{
if (s_contentEncryptionKeyCache.Value?.Key.HasValue ?? false)
{
return s_contentEncryptionKeyCache.Value.Key.Value;
}

IKeyEncryptionKey key = default;

// If we already have a local key and it is the correct one, use that.
Expand All @@ -201,7 +211,7 @@ private async Task<Memory<byte>> GetContentEncryptionKeyAsync(
throw Errors.ClientSideEncryption.KeyNotFound(encryptionData.WrappedContentKey.KeyId);
}

return async
var contentEncryptionKey = async
? await key.UnwrapKeyAsync(
encryptionData.WrappedContentKey.Algorithm,
encryptionData.WrappedContentKey.EncryptedKey,
Expand All @@ -210,6 +220,13 @@ private async Task<Memory<byte>> GetContentEncryptionKeyAsync(
encryptionData.WrappedContentKey.Algorithm,
encryptionData.WrappedContentKey.EncryptedKey,
cancellationToken);

if (s_contentEncryptionKeyCache.Value != default)
{
s_contentEncryptionKeyCache.Value.Key = contentEncryptionKey;
}

return contentEncryptionKey;
}

/// <summary>
Expand Down Expand Up @@ -250,5 +267,15 @@ private static Stream WrapStream(

throw Errors.ClientSideEncryption.BadEncryptionAlgorithm(encryptionData.EncryptionAgent.EncryptionAlgorithm.ToString());
}

internal static void BeginContentEncryptionKeyCaching(ContentEncryptionKeyCache cache = default)
{
s_contentEncryptionKeyCache.Value = cache ?? new ContentEncryptionKeyCache();
}

internal class ContentEncryptionKeyCache
{
public Memory<byte>? Key { get; set; }
}
}
}