diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs index a376a74e70..0dfb673cbe 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs @@ -21,10 +21,13 @@ internal class ManagedIdentityClient { private const string WindowsHimdsFilePath = "%Programfiles%\\AzureConnectedMachineAgent\\himds.exe"; private const string LinuxHimdsFilePath = "/opt/azcmagent/bin/himds"; - internal static ManagedIdentitySource s_sourceName = ManagedIdentitySource.None; + // Preview guard: once we fall back to IMDSv1 while IMDSv2 is cached, // disallow switching to IMDSv2 PoP in the same process (preview behavior). internal static bool s_imdsV1UsedForPreview = false; + // Non-null only after the explicit discovery API (GetManagedIdentitySourceAsync) runs. + // Allows caching "NoneFound" (Source=None) without confusing it with "not discovered yet". + private static ManagedIdentitySourceResult s_cachedSourceResult = null; // Holds the most recently minted mTLS binding certificate for this application instance. private X509Certificate2 _runtimeMtlsBindingCertificate; @@ -32,7 +35,7 @@ internal class ManagedIdentityClient internal static void ResetSourceForTest() { - s_sourceName = ManagedIdentitySource.None; + s_cachedSourceResult = null; s_imdsV1UsedForPreview = false; // Clear cert caches so each test starts fresh @@ -48,42 +51,66 @@ internal async Task SendTokenRequestForManagedIdentityA return await msi.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false); } - // This method tries to create managed identity source for different sources, if none is created then defaults to IMDS. - private async Task GetOrSelectManagedIdentitySourceAsync( + // This method selects the managed identity source for token acquisition. + // It does NOT probe IMDS. It uses the cached explicit discovery result if available, + // otherwise checks environment variables, and defaults to IMDS without probing. + private Task GetOrSelectManagedIdentitySourceAsync( RequestContext requestContext, bool isMtlsPopRequested, CancellationToken cancellationToken) { using (requestContext.Logger.LogMethodDuration()) { - requestContext.Logger.Info($"[Managed Identity] Selecting managed identity source if not cached. Cached value is {s_sourceName} "); + requestContext.Logger.Info($"[Managed Identity] Selecting managed identity source. " + + $"Discovery cached: {s_cachedSourceResult != null}"); + + // Fail fast if cancellation was requested, before performing expensive network probes + cancellationToken.ThrowIfCancellationRequested(); - ManagedIdentitySourceResult sourceResult = null; ManagedIdentitySource source; - // If the source is not already set, determine it - if (s_sourceName == ManagedIdentitySource.None) + if (s_cachedSourceResult != null) { - // First invocation: detect and cache - sourceResult = await GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested, cancellationToken).ConfigureAwait(false); - source = sourceResult.Source; + // Use the cached explicit discovery result (including NoneFound) + source = s_cachedSourceResult.Source; + requestContext.Logger.Info($"[Managed Identity] Using cached discovery result: {source}"); } else { - // Reuse cached value - source = s_sourceName; + // Standard path: check environment variables only, no IMDS probing + source = GetManagedIdentitySourceNoImds(requestContext.Logger); + + if (source == ManagedIdentitySource.None) + { + // No environment-based source found; default to IMDS based on mTLS PoP flag + if (isMtlsPopRequested) + { + // Route mTLS PoP requests directly to IMDSv2 (no probing) + requestContext.Logger.Info("[Managed Identity] mTLS PoP requested, routing to IMDSv2 directly without probing."); + return Task.FromResult(ImdsV2ManagedIdentitySource.Create(requestContext)); + } + + // Default to IMDSv1 without probing + requestContext.Logger.Info("[Managed Identity] Defaulting to IMDSv1 without probing."); + return Task.FromResult(ImdsManagedIdentitySource.Create(requestContext)); + } + } + + // Handle NoneFound from cached discovery + if (source == ManagedIdentitySource.None) + { + throw CreateManagedIdentityUnavailableException(s_cachedSourceResult); } - // If the source has already been set to ImdsV2 (via this method, or GetManagedIdentitySourceAsync in ManagedIdentityApplication.cs) - // and mTLS PoP was NOT requested: fall back to ImdsV1, because ImdsV2 currently only supports mTLS PoP requests + // Preview fallback: if ImdsV2 is cached but mTLS PoP not requested, fall back per-request to ImdsV1 if (source == ManagedIdentitySource.ImdsV2 && !isMtlsPopRequested) { requestContext.Logger.Info("[Managed Identity] ImdsV2 detected, but mTLS PoP was not requested. Falling back to ImdsV1 for this request only. Please use the \"WithMtlsProofOfPossession\" API to request a token via ImdsV2."); - + // Mark that we used IMDSv1 in this process while IMDSv2 is cached (preview behavior). s_imdsV1UsedForPreview = true; - // Do NOT modify s_sourceName; keep cached ImdsV2 so future PoP + // Do NOT modify s_cachedSourceResult; keep cached ImdsV2 so future PoP // requests can leverage it. source = ManagedIdentitySource.Imds; } @@ -106,7 +133,7 @@ private async Task GetOrSelectManagedIdentitySourceAsyn MsalErrorMessage.MtlsPopTokenNotSupportedinImdsV1); } - return source switch + return Task.FromResult(source switch { ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext), ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(requestContext), @@ -115,63 +142,65 @@ private async Task GetOrSelectManagedIdentitySourceAsyn ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext), ManagedIdentitySource.ImdsV2 => ImdsV2ManagedIdentitySource.Create(requestContext), ManagedIdentitySource.Imds => ImdsManagedIdentitySource.Create(requestContext), - _ => throw CreateManagedIdentityUnavailableException(sourceResult) - }; + _ => throw CreateManagedIdentityUnavailableException(s_cachedSourceResult) + }); } } - // Detect managed identity source based on the availability of environment variables and csr metadata probe request. - // This method is perf sensitive any changes should be benchmarked. + private static ManagedIdentitySourceResult CacheDiscoveryResult(ManagedIdentitySourceResult result) + { + s_cachedSourceResult = result; + return result; + } + + // Detect managed identity source by probing IMDS endpoints. + // This method is called only by the explicit discovery path (GetManagedIdentitySourceAsync in ManagedIdentityApplication.cs). + // It probes IMDS v1 first, then v2 if v1 fails, and caches the result. internal async Task GetManagedIdentitySourceAsync( RequestContext requestContext, - bool isMtlsPopRequested, CancellationToken cancellationToken) { + // Return cached result if explicit discovery already ran + if (s_cachedSourceResult != null) + { + return s_cachedSourceResult; + } + // First check env vars to avoid the probe if possible ManagedIdentitySource source = GetManagedIdentitySourceNoImds(requestContext.Logger); + if (source != ManagedIdentitySource.None) { - s_sourceName = source; - return new ManagedIdentitySourceResult(source); + return CacheDiscoveryResult(new ManagedIdentitySourceResult(source)); } - string imdsV2FailureReason = null; string imdsV1FailureReason = null; + string imdsV2FailureReason = null; - // skip the ImdsV2 probe if MtlsPop was NOT requested - if (isMtlsPopRequested) - { - var (imdsV2Success, imdsV2Failure) = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V2, cancellationToken).ConfigureAwait(false); - if (imdsV2Success) - { - requestContext.Logger.Info("[Managed Identity] ImdsV2 detected."); - s_sourceName = ManagedIdentitySource.ImdsV2; - return new ManagedIdentitySourceResult(s_sourceName); - } - imdsV2FailureReason = imdsV2Failure; - } - else - { - requestContext.Logger.Info("[Managed Identity] Mtls Pop was not requested; skipping ImdsV2 probe."); - } - + // Probe IMDS v1 first var (imdsV1Success, imdsV1Failure) = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V1, cancellationToken).ConfigureAwait(false); if (imdsV1Success) { requestContext.Logger.Info("[Managed Identity] ImdsV1 detected."); - s_sourceName = ManagedIdentitySource.Imds; - return new ManagedIdentitySourceResult(s_sourceName); + return CacheDiscoveryResult(new ManagedIdentitySourceResult(ManagedIdentitySource.Imds)); } imdsV1FailureReason = imdsV1Failure; + // If v1 fails, probe IMDS v2 + var (imdsV2Success, imdsV2Failure) = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V2, cancellationToken).ConfigureAwait(false); + if (imdsV2Success) + { + requestContext.Logger.Info("[Managed Identity] ImdsV2 detected."); + return CacheDiscoveryResult(new ManagedIdentitySourceResult(ManagedIdentitySource.ImdsV2)); + } + imdsV2FailureReason = imdsV2Failure; + requestContext.Logger.Info($"[Managed Identity] {MsalErrorMessage.ManagedIdentityAllSourcesUnavailable}"); - s_sourceName = ManagedIdentitySource.None; - - return new ManagedIdentitySourceResult(s_sourceName) + return CacheDiscoveryResult(new ManagedIdentitySourceResult(ManagedIdentitySource.None) { ImdsV1FailureReason = imdsV1FailureReason, ImdsV2FailureReason = imdsV2FailureReason - }; + }); } /// @@ -252,8 +281,8 @@ private static bool ValidateAzureArcEnvironment(string identityEndpoint, string { logger?.Verbose(() => "[Managed Identity] Azure Arc managed identity is available through file detection."); return true; - } - + } + logger?.Verbose(() => "[Managed Identity] Azure Arc managed identity is not available."); return false; } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 000d40d31b..93f8fbc441 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -96,6 +96,18 @@ private static void ThrowCsrMetadataRequestException( Exception ex = null, int? statusCode = null) { + // A 404 from the IMDSv2 CSR endpoint indicates that the host supports only IMDSv1. + // This happens when WithMtlsProofOfPossession() is used without a prior + // GetManagedIdentitySourceAsync() call: MSAL routes directly to IMDSv2, and + // on an IMDSv1-only host the /getplatformmetadata endpoint does not exist. + // Translate to a client error so callers know mTLS PoP is not supported here. + if (statusCode == (int)HttpStatusCode.NotFound) + { + throw new MsalClientException( + MsalError.MtlsPopTokenNotSupportedinImdsV1, + MsalErrorMessage.MtlsPopTokenNotSupportedinImdsV1); + } + throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, $"[ImdsV2] {errorMessage}", @@ -154,8 +166,8 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) return new ImdsV2ManagedIdentitySource(requestContext); } - internal ImdsV2ManagedIdentitySource(RequestContext requestContext) - : this(requestContext, + internal ImdsV2ManagedIdentitySource(RequestContext requestContext) + : this(requestContext, new MtlsBindingCache(s_mtlsCertificateCache, PersistentCertificateCacheFactory .Create(requestContext.Logger))) { @@ -436,9 +448,9 @@ protected override async Task CreateRequestAsync(string /// Cancellation token. /// JWT string suitable for the IMDSv2 attested PoP flow, or null for non-attested flow. private async Task GetAttestationJwtAsync( - string clientId, - Uri attestationEndpoint, - ManagedIdentityKeyInfo keyInfo, + string clientId, + Uri attestationEndpoint, + ManagedIdentityKeyInfo keyInfo, CancellationToken cancellationToken) { // Check if attestation token provider has been configured diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs b/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs index c2016a63e2..05a771759b 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs @@ -25,7 +25,7 @@ public sealed class ManagedIdentityApplication IManagedIdentityApplication { internal ManagedIdentityClient ManagedIdentityClient { get; } - + internal ManagedIdentityApplication( ApplicationConfiguration configuration) : base(configuration) @@ -34,8 +34,8 @@ internal ManagedIdentityApplication( AppTokenCacheInternal = configuration.AppTokenCacheInternalForTest ?? new TokenCache(ServiceBundle, true); - this.ServiceBundle.ApplicationLogger.Verbose(()=>$"ManagedIdentityApplication {configuration.GetHashCode()} created"); - + this.ServiceBundle.ApplicationLogger.Verbose(() => $"ManagedIdentityApplication {configuration.GetHashCode()} created"); + ManagedIdentityClient = new ManagedIdentityClient(); } @@ -58,16 +58,10 @@ public AcquireTokenForManagedIdentityParameterBuilder AcquireTokenForManagedIden /// public async Task GetManagedIdentitySourceAsync(CancellationToken cancellationToken) { - if (ManagedIdentityClient.s_sourceName != ManagedIdentitySource.None) - { - return new ManagedIdentitySourceResult(ManagedIdentityClient.s_sourceName); - } - // Create a temporary RequestContext for the logger and the IMDS probe request. var requestContext = new RequestContext(this.ServiceBundle, Guid.NewGuid(), null, cancellationToken); - // GetManagedIdentitySourceAsync might return ImdsV2 = true, but it still requires .WithMtlsProofOfPossesion on the Managed Identity Application object to hit the ImdsV2 flow - return await ManagedIdentityClient.GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested: true, cancellationToken).ConfigureAwait(false); + return await ManagedIdentityClient.GetManagedIdentitySourceAsync(requestContext, cancellationToken).ConfigureAwait(false); } /// @@ -78,7 +72,7 @@ public async Task GetManagedIdentitySourceAsync(Can public static ManagedIdentitySource GetManagedIdentitySource() { var source = ManagedIdentityClient.GetManagedIdentitySourceNoImds(); - + return source == ManagedIdentitySource.None #pragma warning disable CS0618 // ManagedIdentitySource.DefaultToImds is marked obsolete, but is intentionally used here as a sentinel value to support legacy detection logic. diff --git a/tests/Microsoft.Identity.Test.E2e/ManagedIdentityImdsV2Tests.cs b/tests/Microsoft.Identity.Test.E2e/ManagedIdentityImdsV2Tests.cs index 6160da1f8c..acc0ab37cc 100644 --- a/tests/Microsoft.Identity.Test.E2e/ManagedIdentityImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.E2e/ManagedIdentityImdsV2Tests.cs @@ -68,13 +68,6 @@ public async Task AcquireToken_OnImdsV2_MtlsPoP_WithAttestation_Succeeds(string Assert.Inconclusive("Credential Guard attestation is only available on Windows."); } - // Check if TOKEN_ATTESTATION_ENDPOINT is configured (required for attestation) - var attestationEndpoint = Environment.GetEnvironmentVariable("TOKEN_ATTESTATION_ENDPOINT"); - if (string.IsNullOrWhiteSpace(attestationEndpoint)) - { - Assert.Inconclusive("TOKEN_ATTESTATION_ENDPOINT is not configured. Attestation tests require this environment variable."); - } - var mi = BuildMi(id, idType); try diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs index ade71e97b4..f21628fdbc 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs @@ -43,7 +43,6 @@ public async Task ImdsFails404TwiceThenSucceeds200Async( IManagedIdentityApplication mi = miBuilder.Build(); - ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); // Simulate two 404s (to trigger retries), then a successful response const int Num404Errors = 2; @@ -91,7 +90,7 @@ public async Task ImdsFails410FourTimesThenSucceeds200Async( using (var httpManager = new MockHttpManager()) { SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); - + ManagedIdentityId managedIdentityId = userAssignedId == null ? ManagedIdentityId.SystemAssigned : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); @@ -101,7 +100,6 @@ public async Task ImdsFails410FourTimesThenSucceeds200Async( IManagedIdentityApplication mi = miBuilder.Build(); - ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); // Simulate four 410s (to trigger retries), then a successful response const int Num410Errors = 4; @@ -149,7 +147,7 @@ public async Task ImdsFails410PermanentlyAsync( using (var httpManager = new MockHttpManager()) { SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); - + ManagedIdentityId managedIdentityId = userAssignedId == null ? ManagedIdentityId.SystemAssigned : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); @@ -159,7 +157,6 @@ public async Task ImdsFails410PermanentlyAsync( IManagedIdentityApplication mi = miBuilder.Build(); - ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); // Simulate permanent 410s (to trigger the maximum number of retries) const int Num410Errors = 1 + TestImdsRetryPolicy.LinearStrategyNumRetries; // initial request + maximum number of retries @@ -204,7 +201,7 @@ public async Task ImdsFails504PermanentlyAsync( using (var httpManager = new MockHttpManager()) { SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); - + ManagedIdentityId managedIdentityId = userAssignedId == null ? ManagedIdentityId.SystemAssigned : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); @@ -214,7 +211,6 @@ public async Task ImdsFails504PermanentlyAsync( IManagedIdentityApplication mi = miBuilder.Build(); - ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); // Simulate permanent 504s (to trigger the maximum number of retries) const int Num504Errors = 1 + TestImdsRetryPolicy.ExponentialStrategyNumRetries; // initial request + maximum number of retries @@ -259,7 +255,7 @@ public async Task ImdsFails400WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsy using (var httpManager = new MockHttpManager()) { SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); - + ManagedIdentityId managedIdentityId = userAssignedId == null ? ManagedIdentityId.SystemAssigned : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); @@ -269,7 +265,6 @@ public async Task ImdsFails400WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsy IManagedIdentityApplication mi = miBuilder.Build(); - ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); httpManager.AddManagedIdentityMockHandler( ManagedIdentityTests.ImdsEndpoint, @@ -310,7 +305,7 @@ public async Task ImdsFails500AndRetryPolicyIsDisabledAndNotTriggeredAsync( using (var httpManager = new MockHttpManager(disableInternalRetries: true)) { SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); - + ManagedIdentityId managedIdentityId = userAssignedId == null ? ManagedIdentityId.SystemAssigned : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); @@ -320,7 +315,6 @@ public async Task ImdsFails500AndRetryPolicyIsDisabledAndNotTriggeredAsync( IManagedIdentityApplication mi = miBuilder.Build(); - ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds, userAssignedIdentityId, userAssignedId); httpManager.AddManagedIdentityMockHandler( ManagedIdentityTests.ImdsEndpoint, @@ -351,7 +345,7 @@ await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) } [TestMethod] - + public async Task ImdsRetryPolicyLifeTimeIsPerRequestAsync() { using (new EnvVariableContext()) @@ -365,7 +359,6 @@ public async Task ImdsRetryPolicyLifeTimeIsPerRequestAsync() IManagedIdentityApplication mi = miBuilder.Build(); - ManagedIdentityTests.MockImdsV1Probe(httpManager, ManagedIdentitySource.Imds); // Simulate permanent errors (to trigger the maximum number of retries) const int Num504Errors = 1 + TestImdsRetryPolicy.ExponentialStrategyNumRetries; // initial request + maximum number of retries diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 902080bc17..5527b1d912 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -121,13 +121,22 @@ private async Task CreateManagedIdentityAsync( if (imdsVersion == ImdsVersion.V1) { - httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2, userAssignedIdentityId, userAssignedId)); + // New discovery order: V1 probed first (succeeds) → ImdsV1 cached httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1, userAssignedIdentityId, userAssignedId)); + + if (addSourceCheck) + { + var miSourceResultV1 = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync(ManagedIdentityTests.ImdsProbesCancellationToken).ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.Imds, miSourceResultV1.Source); + } + return managedIdentityApp; } if (addProbeMock) { + // New discovery order: V1 probed first (fails), then V2 (succeeds) + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V1, userAssignedIdentityId, userAssignedId)); httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V2, userAssignedIdentityId, userAssignedId)); } @@ -238,7 +247,7 @@ public async Task mTLSPopTokenIsPerIdentity( result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) .WithMtlsProofOfPossession() - .WithAttestationSupport() + .WithAttestationSupport() .ExecuteAsync().ConfigureAwait(false); Assert.IsNotNull(result); @@ -362,6 +371,60 @@ await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Res } } + /// + /// mTLS PoP token request should succeed with IMDSv2 even without prior discovery. + /// When .WithMtlsProofOfPossession() is used, MSAL routes directly to IMDSv2 (no probing, + /// and no fallback to IMDSv1, which does not support mTLS PoP). The presence of mTLS-specific + /// headers in the request signals to IMDS that it's an mTLS request and IMDS should respond accordingly. + /// + /// + /// + /// + [DataTestMethod] + [DataRow(UserAssignedIdentityId.None, null)] // SAMI + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI + public async Task MtlsPopWithoutPriorDiscovery_UsesImdsV2AndSucceeds( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + ManagedIdentityClient.ResetSourceForTest(); + + // Simulate IMDS host + SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); + + // Create app WITHOUT discovery and WITHOUT probe mocks + var managedIdentityApp = await CreateManagedIdentityAsync( + httpManager, + userAssignedIdentityId, + userAssignedId, + addProbeMock: false, + addSourceCheck: false, + managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard, + imdsVersion: ImdsVersion.V2 + ).ConfigureAwait(false); + + // Add IMDSv2 mocks (CSR + issuecredential + token) + AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); + + var result = await managedIdentityApp + .AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .WithMtlsProofOfPossession() + .WithAttestationSupport() + .ExecuteAsync() + .ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.AreEqual(MTLSPoP, result.TokenType); + Assert.IsNotNull(result.BindingCertificate); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + } + } + [DataTestMethod] [DataRow(UserAssignedIdentityId.None, null)] // SAMI [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI @@ -422,6 +485,8 @@ public async Task ProbeImdsEndpointAsyncSucceeds() { SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + // New discovery order: V1 probed first (fails), then V2 (succeeds) + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V1)); httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V2)); await CreateManagedIdentityAsync(httpManager, addProbeMock: false).ConfigureAwait(false); @@ -436,11 +501,14 @@ public async Task ProbeImdsEndpointAsyncSucceedsAfterRetry() { SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + // New discovery order: V1 probed first (fails), then V2 (first attempt fails with retry, second succeeds) + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V1)); // `retry: true` indicates a retriable status code will be returned httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2, retry: true)); + // Second V2 attempt succeeds + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V2)); - // Second attempt succeeds (defined inside of CreateManagedIdentityAsync) - await CreateManagedIdentityAsync(httpManager).ConfigureAwait(false); + await CreateManagedIdentityAsync(httpManager, addProbeMock: false).ConfigureAwait(false); } } @@ -452,14 +520,15 @@ public async Task ProbeImdsEndpointAsyncFails404WhichIsNonRetriableAndRetryPolic { SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); - // `retry: false` indicates a retriable status code will be returned - httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2, retry: false)); - httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1)); + // New discovery order: V1 probed first (fails with non-retriable 404), then V2 (succeeds) + // `retry: false` indicates a non-retriable status code (404) will be returned for V1 + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V1, retry: false)); + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V2)); var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); var miSourceResult = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync(ManagedIdentityTests.ImdsProbesCancellationToken).ConfigureAwait(false); - Assert.AreEqual(ManagedIdentitySource.Imds, miSourceResult.Source); + Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSourceResult.Source); } } @@ -477,7 +546,8 @@ public async Task ImdsProbeEndpointAsync_TimeOutThrowsOperationCanceledException var managedIdentityApp = miBuilder.Build(); - httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V2)); + // New discovery order: V1 is probed first + httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1)); var cts = new CancellationTokenSource(); cts.Cancel(); @@ -532,13 +602,16 @@ public async Task NonMtlsRequest_FallsBackToImdsV1( } [TestMethod] - public async Task ImdsV2ProbeFailsMaxRetries_FallsBackToImdsV1() + public async Task BothImdsProbesFailMaxRetries_ReturnsNoneFound() { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + // New discovery order: V1 probed first (fails), then V2 fails with max retries → NoneFound + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V1, retry: false)); + const int Num500Errors = 1 + TestImdsProbeRetryPolicy.ExponentialStrategyNumRetries; for (int i = 0; i < Num500Errors; i++) { @@ -546,12 +619,10 @@ public async Task ImdsV2ProbeFailsMaxRetries_FallsBackToImdsV1() httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2, retry: true)); } - httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1)); - var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); var miSourceResult = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync(ManagedIdentityTests.ImdsProbesCancellationToken).ConfigureAwait(false); - Assert.AreEqual(ManagedIdentitySource.Imds, miSourceResult.Source); + Assert.AreEqual(ManagedIdentitySource.None, miSourceResult.Source); } } #endregion diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs index a1c6653451..60c5d23501 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs @@ -44,19 +44,6 @@ public class ManagedIdentityTests : TestBase private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); - // MtlsPop is disabled for all these tests, so no need to mock IMDSv2 probe here - internal static void MockImdsV1Probe( - MockHttpManager httpManager, - ManagedIdentitySource managedIdentitySource, - UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, - string userAssignedId = null) - { - if (managedIdentitySource == ManagedIdentitySource.Imds) - { - httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1, userAssignedIdentityId, userAssignedId)); - } - } - [DataTestMethod] [DataRow("http://127.0.0.1:41564/msi/token/", ManagedIdentitySource.AppService)] [DataRow(AppServiceEndpoint, ManagedIdentitySource.AppService)] @@ -84,11 +71,13 @@ public async Task GetManagedIdentityTests( if (managedIdentitySource == ManagedIdentitySource.ImdsV2) { + // New discovery order: V1 probed first (fails), then V2 (succeeds) + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V1)); httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V2)); } else if (managedIdentitySource == ManagedIdentitySource.Imds) { - httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2)); + // New discovery order: V1 probed first (succeeds) httpManager.AddMockHandler(MockHelpers.MockImdsProbe(ImdsVersion.V1)); } @@ -123,10 +112,9 @@ public async Task SAMIHappyPathAsync( var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); httpManager.AddManagedIdentityMockHandler( endpoint, @@ -172,12 +160,11 @@ public async Task UAMIHappyPathAsync( SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); - + miBuilder.WithHttpManager(httpManager); var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource, userAssignedIdentityId, userAssignedId); httpManager.AddManagedIdentityMockHandler( endpoint, @@ -222,10 +209,9 @@ public async Task ManagedIdentityDifferentScopesTestAsync( var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); httpManager.AddManagedIdentityMockHandler( endpoint, @@ -281,10 +267,9 @@ public async Task ManagedIdentityForceRefreshTestAsync( var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); httpManager.AddManagedIdentityMockHandler( endpoint, @@ -342,10 +327,9 @@ public async Task ManagedIdentityWithClaimsAndCapabilitiesTestAsync( var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithClientCapabilities(TestConstants.ClientCapabilities) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); httpManager.AddManagedIdentityMockHandler( endpoint, @@ -406,10 +390,9 @@ public async Task ManagedIdentityWithClaimsTestAsync( var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); httpManager.AddManagedIdentityMockHandler( endpoint, @@ -479,10 +462,9 @@ public async Task ManagedIdentityTestWrongScopeAsync(string resource, ManagedIde var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); httpManager.AddManagedIdentityMockHandler(endpoint, resource, MockHelpers.GetMsiErrorResponse(managedIdentitySource), managedIdentitySource, statusCode: HttpStatusCode.InternalServerError); @@ -521,7 +503,7 @@ public async Task ManagedIdentityTestErrorResponseParsing(string errorResponse, var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler(AppServiceEndpoint, Resource, errorResponse, @@ -543,7 +525,7 @@ await mi.AcquireTokenForManagedIdentity(Resource) foreach (var expectedErrorSubString in expectedInErrorResponse) { - Assert.IsTrue(ex.Message.Contains(expectedErrorSubString), + Assert.IsTrue(ex.Message.Contains(expectedErrorSubString), $"Expected to contain string {expectedErrorSubString}. Actual error message: {ex.Message}"); } } @@ -584,10 +566,9 @@ public async Task ManagedIdentityErrorResponseNoPayloadTestAsync(ManagedIdentity var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); httpManager.AddManagedIdentityMockHandler(endpoint, "scope", "", managedIdentitySource, statusCode: HttpStatusCode.InternalServerError); @@ -625,10 +606,9 @@ public async Task ManagedIdentityNullResponseAsync(ManagedIdentitySource managed var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); httpManager.AddManagedIdentityMockHandler( endpoint, @@ -664,10 +644,9 @@ public async Task ManagedIdentityUnreachableNetworkAsync(ManagedIdentitySource m var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); httpManager.AddFailingRequest(new HttpRequestException("A socket operation was attempted to an unreachable network.", new SocketException(10051))); @@ -683,18 +662,18 @@ await mi.AcquireTokenForManagedIdentity(Resource) } } - [TestMethod] + [TestMethod] public async Task SystemAssignedManagedIdentityApiIdTestAsync() { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - + SetEnvironmentVariables(ManagedIdentitySource.AppService, AppServiceEndpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -724,7 +703,7 @@ public async Task UserAssignedManagedIdentityApiIdTestAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.WithUserAssignedClientId(TestConstants.ClientId)) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -756,7 +735,7 @@ public async Task ManagedIdentityCacheTestAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.BuildConcrete(); CancellationTokenSource cts = new CancellationTokenSource(); @@ -799,7 +778,7 @@ public async Task ManagedIdentityExpiresOnTestAsync(int expiresInHours, bool ref var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -831,7 +810,7 @@ public async Task ManagedIdentityInvalidRefreshOnThrowsAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); httpManager.AddManagedIdentityMockHandler( @@ -858,7 +837,7 @@ public async Task ManagedIdentityIsProActivelyRefreshedAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.BuildConcrete(); httpManager.AddManagedIdentityMockHandler( @@ -893,7 +872,7 @@ public async Task ManagedIdentityIsProActivelyRefreshedAsync() Assert.AreEqual(0, httpManager.QueueSize, "MSAL should have refreshed the token because the original AT was marked for refresh"); - + cacheAccess.WaitTo_AssertAcessCounts(1, 1); Assert.AreEqual(CacheRefreshReason.ProactivelyRefreshed, result.AuthenticationResultMetadata.CacheRefreshReason); @@ -903,7 +882,7 @@ public async Task ManagedIdentityIsProActivelyRefreshedAsync() result = await mi.AcquireTokenForManagedIdentity(Resource) .ExecuteAsync() .ConfigureAwait(false); - + Assert.AreEqual(CacheRefreshReason.NotApplicable, result.AuthenticationResultMetadata.CacheRefreshReason); } } @@ -921,7 +900,7 @@ public async Task ProactiveRefresh_CancelsSuccessfully_Async() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithLogging(LocalLogCallback) .WithHttpManager(httpManager); - + var mi = miBuilder.BuildConcrete(); httpManager.AddManagedIdentityMockHandler( @@ -961,7 +940,7 @@ void LocalLogCallback(LogLevel level, string message, bool containsPii) [TestMethod] public async Task ParallelRequests_CallTokenEndpointOnceAsync() { - int numOfTasks = 10; + int numOfTasks = 10; int identityProviderHits = 0; int cacheHits = 0; @@ -974,7 +953,7 @@ public async Task ParallelRequests_CallTokenEndpointOnceAsync() var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.BuildConcrete(); httpManager.AddManagedIdentityMockHandler( @@ -1050,7 +1029,6 @@ public async Task InvalidJsonResponseHandling(ManagedIdentitySource managedIdent var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); httpManager.AddManagedIdentityMockHandler( endpoint, @@ -1076,9 +1054,9 @@ await mi.AcquireTokenForManagedIdentity("scope") [DataRow(Resource, "https://graph.microsoft.com", ManagedIdentitySource.ServiceFabric, ServiceFabricEndpoint)] [DataRow(Resource, "https://graph.microsoft.com", ManagedIdentitySource.MachineLearning, MachineLearningEndpoint)] public async Task ManagedIdentityRequestTokensForDifferentScopesTestAsync( - string initialResource, - string newResource, - ManagedIdentitySource managedIdentitySource, + string initialResource, + string newResource, + ManagedIdentitySource managedIdentitySource, string endpoint) { using (new EnvVariableContext()) @@ -1089,10 +1067,9 @@ public async Task ManagedIdentityRequestTokensForDifferentScopesTestAsync( var miBuilder = ManagedIdentityApplicationBuilder .Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - + var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); // Mock handler for the initial resource request httpManager.AddManagedIdentityMockHandler(endpoint, initialResource, @@ -1161,10 +1138,17 @@ public async Task UnavailableManagedIdentitySource_ThrowsExceptionDuringTokenAcq .Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager); - var mi = miBuilder.Build(); + var mi = miBuilder.Build() as ManagedIdentityApplication; + Assert.IsNotNull(mi, "Build() should return a ManagedIdentityApplication instance."); + // Explicit discovery: V1 probe fails, then V2 probe also fails → NoneFound cached httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V1)); + httpManager.AddMockHandler(MockHelpers.MockImdsProbeFailure(ImdsVersion.V2)); + + var sourceResult = await mi.GetManagedIdentitySourceAsync(ImdsProbesCancellationToken).ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.None, sourceResult.Source); + // Token acquisition uses cached NoneFound and throws AllSourcesUnavailable var ex = await Assert.ThrowsExceptionAsync(async () => await mi.AcquireTokenForManagedIdentity("https://management.azure.com") .ExecuteAsync() @@ -1192,7 +1176,7 @@ public async Task MixedUserAndSystemAssignedManagedIdentityTestAsync() .Create(ManagedIdentityId .WithUserAssignedClientId(UserAssignedClientId)) .WithHttpManager(httpManager); - + userAssignedBuilder.Config.AccessorOptions = null; var userAssignedMI = userAssignedBuilder.BuildConcrete(); @@ -1277,7 +1261,7 @@ public async Task ManagedIdentityRetryPolicyLifeTimeIsPerRequestAsync( var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory); - + var mi = miBuilder.Build(); // Simulate permanent errors (to trigger the maximum number of retries) @@ -1291,7 +1275,7 @@ public async Task ManagedIdentityRetryPolicyLifeTimeIsPerRequestAsync( managedIdentitySource, statusCode: statusCode); } - + MsalServiceException ex = await Assert.ThrowsExceptionAsync(async () => await mi.AcquireTokenForManagedIdentity(Resource) @@ -1370,7 +1354,6 @@ public async Task ManagedIdentityWithCapabilitiesTestAsync( var mi = miBuilder.Build(); - MockImdsV1Probe(httpManager, managedIdentitySource); httpManager.AddManagedIdentityMockHandler( endpoint, @@ -1542,7 +1525,7 @@ public void WithExtraQueryParameters_MultipleCallsMergeValues() // Verify that parameters are merged Assert.AreEqual(4, miBuilder.Config.ExtraQueryParameters.Count); - + // Verify merged values Assert.AreEqual("newvalue1", miBuilder.Config.ExtraQueryParameters["param1"]); Assert.AreEqual("value2", miBuilder.Config.ExtraQueryParameters["param2"]);