diff --git a/build/template-run-mi-e2e-imdsv2.yaml b/build/template-run-mi-e2e-imdsv2.yaml index cc6c3c7858..8449a24bbb 100644 --- a/build/template-run-mi-e2e-imdsv2.yaml +++ b/build/template-run-mi-e2e-imdsv2.yaml @@ -6,10 +6,10 @@ steps: # Ensure the exact SDK is available on the agent - task: UseDotNet@2 - displayName: 'Install .NET SDK 8.0.415' + displayName: 'Install .NET SDK 8.0.418' inputs: packageType: 'sdk' - version: '8.0.415' + version: '8.0.418' includePreviewVersions: false performMultiLevelLookup: true diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/IPersistentCertificateCache.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/IPersistentCertificateCache.cs index 3f48ce2560..3c4f1950e3 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/IPersistentCertificateCache.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/IPersistentCertificateCache.cs @@ -27,10 +27,21 @@ internal interface IPersistentCertificateCache void Write(string alias, X509Certificate2 cert, string endpointBase, ILoggerAdapter logger); /// - /// Prunes expired entries for the alias (best-effort). - /// Implementations should remove stale/expired entries while leaving the - /// latest valid binding for the alias in place. + /// Deletes expired certificate entries for the alias (best-effort), + /// leaving the latest valid binding for the alias in place (if any). + /// Write calls DeleteAllForAlias, so this method is only expected to be called + /// by implementations of Write. /// void Delete(string alias, ILoggerAdapter logger); + + /// + /// Deletes ALL certificate entries for the alias (best-effort), including non-expired ones. + /// Intended for "reset/evict" scenarios (e.g., SCHANNEL rejects the cached cert) to force a + /// re-mint. When a machine restarts the key becomes inaccessible and the cached certs should + /// be cleared to allow a new cert to be minted. + /// + /// + /// + void DeleteAllForAlias(string alias, ILoggerAdapter logger); } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 25bd0b918e..000d40d31b 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -175,7 +175,58 @@ public override async Task AuthenticateAsync( { // Capture the attestation token provider delegate before calling base _attestationTokenProvider = parameters.AttestationTokenProvider; - return await base.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false); + + try + { + return await base.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false); + } + catch (MsalServiceException ex) when (ex.ErrorCode == MsalError.ManagedIdentityUnreachableNetwork && IsSchanelFailure(ex)) + { + _requestContext.Logger.Verbose(() => + "[ImdsV2] SCHANNEL mTLS failure detected. Removing bad persisted cert and retrying with fresh mint."); + + // Remove the bad cert from both caches + string certCacheKey = GetMtlsCertCacheKey(); + try + { + if (_mtlsCache is MtlsBindingCache mtlsCache) + { + mtlsCache.RemoveBadCert(certCacheKey, _requestContext.Logger); + } + } + catch (Exception removalEx) + { + _requestContext.Logger.Verbose(() => $"[ImdsV2] Error removing bad cert: {removalEx.Message}"); + } + + // Retry - will mint fresh cert since we just deleted the bad one + return await base.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false); + } + } + + /// + /// Detects if the exception was caused by a SCHANNEL failure during mTLS authentication, + /// which can occur if the client certificate becomes invalid. + /// + /// + /// + private static bool IsSchanelFailure(MsalServiceException ex) + { + for (Exception e = ex; e != null; e = e.InnerException) + { + if (e is System.Net.Sockets.SocketException se && + (se.ErrorCode == 10054 || se.SocketErrorCode == System.Net.Sockets.SocketError.ConnectionReset)) + { + return true; + } + + if (e is System.Security.Authentication.AuthenticationException) + { + return true; + } + } + + return false; } private async Task ExecuteCertificateRequestAsync( diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/MtlsCertificateCache.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/MtlsCertificateCache.cs index fbea8e00d4..538c0fd249 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/MtlsCertificateCache.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/MtlsCertificateCache.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Collections.Concurrent; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -23,6 +24,7 @@ internal sealed class MtlsBindingCache : IMtlsCertificateCache private readonly KeyedSemaphorePool _gates = new(); private readonly ICertificateCache _memory; private readonly IPersistentCertificateCache _persisted; + private readonly ConcurrentDictionary _forceMint = new(); /// /// Inject both caches to avoid global state and enable testing. @@ -59,8 +61,10 @@ public async Task GetOrCreateAsync( throw new ArgumentNullException(nameof(factory)); } + bool forceMint = _forceMint.ContainsKey(cacheKey); + // 1) In-memory cache first - if (_memory.TryGet(cacheKey, out var cachedEntry, logger)) + if (!forceMint && _memory.TryGet(cacheKey, out var cachedEntry, logger)) { logger.Verbose(() => $"[PersistentCert] mTLS binding cache HIT (memory) for '{cacheKey}'."); @@ -76,8 +80,10 @@ public async Task GetOrCreateAsync( try { + forceMint = _forceMint.ContainsKey(cacheKey); + // Re-check after acquiring the gate - if (_memory.TryGet(cacheKey, out cachedEntry, logger)) + if (!forceMint && _memory.TryGet(cacheKey, out cachedEntry, logger)) { logger.Verbose(() => $"[PersistentCert] mTLS binding cache HIT (memory-after-gate) for '{cacheKey}'."); @@ -89,7 +95,7 @@ public async Task GetOrCreateAsync( } // 3) Persistent cache (best-effort) - if (_persisted.Read(cacheKey, out var persistedEntry, logger)) + if (!forceMint && _persisted.Read(cacheKey, out var persistedEntry, logger)) { logger.Verbose(() => $"[PersistentCert] mTLS binding cache HIT (persistent) for '{cacheKey}'."); @@ -135,6 +141,11 @@ public async Task GetOrCreateAsync( // This is also best-effort and must not throw. _persisted.Delete(cacheKey, logger); + if (forceMint) + { + _forceMint.TryRemove(cacheKey, out _); + } + // Pass through the factory result (already an MtlsBindingInfo) return mintedBinding; } @@ -143,5 +154,36 @@ public async Task GetOrCreateAsync( _gates.Release(cacheKey); } } + + /// + /// Removes a certificate from both in-memory and persistent cache when SCHANNEL rejects it. + /// + public void RemoveBadCert(string cacheKey, ILoggerAdapter logger) + { + if (cacheKey != null) + { + _forceMint[cacheKey] = 0; + } + + try + { + _memory.Remove(cacheKey, logger); + logger?.Verbose(() => $"[PersistentCert] Removed bad cert from memory cache for '{cacheKey}'"); + } + catch (Exception ex) + { + logger?.Verbose(() => $"[PersistentCert] Error removing from memory cache: {ex.Message}"); + } + + try + { + _persisted.DeleteAllForAlias(cacheKey, logger); + logger?.Verbose(() => $"[PersistentCert] Removed bad cert from persistent cache for '{cacheKey}'"); + } + catch (Exception ex) + { + logger?.Verbose(() => $"[PersistentCert] Error removing from persistent cache: {ex.Message}"); + } + } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/NoOpPersistentCertificateCache.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/NoOpPersistentCertificateCache.cs index 18ae4d99cc..a07ae0f97f 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/NoOpPersistentCertificateCache.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/NoOpPersistentCertificateCache.cs @@ -26,5 +26,10 @@ public void Delete(string alias, ILoggerAdapter logger) { // no-op } + + public void DeleteAllForAlias(string alias, ILoggerAdapter logger) + { + // no-op + } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/WindowsPersistentCertificateCache.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/WindowsPersistentCertificateCache.cs index 718a209828..2f1ddff3cb 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/WindowsPersistentCertificateCache.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/WindowsPersistentCertificateCache.cs @@ -270,6 +270,61 @@ public void Delete(string alias, ILoggerAdapter logger) logVerbose: s => logger.Verbose(() => s)); } + public void DeleteAllForAlias(string alias, ILoggerAdapter logger) + { + // Best-effort: short, non-configurable timeout. + InterprocessLock.TryWithAliasLock( + alias, + timeout: TimeSpan.FromMilliseconds(300), + action: () => + { + try + { + using var store = new X509Store(StoreName.My, StoreLocation.CurrentUser); + store.Open(OpenFlags.ReadWrite); + + X509Certificate2[] items; + try + { + items = new X509Certificate2[store.Certificates.Count]; + store.Certificates.CopyTo(items, 0); + } + catch (Exception ex) + { + logger.Verbose(() => "[PersistentCert] Store snapshot via CopyTo failed; falling back to enumeration. Details: " + ex.Message); + items = store.Certificates.Cast().ToArray(); + } + + int removed = 0; + + foreach (var existing in items) + { + try + { + if (!MsiCertificateFriendlyNameEncoder.TryDecode(existing.FriendlyName, out var decodedAlias, out _)) + continue; + if (!StringComparer.Ordinal.Equals(decodedAlias, alias)) + continue; + + // Delete ALL certs for this alias + store.Remove(existing); + removed++; + logger?.Verbose(() => $"[PersistentCert] Deleted certificate from store for alias '{alias}'"); + } + finally + { + existing.Dispose(); + } + } + } + catch (Exception ex) + { + logger.Verbose(() => "[PersistentCert] DeleteAllForAlias failed: " + ex.Message); + } + }, + logVerbose: s => logger.Verbose(() => s)); + } + /// /// Deletes only certificates that are actually expired (NotAfter < nowUtc), /// scoped to the given alias (cache key) via FriendlyName. diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/PersistentCertificateStoreUnitTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/PersistentCertificateStoreUnitTests.cs index 7079c4b490..017f9821c3 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/PersistentCertificateStoreUnitTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/PersistentCertificateStoreUnitTests.cs @@ -2,11 +2,16 @@ // Licensed under the MIT License. using System; +using System.IO; using System.Linq; +using System.Net.Http; +using System.Net.Sockets; +using System.Reflection; using System.Runtime.InteropServices; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Threading; +using Microsoft.Identity.Client; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.PlatformsCommon.Shared; @@ -621,5 +626,222 @@ public void Write_And_Read_Handle_Alias_EdgeCases() } #endregion + + #region //MTLS specific tests + + [TestMethod] + public void DeleteAllForAlias_Removes_All_Certificates_For_Alias() + { + WindowsOnly(); + + var alias = "alias-delall-" + Guid.NewGuid().ToString("N"); + var ep = "https://fake_mtls/delall"; + var logger = Logger; + + try + { + // Write 3 certs with increasing NotAfter so all 3 are added (policy only skips older/equal) + using var c1 = CreateSelfSignedWithKey("CN=" + Guid.NewGuid().ToString("D"), TimeSpan.FromDays(2)); + using var c2 = CreateSelfSignedWithKey("CN=" + Guid.NewGuid().ToString("D"), TimeSpan.FromDays(3)); + using var c3 = CreateSelfSignedWithKey("CN=" + Guid.NewGuid().ToString("D"), TimeSpan.FromDays(4)); + + _cache.Write(alias, c1, ep, logger); + _cache.Write(alias, c2, ep, logger); + _cache.Write(alias, c3, ep, logger); + + Assert.IsTrue(WaitForFind(alias, out _), "Expected at least one persisted entry."); + Assert.IsTrue(CountAliasInStore(alias) >= 2, "Expected multiple certs persisted for alias."); + + // Act + _cache.DeleteAllForAlias(alias, logger); + + // Assert: store should have 0 for alias + Assert.IsTrue(WaitForAliasCount(alias, expected: 0), "Expected all certs for alias to be deleted."); + Assert.IsFalse(_cache.Read(alias, out _, logger), "Read should return false after DeleteAllForAlias."); + } + finally + { + RemoveAliasFromStore(alias); + } + } + + [TestMethod] + public void DeleteAllForAlias_Does_Not_Remove_Other_Aliases() + { + WindowsOnly(); + + var alias1 = "alias-delall-a-" + Guid.NewGuid().ToString("N"); + var alias2 = "alias-delall-b-" + Guid.NewGuid().ToString("N"); + var ep1 = "https://fake_mtls/a"; + var ep2 = "https://fake_mtls/b"; + var logger = Logger; + + try + { + using var c1 = CreateSelfSignedWithKey("CN=" + Guid.NewGuid().ToString("D"), TimeSpan.FromDays(3)); + using var c2 = CreateSelfSignedWithKey("CN=" + Guid.NewGuid().ToString("D"), TimeSpan.FromDays(3)); + + _cache.Write(alias1, c1, ep1, logger); + _cache.Write(alias2, c2, ep2, logger); + + Assert.IsTrue(WaitForFind(alias1, out _)); + Assert.IsTrue(WaitForFind(alias2, out _)); + Assert.AreEqual(1, CountAliasInStore(alias1)); + Assert.AreEqual(1, CountAliasInStore(alias2)); + + // Act + _cache.DeleteAllForAlias(alias1, logger); + + // Assert + Assert.IsTrue(WaitForAliasCount(alias1, expected: 0), "alias1 should be removed."); + Assert.AreEqual(1, CountAliasInStore(alias2), "alias2 must remain."); + Assert.IsTrue(_cache.Read(alias2, out var v2, logger), "Read(alias2) should still succeed."); + + // caller owns returned cert + v2.Certificate.Dispose(); + } + finally + { + RemoveAliasFromStore(alias1); + RemoveAliasFromStore(alias2); + } + } + + [TestMethod] + public void RemoveBadCert_Removes_From_Memory_And_Calls_Persistent_DeleteAll() + { + var memory = new InMemoryCertificateCache(); + var persisted = Substitute.For(); + var logger = Substitute.For(); + + var cache = new MtlsBindingCache(memory, persisted); + + const string alias = "alias-remove-bad-cert"; + const string ep = "https://fake_mtls/ep"; + const string cid = "11111111-1111-1111-1111-111111111111"; + + using var cert = CreateSelfSignedCert(TimeSpan.FromDays(2)); + memory.Set(alias, new CertificateCacheValue(cert, ep, cid)); + + // Sanity: should be present + Assert.IsTrue(memory.TryGet(alias, out var before)); + before.Certificate.Dispose(); + + // Act + cache.RemoveBadCert(alias, logger); + + // Assert: memory entry gone + Assert.IsFalse(memory.TryGet(alias, out _), "Expected memory cache eviction."); + + // Assert: persistent delete-all invoked + persisted.Received(1).DeleteAllForAlias(alias, logger); + } + + [TestMethod] + public void RemoveBadCert_Is_BestEffort_DoesNotThrow_When_Persistent_Throws() + { + var memory = new InMemoryCertificateCache(); + var persisted = Substitute.For(); + var logger = Substitute.For(); + + persisted + .When(p => p.DeleteAllForAlias(Arg.Any(), Arg.Any())) + .Do(_ => throw new InvalidOperationException("boom")); + + var cache = new MtlsBindingCache(memory, persisted); + + // Should not throw + cache.RemoveBadCert("alias", logger); + } + + [TestMethod] + public void IsSchanelFailure_ReturnsTrue_For_SocketException_10054_Chain() + { + // Build exception chain like your logs + var sock = new SocketException(10054); + var io = new IOException("Unable to write data to the transport connection: An existing connection was forcibly closed by the remote host.", sock); + var http = new HttpRequestException("An error occurred while sending the request.", io); + + // ErrorCode must be managed_identity_unreachable_network for the catch filter, + // but the private method only checks ToString() content. + var msal = new MsalServiceException(MsalError.ManagedIdentityUnreachableNetwork, "An error occurred while sending the request.", http); + + // Invoke private static bool IsSchanelFailure(MsalServiceException ex) + var mi = typeof(ImdsV2ManagedIdentitySource) + .GetMethod("IsSchanelFailure", BindingFlags.NonPublic | BindingFlags.Static); + + Assert.IsNotNull(mi, "Could not find IsSchanelFailure via reflection."); + + var result = (bool)mi.Invoke(null, new object[] { msal }); + + Assert.IsTrue(result, "Expected 10054 chain to be detected as SCHANNEL failure."); + } + + private static X509Certificate2 CreateSelfSignedCert(TimeSpan lifetime, string subjectCn = "CN=RemoveBadCertTest") + { + using var rsa = RSA.Create(2048); + var req = new System.Security.Cryptography.X509Certificates.CertificateRequest( + new X500DistinguishedName(subjectCn), + rsa, + HashAlgorithmName.SHA256, + RSASignaturePadding.Pkcs1); + + var notBefore = DateTimeOffset.UtcNow.AddMinutes(-2); + var notAfter = notBefore.Add(lifetime); + return req.CreateSelfSigned(notBefore, notAfter); + } + + private static int CountAliasInStore(string alias) + { + using var store = new X509Store(StoreName.My, StoreLocation.CurrentUser); + store.Open(OpenFlags.OpenExistingOnly | OpenFlags.ReadOnly); + + X509Certificate2[] items; + try + { + items = new X509Certificate2[store.Certificates.Count]; + store.Certificates.CopyTo(items, 0); + } + catch + { + items = store.Certificates.Cast().ToArray(); + } + + int count = 0; + foreach (var cert in items) + { + try + { + if (MsiCertificateFriendlyNameEncoder.TryDecode(cert.FriendlyName, out var decodedAlias, out _) + && StringComparer.Ordinal.Equals(decodedAlias, alias)) + { + count++; + } + } + finally + { + cert.Dispose(); + } + } + + return count; + } + + private static bool WaitForAliasCount(string alias, int expected, int retries = 20, int delayMs = 50) + { + for (int i = 0; i < retries; i++) + { + if (CountAliasInStore(alias) == expected) + { + return true; + } + + Thread.Sleep(delayMs); + } + + return false; + } + + #endregion } }