diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.IntegrationTests/EndtoEndPositiveTests.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.IntegrationTests/EndtoEndPositiveTests.cs index 1c43939b844e..75748c6074c9 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.IntegrationTests/EndtoEndPositiveTests.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.IntegrationTests/EndtoEndPositiveTests.cs @@ -173,10 +173,11 @@ private enum CertIdentifierType /// /// [Theory] - [InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier)] + [InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier, false)] + [InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier, true)] [InlineData(CertIdentifierType.SubjectName)] [InlineData(CertIdentifierType.Thumbprint)] - private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType certIdentifierType) + private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType certIdentifierType, bool useUserAssignedMsi = false) { string testCertUrl = Environment.GetEnvironmentVariable(Constants.TestCertUrlEnv); @@ -208,7 +209,9 @@ private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType connectionString = $"RunAs=App;AppId={app.AppId};TenantId={_tenantId};{thumbprintOrSubjectName};CertificateStoreLocation={Constants.CurrentUserStore};"; break; case CertIdentifierType.KeyVaultCertificateSecretIdentifier: - connectionString = $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};"; + connectionString = useUserAssignedMsi + ? $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};KeyVaultUserAssignedManagedIdentityId={Constants.TestUserAssignedManagedIdentityId}" //TODO: figure out real MSI to use here. Also, does the test really use MSI or does it rely on the fallback? + : $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};"; break; } diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.TestCommon/Constants.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.TestCommon/Constants.cs index 933cf214ad96..ee8bf44e60a0 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.TestCommon/Constants.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.TestCommon/Constants.cs @@ -71,6 +71,7 @@ public class Constants public static readonly string CertificateConnStringThumbprintCurrentUser = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};CertificateThumbprint=123;CertificateStoreLocation=CurrentUser"; public static readonly string CertificateConnStringSubjectNameCurrentUser = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};CertificateSubjectName=123;CertificateStoreLocation=CurrentUser"; public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifier = $"RunAs=App;AppId={TestAppId};KeyVaultCertificateSecretIdentifier=SecretIdentifier"; + public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi = $"RunAs=App;AppId={TestAppId};KeyVaultCertificateSecretIdentifier=SecretIdentifier;KeyVaultAppId={TestUserAssignedManagedIdentityId}"; public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};KeyVaultCertificateSecretIdentifier=SecretIdentifier"; public static readonly string ClientSecretConnString = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};AppKey={ClientSecret}"; public static readonly string ConnectionStringEnvironmentVariableName = "AzureServicesAuthConnectionString"; diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/AzureServiceTokenProviderFactoryTests.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/AzureServiceTokenProviderFactoryTests.cs index 54ccf7e5ee1c..3a1a979e9b38 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/AzureServiceTokenProviderFactoryTests.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/AzureServiceTokenProviderFactoryTests.cs @@ -194,6 +194,11 @@ public void CertValidTest() Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifier, provider.ConnectionString); Assert.IsType(provider); + provider = AzureServiceTokenProviderFactory.Create(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi, Constants.AzureAdInstance); + Assert.NotNull(provider); + Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi, provider.ConnectionString); + Assert.IsType(provider); + provider = AzureServiceTokenProviderFactory.Create(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId, Constants.AzureAdInstance); Assert.NotNull(provider); Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId, provider.ConnectionString); diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/ClientCertificateAccessTokenProviderTests.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/ClientCertificateAccessTokenProviderTests.cs index 60d01ce718c6..f7f7757500cc 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/ClientCertificateAccessTokenProviderTests.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/ClientCertificateAccessTokenProviderTests.cs @@ -38,7 +38,7 @@ public async Task ThumbprintSuccessTest() // Create ClientCertificateAzureServiceTokenProvider instance ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext); + cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext); // Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on thumbprint in the connection string. var authResult = await provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId).ConfigureAwait(false); @@ -64,7 +64,7 @@ public async Task ThumbprintFailTest() // Create ClientCertificateAzureServiceTokenProvider instance ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext); + cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext); // Ensure exception is thrown when getting the token var exception = await Assert.ThrowsAsync(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId)); @@ -89,12 +89,12 @@ public void ClientIdNullOrEmptyTest() // Create ClientCertificateAzureServiceTokenProvider instance var exception = Assert.Throws(() => new ClientCertificateAzureServiceTokenProvider(null, - cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext)); + cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext)); Assert.Contains(Constants.CannotBeNullError, exception.ToString()); exception = Assert.Throws(() => new ClientCertificateAzureServiceTokenProvider(string.Empty, - cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext)); + cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext)); Assert.Contains(Constants.CannotBeNullError, exception.ToString()); } @@ -114,12 +114,12 @@ public void StoreLocationNullOrEmptyTest() // Create ClientCertificateAzureServiceTokenProvider instance var exception = Assert.Throws(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - cert.Thumbprint, CertificateIdentifierType.Thumbprint, null, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext)); + cert.Thumbprint, CertificateIdentifierType.Thumbprint, null, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext)); Assert.Contains(Constants.CannotBeNullError, exception.ToString()); exception = Assert.Throws(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - cert.Thumbprint, CertificateIdentifierType.Thumbprint, string.Empty, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext)); + cert.Thumbprint, CertificateIdentifierType.Thumbprint, string.Empty, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext)); Assert.Contains(Constants.CannotBeNullError, exception.ToString()); } @@ -135,12 +135,12 @@ public void CertSubjectNameOrThumbprintNullOrEmptyTest() // Create ClientCertificateAzureServiceTokenProvider instance var exception = Assert.Throws(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - null, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext)); + null, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext)); Assert.Contains(Constants.CannotBeNullError, exception.ToString()); exception = Assert.Throws(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - string.Empty, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext)); + string.Empty, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext)); Assert.Contains(Constants.CannotBeNullError, exception.ToString()); } @@ -160,7 +160,7 @@ public void InvalidStoreLocationTest() // Create ClientCertificateAzureServiceTokenProvider instance var exception = Assert.Throws(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.InvalidString, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext)); + cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.InvalidString, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext)); Assert.Contains(Constants.InvalidCertLocationError, exception.ToString()); } @@ -177,7 +177,7 @@ public async Task SubjectNameSuccessTest() // Create ClientCertificateAzureServiceTokenProvider instance with a subject name ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext); + cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext); // Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string. var authResult = await provider.GetAuthResultAsync(Constants.KeyVaultResourceId, string.Empty).ConfigureAwait(false); @@ -204,7 +204,7 @@ public void CannotAcquireTokenThroughCertTest() // Create ClientCertificateAzureServiceTokenProvider instance with a subject name ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext); + cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext); // Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string. var exception = Assert.ThrowsAsync(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, string.Empty)); @@ -226,7 +226,7 @@ public async Task CertificateNotFoundTest() MockAuthenticationContext mockAuthenticationContext = new MockAuthenticationContext(MockAuthenticationContext.MockAuthenticationContextTestType.AcquireTokenAsyncClientCertificateSuccess); ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - Guid.NewGuid().ToString(), CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext); + Guid.NewGuid().ToString(), CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext); var exception = await Assert.ThrowsAsync(() => Task.Run(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId))); @@ -257,7 +257,7 @@ public async Task KeyVaultCertificateSecretIdentifierSuccessTest(bool includeTen // Create ClientCertificateAzureServiceTokenProvider instance with a subject name ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - Constants.TestKeyVaultCertificateSecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, null, Constants.AzureAdInstance, tenantIdParam, 0, mockAuthenticationContext, keyVaultClient); + Constants.TestKeyVaultCertificateSecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, null, Constants.AzureAdInstance, tenantIdParam, 0, authenticationContext: mockAuthenticationContext, keyVaultClient: keyVaultClient); // Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string. var authResult = await provider.GetAuthResultAsync(Constants.ArmResourceId, string.Empty).ConfigureAwait(false); @@ -283,7 +283,7 @@ public async Task KeyVaultCertificateNotFoundTest() string SecretIdentifier = "https://testbedkeyvault.vault.azure.net/secrets/secret/"; ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId, - SecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext, keyVaultClient); + SecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext, keyVaultClient: keyVaultClient); var exception = await Assert.ThrowsAsync(() => Task.Run(() => provider.GetAuthResultAsync(Constants.ArmResourceId, Constants.TenantId))); diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/AzureServiceTokenProviderFactory.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/AzureServiceTokenProviderFactory.cs index a8934af8740e..3a8465c20031 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/AzureServiceTokenProviderFactory.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/AzureServiceTokenProviderFactory.cs @@ -27,6 +27,7 @@ internal class AzureServiceTokenProviderFactory private const string CertificateSubjectName = "CertificateSubjectName"; private const string CertificateThumbprint = "CertificateThumbprint"; private const string KeyVaultCertificateSecretIdentifier = "KeyVaultCertificateSecretIdentifier"; + private const string KeyVaultUserAssignedManagedIdentityId = "KeyVaultUserAssignedManagedIdentityId"; private const string CertificateStoreLocation = "CertificateStoreLocation"; private const string MsiRetryTimeout = "MsiRetryTimeout"; @@ -125,7 +126,7 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec azureAdInstance, connectionSettings[TenantId], 0, - new AdalAuthenticationContext(httpClientFactory)); + authenticationContext: new AdalAuthenticationContext(httpClientFactory)); } else if (connectionSettings.ContainsKey(CertificateThumbprint) || connectionSettings.ContainsKey(CertificateSubjectName)) @@ -138,6 +139,11 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec { ValidateMsiRetryTimeout(connectionSettings, connectionString); + var msiRetryTimeout = connectionSettings.ContainsKey(MsiRetryTimeout) + ? int.Parse(connectionSettings[MsiRetryTimeout]) + : 0; + connectionSettings.TryGetValue(KeyVaultUserAssignedManagedIdentityId, out var keyVaultUserAssignedManagedIdentityId); + azureServiceTokenProvider = new ClientCertificateAzureServiceTokenProvider( connectionSettings[AppId], @@ -148,9 +154,8 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec connectionSettings.ContainsKey(TenantId) // tenantId can be specified in connection string or retrieved from Key Vault access token later ? connectionSettings[TenantId] : default, - connectionSettings.ContainsKey(MsiRetryTimeout) - ? int.Parse(connectionSettings[MsiRetryTimeout]) - : 0, + msiRetryTimeout, + keyVaultUserAssignedManagedIdentityId, new AdalAuthenticationContext(httpClientFactory)); } else if (connectionSettings.ContainsKey(AppKey)) diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/Clients/KeyVault/KeyVaultClient.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/Clients/KeyVault/KeyVaultClient.cs index 6757f48dcaf7..c5b7e129e687 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/Clients/KeyVault/KeyVaultClient.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/Clients/KeyVault/KeyVaultClient.cs @@ -20,6 +20,9 @@ internal class KeyVaultClient private readonly HttpClient _httpClient; private NonInteractiveAzureServiceTokenProviderBase _tokenProvider; + // In case of User assigned MSI, this needs to be provided + private string _managedIdentityClientId; + private const string KeyVaultRestApiVersion = "2016-10-01"; // Error messages @@ -35,7 +38,7 @@ internal class KeyVaultClient internal Principal PrincipalUsed { get; private set; } - internal KeyVaultClient(int msiRetryTimeoutInSeconds = 0, HttpClient httpClient = null, NonInteractiveAzureServiceTokenProviderBase tokenProvider = null) + internal KeyVaultClient(int msiRetryTimeoutInSeconds = 0, string managedIdentityClientId = null, HttpClient httpClient = null, NonInteractiveAzureServiceTokenProviderBase tokenProvider = null) { _msiRetryTimeoutInSeconds = msiRetryTimeoutInSeconds; #if NETSTANDARD1_4 || net452 || net461 @@ -44,9 +47,10 @@ internal KeyVaultClient(int msiRetryTimeoutInSeconds = 0, HttpClient httpClient _httpClient = httpClient ?? new HttpClient(new HttpClientHandler() { CheckCertificateRevocationList = true }); #endif _tokenProvider = tokenProvider; + _managedIdentityClientId = managedIdentityClientId; } - internal KeyVaultClient(HttpClient httpClient, NonInteractiveAzureServiceTokenProviderBase tokenProvider = null) : this(0, httpClient, tokenProvider) + internal KeyVaultClient(HttpClient httpClient, NonInteractiveAzureServiceTokenProviderBase tokenProvider = null, string managedIdentityClientId = null) : this(0, managedIdentityClientId, httpClient, tokenProvider) { } @@ -186,7 +190,7 @@ private List GetTokenProviders(stri string azureAdInstance = UriHelper.GetAzureAdInstanceByAuthority(authority); tokenProviders = new List { - new MsiAccessTokenProvider(_msiRetryTimeoutInSeconds), + new MsiAccessTokenProvider(_msiRetryTimeoutInSeconds, _managedIdentityClientId), new VisualStudioAccessTokenProvider(new ProcessManager()), new AzureCliAccessTokenProvider(new ProcessManager()), #if FullNetFx diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs index 7778b10ae0c4..2e67a90ac9b9 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs @@ -56,6 +56,7 @@ internal enum CertificateIdentifierType internal ClientCertificateAzureServiceTokenProvider(string clientId, string certificateIdentifier, CertificateIdentifierType certificateIdentifierType, string storeLocation, string azureAdInstance, string tenantId = default, int msiRetryTimeoutInSeconds = 0, + string keyVaultUserAssignedManagedIdentityId = null, IAuthenticationContext authenticationContext = null, KeyVaultClient keyVaultClient = null) { if (string.IsNullOrWhiteSpace(clientId)) @@ -89,6 +90,10 @@ internal ClientCertificateAzureServiceTokenProvider(string clientId, $"StoreLocation {storeLocation} is not valid. Valid values are CurrentUser and LocalMachine."); } } + else + { + _keyVaultClient = keyVaultClient ?? new KeyVaultClient(msiRetryTimeoutInSeconds, keyVaultUserAssignedManagedIdentityId); + } _clientId = clientId; @@ -96,7 +101,6 @@ internal ClientCertificateAzureServiceTokenProvider(string clientId, _azureAdInstance = azureAdInstance; _tenantId = tenantId; _authenticationContext = authenticationContext ?? new AdalAuthenticationContext(); - _keyVaultClient = keyVaultClient ?? new KeyVaultClient(msiRetryTimeoutInSeconds); _certificateIdentifier = certificateIdentifier;