Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using Azure.Security.KeyVault.Keys.Cryptography;
using System;
using System.Collections.Concurrent;
using System.Threading.Tasks;
using static Azure.Security.KeyVault.Keys.Cryptography.SignatureAlgorithm;

namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider
Expand All @@ -29,7 +28,7 @@ internal class AzureSqlKeyCryptographer
/// These tasks will be used for returning the key in the event that the fetch task has not finished depositing the
/// key into the key dictionary.
/// </summary>
private readonly ConcurrentDictionary<string, Task<Azure.Response<KeyVaultKey>>> _keyFetchTaskDictionary = new();
private readonly ConcurrentDictionary<string, KeyVaultKey> _keyFetchTaskDictionary = new();

/// <summary>
/// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI).
Expand Down Expand Up @@ -79,10 +78,10 @@ internal KeyVaultKey GetKey(string keyIdentifierUri)
return key;
}

if (_keyFetchTaskDictionary.TryGetValue(keyIdentifierUri, out Task<Azure.Response<KeyVaultKey>> task))
if (_keyFetchTaskDictionary.TryGetValue(keyIdentifierUri, out KeyVaultKey keyVaultKey))
{
AKVEventSource.Log.TryTraceEvent("New Master key fetched.");
return Task.Run(() => task).GetAwaiter().GetResult();
return keyVaultKey;
}

// Not a public exception - not likely to occur.
Expand Down Expand Up @@ -155,14 +154,11 @@ private CryptographyClient GetCryptographyClient(string keyIdentifierUri)
/// <param name="keyResourceUri">The Azure Key Vault key identifier</param>
private void FetchKey(Uri vaultUri, string keyName, string keyVersion, string keyResourceUri)
{
Task<Azure.Response<KeyVaultKey>> fetchKeyTask = FetchKeyFromKeyVault(vaultUri, keyName, keyVersion);
_keyFetchTaskDictionary.AddOrUpdate(keyResourceUri, fetchKeyTask, (k, v) => fetchKeyTask);
KeyVaultKey key = FetchKeyFromKeyVault(vaultUri, keyName, keyVersion);
_keyFetchTaskDictionary.AddOrUpdate(keyResourceUri, key, (k, v) => key);

fetchKeyTask
.ContinueWith(k => ValidateRsaKey(k.GetAwaiter().GetResult()))
.ContinueWith(k => _keyDictionary.AddOrUpdate(keyResourceUri, k.GetAwaiter().GetResult(), (key, v) => k.GetAwaiter().GetResult()));

Task.Run(() => fetchKeyTask);
ValidateRsaKey(key);
_keyDictionary.AddOrUpdate(keyResourceUri, key, (k, v) => key);
}

/// <summary>
Expand All @@ -172,19 +168,31 @@ private void FetchKey(Uri vaultUri, string keyName, string keyVersion, string ke
/// <param name="keyName">Then name of the key</param>
/// <param name="keyVersion">Then version of the key</param>
/// <returns></returns>
private Task<Azure.Response<KeyVaultKey>> FetchKeyFromKeyVault(Uri vaultUri, string keyName, string keyVersion)
private KeyVaultKey FetchKeyFromKeyVault(Uri vaultUri, string keyName, string keyVersion)
{
_keyClientDictionary.TryGetValue(vaultUri, out KeyClient keyClient);
AKVEventSource.Log.TryTraceEvent("Fetching requested master key: {0}", keyName);
return keyClient?.GetKeyAsync(keyName, keyVersion);
return keyClient?.GetKey(keyName, keyVersion);
}

/// <summary>
/// Instantiates and adds a KeyClient to the KeyClient dictionary
/// </summary>
/// <param name="vaultUri">The Azure Key Vault URI</param>
private void CreateKeyClient(Uri vaultUri)
{
if (!_keyClientDictionary.ContainsKey(vaultUri))
{
_keyClientDictionary.TryAdd(vaultUri, new KeyClient(vaultUri, TokenCredential));
}
}

/// <summary>
/// Validates that a key is of type RSA
/// </summary>
/// <param name="key"></param>
/// <returns></returns>
private KeyVaultKey ValidateRsaKey(KeyVaultKey key)
private static KeyVaultKey ValidateRsaKey(KeyVaultKey key)
{
if (key.KeyType != KeyType.Rsa && key.KeyType != KeyType.RsaHsm)
{
Expand All @@ -195,26 +203,14 @@ private KeyVaultKey ValidateRsaKey(KeyVaultKey key)
return key;
}

/// <summary>
/// Instantiates and adds a KeyClient to the KeyClient dictionary
/// </summary>
/// <param name="vaultUri">The Azure Key Vault URI</param>
private void CreateKeyClient(Uri vaultUri)
{
if (!_keyClientDictionary.ContainsKey(vaultUri))
{
_keyClientDictionary.TryAdd(vaultUri, new KeyClient(vaultUri, TokenCredential));
}
}

/// <summary>
/// Validates and parses the Azure Key Vault URI and key name.
/// </summary>
/// <param name="masterKeyPath">The Azure Key Vault key identifier</param>
/// <param name="vaultUri">The Azure Key Vault URI</param>
/// <param name="masterKeyName">The name of the key</param>
/// <param name="masterKeyVersion">The version of the key</param>
private void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion)
private static void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion)
{
Uri masterKeyPathUri = new(masterKeyPath);
vaultUri = new Uri(masterKeyPathUri.GetLeftPart(UriPartial.Authority));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Extensions.Caching.Memory;
using System;
using Microsoft.Extensions.Caching.Memory;
using static System.Math;

namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider
Expand Down Expand Up @@ -89,15 +89,5 @@ internal TValue GetOrCreate(TKey key, Func<TValue> createItem)

return cacheEntry;
}

/// <summary>
/// Determines whether the <see cref="LocalCache{TKey, TValue}">LocalCache</see> contains the specified key.
/// </summary>
/// <param name="key"></param>
/// <returns></returns>
internal bool Contains(TKey key)
{
return _cache.TryGetValue(key, out _);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Text;
using System.Threading;
using Azure.Core;
using Azure.Security.KeyVault.Keys.Cryptography;
using static Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider.Validator;
Expand Down Expand Up @@ -55,6 +56,8 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt

private readonly static KeyWrapAlgorithm s_keyWrapAlgorithm = KeyWrapAlgorithm.RsaOaep;

private SemaphoreSlim _cacheSemaphore = new(1, 1);

/// <summary>
/// List of Trusted Endpoints
///
Expand All @@ -69,7 +72,7 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt
/// <summary>
/// A cache for storing the results of signature verification of column master key metadata.
/// </summary>
private readonly LocalCache<Tuple<string, bool, string>, bool> _columnMasterKeyMetadataSignatureVerificationCache =
private readonly LocalCache<Tuple<string, bool, string>, bool> _columnMasterKeyMetadataSignatureVerificationCache =
new(maxSizeLimit: 2000) { TimeToLive = TimeSpan.FromDays(10) };

/// <summary>
Expand Down Expand Up @@ -230,7 +233,7 @@ byte[] DecryptEncryptionKey()
// Get ciphertext
byte[] cipherText = new byte[cipherTextLength];
Array.Copy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength);

currentIndex += cipherTextLength;

// Get signature
Expand Down Expand Up @@ -415,8 +418,19 @@ private string ToHexString(byte[] source)
/// <remarks>
///
/// </remarks>
private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func<byte[]> createItem)
=> _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem);
private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func<byte[]> createItem)
{
try
{
// Aow only one thread to access the cache at a time.
_cacheSemaphore.Wait();
return _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem);
}
finally
{
_cacheSemaphore.Release();
}
}

/// <summary>
/// Returns the cached signature verification result, or proceeds to verify if not present.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ private SqlConnection(SqlConnection connection)

internal static bool TryGetSystemColumnEncryptionKeyStoreProvider(string keyStoreName, out SqlColumnEncryptionKeyStoreProvider provider)
{
return s_systemColumnEncryptionKeyStoreProviders.TryGetValue(keyStoreName, out provider);
return s_systemColumnEncryptionKeyStoreProviders.TryGetValue(keyStoreName, out provider);
}

/// <summary>
Expand Down Expand Up @@ -276,9 +276,9 @@ internal static List<string> GetColumnEncryptionSystemKeyStoreProvidersNames()
{
if (s_systemColumnEncryptionKeyStoreProviders.Count > 0)
{
return new List<string>(s_systemColumnEncryptionKeyStoreProviders.Keys);
return [.. s_systemColumnEncryptionKeyStoreProviders.Keys];
}
return new List<string>(0);
return [];
}

/// <summary>
Expand All @@ -291,13 +291,13 @@ internal List<string> GetColumnEncryptionCustomKeyStoreProvidersNames()
if (_customColumnEncryptionKeyStoreProviders is not null &&
_customColumnEncryptionKeyStoreProviders.Count > 0)
{
return new List<string>(_customColumnEncryptionKeyStoreProviders.Keys);
return [.. _customColumnEncryptionKeyStoreProviders.Keys];
}
if (s_globalCustomColumnEncryptionKeyStoreProviders is not null)
{
return new List<string>(s_globalCustomColumnEncryptionKeyStoreProviders.Keys);
return [.. s_globalCustomColumnEncryptionKeyStoreProviders.Keys];
}
return new List<string>(0);
return [];
}

/// <summary>
Expand Down Expand Up @@ -325,12 +325,6 @@ public static void RegisterColumnEncryptionKeyStoreProviders(IDictionary<string,
throw SQL.CanOnlyCallOnce();
}

// to prevent conflicts between CEK caches, global providers should not use their own CEK caches
foreach (SqlColumnEncryptionKeyStoreProvider provider in customProviders.Values)
{
provider.ColumnEncryptionKeyCacheTtl = new TimeSpan(0);
}

// Create a temporary dictionary and then add items from the provided dictionary.
// Dictionary constructor does shallow copying by simply copying the provider name and provider reference pairs
// in the provided customerProviders dictionary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ internal SqlClientSymmetricKey GetKey(SqlEncryptionKeyInfo keyInfo, SqlConnectio
{
string serverName = connection.DataSource;
Debug.Assert(serverName is not null, @"serverName should not be null.");
StringBuilder cacheLookupKeyBuilder = new StringBuilder(serverName, capacity: serverName.Length + SqlSecurityUtility.GetBase64LengthFromByteLength(keyInfo.encryptedKey.Length) + keyInfo.keyStoreName.Length + 2/*separators*/);
StringBuilder cacheLookupKeyBuilder = new(serverName, capacity: serverName.Length + SqlSecurityUtility.GetBase64LengthFromByteLength(keyInfo.encryptedKey.Length) + keyInfo.keyStoreName.Length + 2/*separators*/);

#if DEBUG
int capacity = cacheLookupKeyBuilder.Capacity;
Expand All @@ -53,8 +53,7 @@ internal SqlClientSymmetricKey GetKey(SqlEncryptionKeyInfo keyInfo, SqlConnectio
#endif //DEBUG

// Lookup the key in cache
SqlClientSymmetricKey encryptionKey;
if (!(_cache.TryGetValue(cacheLookupKey, out encryptionKey)))
if (!(_cache.TryGetValue(cacheLookupKey, out SqlClientSymmetricKey encryptionKey)))
{
Debug.Assert(SqlConnection.ColumnEncryptionTrustedMasterKeyPaths is not null, @"SqlConnection.ColumnEncryptionTrustedMasterKeyPaths should not be null");

Expand All @@ -73,8 +72,6 @@ internal SqlClientSymmetricKey GetKey(SqlEncryptionKeyInfo keyInfo, SqlConnectio
byte[] plaintextKey;
try
{
// to prevent conflicts between CEK caches, global providers should not use their own CEK caches
provider.ColumnEncryptionKeyCacheTtl = new TimeSpan(0);
plaintextKey = provider.DecryptColumnEncryptionKey(keyInfo.keyPath, keyInfo.algorithmName, keyInfo.encryptedKey);
}
catch (Exception e)
Expand All @@ -91,11 +88,11 @@ internal SqlClientSymmetricKey GetKey(SqlEncryptionKeyInfo keyInfo, SqlConnectio
{
// In case multiple threads reach here at the same time, the first one wins.
// The allocated memory will be reclaimed by Garbage Collector.
MemoryCacheEntryOptions options = new MemoryCacheEntryOptions
MemoryCacheEntryOptions options = new()
{
AbsoluteExpirationRelativeToNow = SqlConnection.ColumnEncryptionKeyCacheTtl
};
_cache.Set<SqlClientSymmetricKey>(cacheLookupKey, encryptionKey, options);
_cache.Set(cacheLookupKey, encryptionKey, options);
}
}

Expand Down
Loading