Skip to content

Commit 6ae13e6

Browse files
danicolednicolescu
andauthored
Add the option to have UserAssignedMsi when authentication uses KeyVault certificate. Previously, the code would only use the system assigned MSI if any. (#14676)
Co-authored-by: dnicolescu <[email protected]>
1 parent 3cc9799 commit 6ae13e6

File tree

7 files changed

+47
-25
lines changed

7 files changed

+47
-25
lines changed

sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.IntegrationTests/EndtoEndPositiveTests.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ private enum CertIdentifierType
173173
/// <param name="certIdentifierType"></param>
174174
/// <returns></returns>
175175
[Theory]
176-
[InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier)]
176+
[InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier, false)]
177+
[InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier, true)]
177178
[InlineData(CertIdentifierType.SubjectName)]
178179
[InlineData(CertIdentifierType.Thumbprint)]
179-
private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType certIdentifierType)
180+
private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType certIdentifierType, bool useUserAssignedMsi = false)
180181
{
181182
string testCertUrl = Environment.GetEnvironmentVariable(Constants.TestCertUrlEnv);
182183

@@ -208,7 +209,9 @@ private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType
208209
connectionString = $"RunAs=App;AppId={app.AppId};TenantId={_tenantId};{thumbprintOrSubjectName};CertificateStoreLocation={Constants.CurrentUserStore};";
209210
break;
210211
case CertIdentifierType.KeyVaultCertificateSecretIdentifier:
211-
connectionString = $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};";
212+
connectionString = useUserAssignedMsi
213+
? $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};KeyVaultUserAssignedManagedIdentityId={Constants.TestUserAssignedManagedIdentityId}" //TODO: figure out real MSI to use here. Also, does the test really use MSI or does it rely on the fallback?
214+
: $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};";
212215
break;
213216
}
214217

sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.TestCommon/Constants.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ public class Constants
7171
public static readonly string CertificateConnStringThumbprintCurrentUser = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};CertificateThumbprint=123;CertificateStoreLocation=CurrentUser";
7272
public static readonly string CertificateConnStringSubjectNameCurrentUser = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};CertificateSubjectName=123;CertificateStoreLocation=CurrentUser";
7373
public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifier = $"RunAs=App;AppId={TestAppId};KeyVaultCertificateSecretIdentifier=SecretIdentifier";
74+
public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi = $"RunAs=App;AppId={TestAppId};KeyVaultCertificateSecretIdentifier=SecretIdentifier;KeyVaultAppId={TestUserAssignedManagedIdentityId}";
7475
public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};KeyVaultCertificateSecretIdentifier=SecretIdentifier";
7576
public static readonly string ClientSecretConnString = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};AppKey={ClientSecret}";
7677
public static readonly string ConnectionStringEnvironmentVariableName = "AzureServicesAuthConnectionString";

sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/AzureServiceTokenProviderFactoryTests.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,11 @@ public void CertValidTest()
194194
Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifier, provider.ConnectionString);
195195
Assert.IsType<ClientCertificateAzureServiceTokenProvider>(provider);
196196

197+
provider = AzureServiceTokenProviderFactory.Create(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi, Constants.AzureAdInstance);
198+
Assert.NotNull(provider);
199+
Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi, provider.ConnectionString);
200+
Assert.IsType<ClientCertificateAzureServiceTokenProvider>(provider);
201+
197202
provider = AzureServiceTokenProviderFactory.Create(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId, Constants.AzureAdInstance);
198203
Assert.NotNull(provider);
199204
Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId, provider.ConnectionString);

sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication.Unit.Tests/ClientCertificateAccessTokenProviderTests.cs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public async Task ThumbprintSuccessTest()
3838

3939
// Create ClientCertificateAzureServiceTokenProvider instance
4040
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
41-
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
41+
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);
4242

4343
// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on thumbprint in the connection string.
4444
var authResult = await provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId).ConfigureAwait(false);
@@ -64,7 +64,7 @@ public async Task ThumbprintFailTest()
6464

6565
// Create ClientCertificateAzureServiceTokenProvider instance
6666
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
67-
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
67+
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);
6868

6969
// Ensure exception is thrown when getting the token
7070
var exception = await Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId));
@@ -89,12 +89,12 @@ public void ClientIdNullOrEmptyTest()
8989

9090
// Create ClientCertificateAzureServiceTokenProvider instance
9191
var exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(null,
92-
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
92+
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));
9393

9494
Assert.Contains(Constants.CannotBeNullError, exception.ToString());
9595

9696
exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(string.Empty,
97-
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
97+
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));
9898

9999
Assert.Contains(Constants.CannotBeNullError, exception.ToString());
100100
}
@@ -114,12 +114,12 @@ public void StoreLocationNullOrEmptyTest()
114114

115115
// Create ClientCertificateAzureServiceTokenProvider instance
116116
var exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
117-
cert.Thumbprint, CertificateIdentifierType.Thumbprint, null, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
117+
cert.Thumbprint, CertificateIdentifierType.Thumbprint, null, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));
118118

119119
Assert.Contains(Constants.CannotBeNullError, exception.ToString());
120120

121121
exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
122-
cert.Thumbprint, CertificateIdentifierType.Thumbprint, string.Empty, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
122+
cert.Thumbprint, CertificateIdentifierType.Thumbprint, string.Empty, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));
123123

124124
Assert.Contains(Constants.CannotBeNullError, exception.ToString());
125125
}
@@ -135,12 +135,12 @@ public void CertSubjectNameOrThumbprintNullOrEmptyTest()
135135

136136
// Create ClientCertificateAzureServiceTokenProvider instance
137137
var exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
138-
null, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
138+
null, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));
139139

140140
Assert.Contains(Constants.CannotBeNullError, exception.ToString());
141141

142142
exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
143-
string.Empty, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
143+
string.Empty, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));
144144

145145
Assert.Contains(Constants.CannotBeNullError, exception.ToString());
146146
}
@@ -160,7 +160,7 @@ public void InvalidStoreLocationTest()
160160

161161
// Create ClientCertificateAzureServiceTokenProvider instance
162162
var exception = Assert.Throws<ArgumentException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
163-
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.InvalidString, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
163+
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.InvalidString, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));
164164

165165
Assert.Contains(Constants.InvalidCertLocationError, exception.ToString());
166166
}
@@ -177,7 +177,7 @@ public async Task SubjectNameSuccessTest()
177177

178178
// Create ClientCertificateAzureServiceTokenProvider instance with a subject name
179179
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
180-
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
180+
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);
181181

182182
// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string.
183183
var authResult = await provider.GetAuthResultAsync(Constants.KeyVaultResourceId, string.Empty).ConfigureAwait(false);
@@ -204,7 +204,7 @@ public void CannotAcquireTokenThroughCertTest()
204204

205205
// Create ClientCertificateAzureServiceTokenProvider instance with a subject name
206206
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
207-
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
207+
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);
208208

209209
// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string.
210210
var exception = Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, string.Empty));
@@ -226,7 +226,7 @@ public async Task CertificateNotFoundTest()
226226
MockAuthenticationContext mockAuthenticationContext = new MockAuthenticationContext(MockAuthenticationContext.MockAuthenticationContextTestType.AcquireTokenAsyncClientCertificateSuccess);
227227

228228
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
229-
Guid.NewGuid().ToString(), CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
229+
Guid.NewGuid().ToString(), CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);
230230

231231
var exception = await Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => Task.Run(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId)));
232232

@@ -257,7 +257,7 @@ public async Task KeyVaultCertificateSecretIdentifierSuccessTest(bool includeTen
257257

258258
// Create ClientCertificateAzureServiceTokenProvider instance with a subject name
259259
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
260-
Constants.TestKeyVaultCertificateSecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, null, Constants.AzureAdInstance, tenantIdParam, 0, mockAuthenticationContext, keyVaultClient);
260+
Constants.TestKeyVaultCertificateSecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, null, Constants.AzureAdInstance, tenantIdParam, 0, authenticationContext: mockAuthenticationContext, keyVaultClient: keyVaultClient);
261261

262262
// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string.
263263
var authResult = await provider.GetAuthResultAsync(Constants.ArmResourceId, string.Empty).ConfigureAwait(false);
@@ -283,7 +283,7 @@ public async Task KeyVaultCertificateNotFoundTest()
283283

284284
string SecretIdentifier = "https://testbedkeyvault.vault.azure.net/secrets/secret/";
285285
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
286-
SecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext, keyVaultClient);
286+
SecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext, keyVaultClient: keyVaultClient);
287287

288288
var exception = await Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => Task.Run(() => provider.GetAuthResultAsync(Constants.ArmResourceId, Constants.TenantId)));
289289

sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/AzureServiceTokenProviderFactory.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ internal class AzureServiceTokenProviderFactory
2727
private const string CertificateSubjectName = "CertificateSubjectName";
2828
private const string CertificateThumbprint = "CertificateThumbprint";
2929
private const string KeyVaultCertificateSecretIdentifier = "KeyVaultCertificateSecretIdentifier";
30+
private const string KeyVaultUserAssignedManagedIdentityId = "KeyVaultUserAssignedManagedIdentityId";
3031
private const string CertificateStoreLocation = "CertificateStoreLocation";
3132
private const string MsiRetryTimeout = "MsiRetryTimeout";
3233

@@ -125,7 +126,7 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec
125126
azureAdInstance,
126127
connectionSettings[TenantId],
127128
0,
128-
new AdalAuthenticationContext(httpClientFactory));
129+
authenticationContext: new AdalAuthenticationContext(httpClientFactory));
129130
}
130131
else if (connectionSettings.ContainsKey(CertificateThumbprint) ||
131132
connectionSettings.ContainsKey(CertificateSubjectName))
@@ -138,6 +139,11 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec
138139
{
139140
ValidateMsiRetryTimeout(connectionSettings, connectionString);
140141

142+
var msiRetryTimeout = connectionSettings.ContainsKey(MsiRetryTimeout)
143+
? int.Parse(connectionSettings[MsiRetryTimeout])
144+
: 0;
145+
connectionSettings.TryGetValue(KeyVaultUserAssignedManagedIdentityId, out var keyVaultUserAssignedManagedIdentityId);
146+
141147
azureServiceTokenProvider =
142148
new ClientCertificateAzureServiceTokenProvider(
143149
connectionSettings[AppId],
@@ -148,9 +154,8 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec
148154
connectionSettings.ContainsKey(TenantId) // tenantId can be specified in connection string or retrieved from Key Vault access token later
149155
? connectionSettings[TenantId]
150156
: default,
151-
connectionSettings.ContainsKey(MsiRetryTimeout)
152-
? int.Parse(connectionSettings[MsiRetryTimeout])
153-
: 0,
157+
msiRetryTimeout,
158+
keyVaultUserAssignedManagedIdentityId,
154159
new AdalAuthenticationContext(httpClientFactory));
155160
}
156161
else if (connectionSettings.ContainsKey(AppKey))

0 commit comments

Comments
 (0)