diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs index 396e2c0dd5..a8683f7281 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs @@ -58,7 +58,7 @@ private async Task GetOrSelectManagedIdentitySourceAsyn if (s_sourceName == ManagedIdentitySource.None) { // First invocation: detect and cache - source = await GetManagedIdentitySourceAsync(requestContext).ConfigureAwait(false); + source = await GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested).ConfigureAwait(false); } else { @@ -101,13 +101,18 @@ private async Task GetOrSelectManagedIdentitySourceAsyn // Detect managed identity source based on the availability of environment variables and csr metadata probe request. // This method is perf sensitive any changes should be benchmarked. - internal async Task GetManagedIdentitySourceAsync(RequestContext requestContext) + internal async Task GetManagedIdentitySourceAsync( + RequestContext requestContext, + bool isMtlsPopRequested) { // First check env vars to avoid the probe if possible ManagedIdentitySource source = GetManagedIdentitySourceNoImdsV2(requestContext.Logger); - // If a source is detected via env vars, use it - if (source != ManagedIdentitySource.DefaultToImds) + // If a source is detected via env vars, or + // a source wasn't detected (it defaulted to ImdsV1) and MtlsPop was NOT requested, + // use the source. + // (don't trigger the ImdsV2 probe endpoint if MtlsPop was NOT requested) + if (source != ManagedIdentitySource.DefaultToImds || !isMtlsPopRequested) { s_sourceName = source; return source; diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs b/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs index 2144402c10..ea1fdb9c37 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs @@ -66,7 +66,8 @@ public async Task GetManagedIdentitySourceAsync() // Create a temporary RequestContext for the CSR metadata probe request. var csrMetadataProbeRequestContext = new RequestContext(this.ServiceBundle, Guid.NewGuid(), null, CancellationToken.None); - return await ManagedIdentityClient.GetManagedIdentitySourceAsync(csrMetadataProbeRequestContext).ConfigureAwait(false); + // GetManagedIdentitySourceAsync might return ImdsV2 = true, but it still requires .WithMtlsProofOfPossesion on the Managed Identity Application object to hit the ImdsV2 flow + return await ManagedIdentityClient.GetManagedIdentitySourceAsync(csrMetadataProbeRequestContext, isMtlsPopRequested: true).ConfigureAwait(false); } /// diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index e8cd8f8b44..ab9a5d2a4a 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -506,6 +506,62 @@ public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIs } } + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task ProbeDoesNotFireWhenMtlsPopNotRequested( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); + + ManagedIdentityApplicationBuilder miBuilder = null; + + var uami = userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null; + if (uami) + { + miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); + } + else + { + miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned); + } + + miBuilder + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + var managedIdentityApp = miBuilder.Build(); + + // mock probe to show ImdsV2 is available + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(userAssignedIdentityId: userAssignedIdentityId, userAssignedId: userAssignedId)); + + var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); + // this indicates ImdsV2 is available + Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource); + + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.Imds, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); + + // ImdsV1 flow will be used since .WithMtlsProofOfPossession() is not used here + var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource).ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + } + } + #region Cuid Tests [TestMethod] public void TestCsrGeneration_OnlyVmId() @@ -684,7 +740,7 @@ await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Res } #endregion - #region cached certificate tests + #region Cached certificate tests [TestMethod] public async Task mTLSPop_ForceRefresh_UsesCachedCert_NoIssueCredential_PostsCanonicalClientId_AndSkipsAttestation() { diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs index 7c119315da..7ae48f4a54 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs @@ -42,16 +42,6 @@ public class ManagedIdentityTests : TestBase private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); - private void AddImdsV2CsrMockHandlerIfNeeded( - ManagedIdentitySource managedIdentitySource, - MockHttpManager httpManager) - { - if (managedIdentitySource == ManagedIdentitySource.Imds) - { - httpManager.AddMockHandler(MockHelpers.MockCsrResponseFailure()); - } - } - [DataTestMethod] [DataRow("http://127.0.0.1:41564/msi/token/", ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)] [DataRow(AppServiceEndpoint, ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)] @@ -69,15 +59,11 @@ public async Task GetManagedIdentityTests( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - - - ManagedIdentityApplication mi = miBuilder.Build() as ManagedIdentityApplication; Assert.AreEqual(expectedManagedIdentitySource, await mi.GetManagedIdentitySourceAsync().ConfigureAwait(false)); @@ -106,7 +92,6 @@ public async Task SAMIHappyPathAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) @@ -155,7 +140,6 @@ public async Task UAMIHappyPathAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); @@ -203,7 +187,6 @@ public async Task ManagedIdentityDifferentScopesTestAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) @@ -261,7 +244,6 @@ public async Task ManagedIdentityForceRefreshTestAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) @@ -320,7 +302,6 @@ public async Task ManagedIdentityWithClaimsAndCapabilitiesTestAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) @@ -384,7 +365,6 @@ public async Task ManagedIdentityWithClaimsTestAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) @@ -456,7 +436,6 @@ public async Task ManagedIdentityTestWrongScopeAsync(string resource, ManagedIde using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) @@ -560,7 +539,6 @@ public async Task ManagedIdentityErrorResponseNoPayloadTestAsync(ManagedIdentity using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) @@ -600,7 +578,6 @@ public async Task ManagedIdentityNullResponseAsync(ManagedIdentitySource managed using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) @@ -638,7 +615,6 @@ public async Task ManagedIdentityUnreachableNetworkAsync(ManagedIdentitySource m using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) @@ -1036,7 +1012,6 @@ public async Task InvalidJsonResponseHandling(ManagedIdentitySource managedIdent using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder @@ -1077,7 +1052,6 @@ public async Task ManagedIdentityRequestTokensForDifferentScopesTestAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder @@ -1332,7 +1306,6 @@ public async Task ManagedIdentityWithCapabilitiesTestAsync( using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager); SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)