diff --git a/sdk/storage/Azure.Storage.Blobs/CHANGELOG.md b/sdk/storage/Azure.Storage.Blobs/CHANGELOG.md index aab3736e493f..931d6bdabfbb 100644 --- a/sdk/storage/Azure.Storage.Blobs/CHANGELOG.md +++ b/sdk/storage/Azure.Storage.Blobs/CHANGELOG.md @@ -9,6 +9,7 @@ - Added support for listing deleted root blobs with versions to BlobContainerClient.GetBlobs() and .GetBlobsByHierarchy() - Added support for OAuth copy sources for synchronous copy operations. - Added support for Parquet as an input format in BlockBlobClient.Query(). +- Added optimization to unwrap encryption key once for DownloadTo and OpenRead when Client Side Encryption is enabled. ## 12.9.0 (2021-06-08) - Includes all features from 12.9.0-beta.4. diff --git a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs index b6c54af90ba4..3bf1460ecec1 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs @@ -1787,6 +1787,10 @@ internal async Task StagedDownloadAsync( { PartitionedDownloader downloader = new PartitionedDownloader(this, transferOptions); + if (UsingClientSideEncryption) + { + ClientSideDecryptor.BeginContentEncryptionKeyCaching(); + } if (async) { return await downloader.DownloadToAsync(destination, conditions, cancellationToken).ConfigureAwait(false); @@ -2077,12 +2081,22 @@ internal async Task OpenReadInternal( readConditions = readConditions?.WithIfMatch(etag) ?? new BlobRequestConditions { IfMatch = etag }; } + ClientSideDecryptor.ContentEncryptionKeyCache contentEncryptionKeyCache = default; + if (UsingClientSideEncryption && !allowModifications) + { + contentEncryptionKeyCache = new(); + } + return new LazyLoadingReadOnlyStream( async (HttpRange range, bool rangeGetContentHash, bool async, CancellationToken cancellationToken) => { + if (UsingClientSideEncryption) + { + ClientSideDecryptor.BeginContentEncryptionKeyCaching(contentEncryptionKeyCache); + } Response response = await DownloadStreamingInternal( range, readConditions, diff --git a/sdk/storage/Azure.Storage.Blobs/tests/ClientSideEncryptionTests.cs b/sdk/storage/Azure.Storage.Blobs/tests/ClientSideEncryptionTests.cs index 8d1e5089e7ca..0a9df3732382 100644 --- a/sdk/storage/Azure.Storage.Blobs/tests/ClientSideEncryptionTests.cs +++ b/sdk/storage/Azure.Storage.Blobs/tests/ClientSideEncryptionTests.cs @@ -100,13 +100,25 @@ private Mock GetIKeyEncryptionKey(byte[] userKeyBytes = defau { keyMock.Setup(k => k.WrapKey(s_algorithmName, IsNotNull>(), s_cancellationToken)) .Returns, CancellationToken>((algorithm, key, cancellationToken) => Xor(userKeyBytes, key.ToArray())); - keyMock.Setup(k => k.UnwrapKey(s_algorithmName, IsNotNull>(), s_cancellationToken)) + keyMock.Setup(k => k.UnwrapKey(s_algorithmName, IsNotNull>(), It.IsAny())) .Returns, CancellationToken>((algorithm, wrappedKey, cancellationToken) => Xor(userKeyBytes, wrappedKey.ToArray())); } return keyMock; } + private void VerifyUnwrappedKeyWasCached(Mock keyMock) + { + if (IsAsync) + { + keyMock.Verify(k => k.UnwrapKeyAsync(s_algorithmName, IsNotNull>(), s_cancellationToken), Times.Once); + } + else + { + keyMock.Verify(k => k.UnwrapKey(s_algorithmName, IsNotNull>(), It.IsAny()), Times.Once); + } + } + private Mock GetAlwaysFailsKeyResolver(bool throws) { var mock = new Mock(MockBehavior.Strict); @@ -288,20 +300,21 @@ public async Task UploadAsync(long dataSize) } } - [TestCase(16)] // a single cipher block - [TestCase(14)] // a single unalligned cipher block - [TestCase(Constants.KB)] // multiple blocks - [TestCase(Constants.KB - 4)] // multiple unalligned blocks + [TestCase(16, null)] // a single cipher block + [TestCase(14, null)] // a single unalligned cipher block + [TestCase(Constants.KB, null)] // multiple blocks + [TestCase(Constants.KB - 4, null)] // multiple unalligned blocks + [TestCase(Constants.MB, 64*Constants.KB)] // make sure we cache unwrapped key for large downloads [LiveOnly] // cannot seed content encryption key - public async Task RoundtripAsync(long dataSize) + public async Task RoundtripAsync(long dataSize, long? initialDownloadRequestSize) { var data = GetRandomBuffer(dataSize); - var mockKey = GetIKeyEncryptionKey().Object; - var mockKeyResolver = GetIKeyEncryptionKeyResolver(mockKey).Object; + var mockKey = GetIKeyEncryptionKey(); + var mockKeyResolver = GetIKeyEncryptionKeyResolver(mockKey.Object).Object; await using (var disposable = await GetTestContainerEncryptionAsync( new ClientSideEncryptionOptions(ClientSideEncryptionVersion.V1_0) { - KeyEncryptionKey = mockKey, + KeyEncryptionKey = mockKey.Object, KeyResolver = mockKeyResolver, KeyWrapAlgorithm = s_algorithmName })) @@ -315,12 +328,58 @@ public async Task RoundtripAsync(long dataSize) byte[] downloadData; using (var stream = new MemoryStream()) { - await blob.DownloadToAsync(stream, cancellationToken: s_cancellationToken); + await blob.DownloadToAsync(stream, + transferOptions: new StorageTransferOptions() { InitialTransferSize = initialDownloadRequestSize }, + cancellationToken: s_cancellationToken); downloadData = stream.ToArray(); } // compare data Assert.AreEqual(data, downloadData); + VerifyUnwrappedKeyWasCached(mockKey); + } + } + + [TestCase(Constants.MB, 64*Constants.KB)] + [TestCase(Constants.MB, Constants.MB)] + [TestCase(Constants.MB, 4*Constants.MB)] + [LiveOnly] // cannot seed content encryption key + public async Task RoundtripAsyncWithOpenRead(long dataSize, int bufferSize) + { + var data = GetRandomBuffer(dataSize); + var mockKey = GetIKeyEncryptionKey(); + var mockKeyResolver = GetIKeyEncryptionKeyResolver(mockKey.Object).Object; + await using (var disposable = await GetTestContainerEncryptionAsync( + new ClientSideEncryptionOptions(ClientSideEncryptionVersion.V1_0) + { + KeyEncryptionKey = mockKey.Object, + KeyResolver = mockKeyResolver, + KeyWrapAlgorithm = s_algorithmName + })) + { + var blob = InstrumentClient(disposable.Container.GetBlobClient(GetNewBlobName())); + + // upload with encryption + await blob.UploadAsync(new MemoryStream(data), cancellationToken: s_cancellationToken); + + // download with decryption + byte[] downloadData; + using (var stream = new MemoryStream()) + { + using var blobStream = await blob.OpenReadAsync(new BlobOpenReadOptions(false) { BufferSize = bufferSize }, cancellationToken: s_cancellationToken); + if (IsAsync) + { + await blobStream.CopyToAsync(stream, bufferSize, s_cancellationToken); + } else + { + blobStream.CopyTo(stream, bufferSize); + } + downloadData = stream.ToArray(); + } + + // compare data + Assert.AreEqual(data, downloadData); + VerifyUnwrappedKeyWasCached(mockKey); } } @@ -534,6 +593,31 @@ public async Task RoundtripWithKeyvaultProvider() } } + [TestCase(Constants.MB, 64*Constants.KB)] + [LiveOnly] // need access to keyvault service && cannot seed content encryption key + public async Task RoundtripWithKeyvaultProviderOpenRead(long dataSize, int bufferSize) + { + var data = GetRandomBuffer(dataSize); + IKeyEncryptionKey key = await GetKeyvaultIKeyEncryptionKey(); + await using (var disposable = await GetTestContainerEncryptionAsync( + new ClientSideEncryptionOptions(ClientSideEncryptionVersion.V1_0) + { + KeyEncryptionKey = key, + KeyWrapAlgorithm = "RSA-OAEP-256" + })) + { + var blob = disposable.Container.GetBlobClient(GetNewBlobName()); + + await blob.UploadAsync(new MemoryStream(data), cancellationToken: s_cancellationToken); + + var downloadStream = new MemoryStream(); + using var blobStream = await blob.OpenReadAsync(new BlobOpenReadOptions(false) { BufferSize = bufferSize}); + await blobStream.CopyToAsync(downloadStream); + + Assert.AreEqual(data, downloadStream.ToArray()); + } + } + [TestCase(true)] [TestCase(false)] [LiveOnly] diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/ClientsideEncryption/ClientSideDecryptor.cs b/sdk/storage/Azure.Storage.Common/src/Shared/ClientsideEncryption/ClientSideDecryptor.cs index d7fd4074b295..2b8318a9413d 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/ClientsideEncryption/ClientSideDecryptor.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/ClientsideEncryption/ClientSideDecryptor.cs @@ -13,6 +13,11 @@ namespace Azure.Storage.Cryptography { internal class ClientSideDecryptor { + /// + /// A cache for encryption key if high level API spans across multiple service calls. + /// + private static readonly AsyncLocal s_contentEncryptionKeyCache = new(); + /// /// Clients that can upload data have a key encryption key stored on them. Checking if /// a cached key exists and matches a given saves a call @@ -179,6 +184,11 @@ private async Task> GetContentEncryptionKeyAsync( bool async, CancellationToken cancellationToken) { + if (s_contentEncryptionKeyCache.Value?.Key.HasValue ?? false) + { + return s_contentEncryptionKeyCache.Value.Key.Value; + } + IKeyEncryptionKey key = default; // If we already have a local key and it is the correct one, use that. @@ -201,7 +211,7 @@ private async Task> GetContentEncryptionKeyAsync( throw Errors.ClientSideEncryption.KeyNotFound(encryptionData.WrappedContentKey.KeyId); } - return async + var contentEncryptionKey = async ? await key.UnwrapKeyAsync( encryptionData.WrappedContentKey.Algorithm, encryptionData.WrappedContentKey.EncryptedKey, @@ -210,6 +220,13 @@ private async Task> GetContentEncryptionKeyAsync( encryptionData.WrappedContentKey.Algorithm, encryptionData.WrappedContentKey.EncryptedKey, cancellationToken); + + if (s_contentEncryptionKeyCache.Value != default) + { + s_contentEncryptionKeyCache.Value.Key = contentEncryptionKey; + } + + return contentEncryptionKey; } /// @@ -250,5 +267,15 @@ private static Stream WrapStream( throw Errors.ClientSideEncryption.BadEncryptionAlgorithm(encryptionData.EncryptionAgent.EncryptionAlgorithm.ToString()); } + + internal static void BeginContentEncryptionKeyCaching(ContentEncryptionKeyCache cache = default) + { + s_contentEncryptionKeyCache.Value = cache ?? new ContentEncryptionKeyCache(); + } + + internal class ContentEncryptionKeyCache + { + public Memory? Key { get; set; } + } } }