Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions build/template-run-mi-e2e-imdsv2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ steps:

# Ensure the exact SDK is available on the agent
- task: UseDotNet@2
displayName: 'Install .NET SDK 8.0.415'
displayName: 'Install .NET SDK 8.0.418'
inputs:
packageType: 'sdk'
version: '8.0.415'
version: '8.0.418'
includePreviewVersions: false
performMultiLevelLookup: true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,21 @@ internal interface IPersistentCertificateCache
void Write(string alias, X509Certificate2 cert, string endpointBase, ILoggerAdapter logger);

/// <summary>
/// Prunes expired entries for the alias (best-effort).
/// Implementations should remove stale/expired entries while leaving the
/// latest valid binding for the alias in place.
/// Deletes expired certificate entries for the alias (best-effort),
/// leaving the latest valid binding for the alias in place (if any).
/// Write calls DeleteAllForAlias, so this method is only expected to be called
/// by implementations of Write.
/// </summary>
void Delete(string alias, ILoggerAdapter logger);

/// <summary>
/// Deletes ALL certificate entries for the alias (best-effort), including non-expired ones.
/// Intended for "reset/evict" scenarios (e.g., SCHANNEL rejects the cached cert) to force a
/// re-mint. When a machine restarts the key becomes inaccessible and the cached certs should
/// be cleared to allow a new cert to be minted.
/// </summary>
/// <param name="alias"></param>
/// <param name="logger"></param>
void DeleteAllForAlias(string alias, ILoggerAdapter logger);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,58 @@ public override async Task<ManagedIdentityResponse> AuthenticateAsync(
{
// Capture the attestation token provider delegate before calling base
_attestationTokenProvider = parameters.AttestationTokenProvider;
return await base.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false);

try
{
return await base.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false);
}
catch (MsalServiceException ex) when (ex.ErrorCode == MsalError.ManagedIdentityUnreachableNetwork && IsSchanelFailure(ex))
{
_requestContext.Logger.Verbose(() =>
"[ImdsV2] SCHANNEL mTLS failure detected. Removing bad persisted cert and retrying with fresh mint.");

// Remove the bad cert from both caches
string certCacheKey = GetMtlsCertCacheKey();
try
{
if (_mtlsCache is MtlsBindingCache mtlsCache)
{
mtlsCache.RemoveBadCert(certCacheKey, _requestContext.Logger);
}
}
catch (Exception removalEx)
{
_requestContext.Logger.Verbose(() => $"[ImdsV2] Error removing bad cert: {removalEx.Message}");
}

// Retry - will mint fresh cert since we just deleted the bad one
return await base.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false);
}
}

/// <summary>
/// Detects if the exception was caused by a SCHANNEL failure during mTLS authentication,
/// which can occur if the client certificate becomes invalid.
/// </summary>
/// <param name="ex"></param>
/// <returns></returns>
private static bool IsSchanelFailure(MsalServiceException ex)
{
for (Exception e = ex; e != null; e = e.InnerException)
{
if (e is System.Net.Sockets.SocketException se &&
(se.ErrorCode == 10054 || se.SocketErrorCode == System.Net.Sockets.SocketError.ConnectionReset))
{
return true;
}

if (e is System.Security.Authentication.AuthenticationException)
{
return true;
}
}

return false;
}

private async Task<CertificateRequestResponse> ExecuteCertificateRequestAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -23,6 +24,7 @@ internal sealed class MtlsBindingCache : IMtlsCertificateCache
private readonly KeyedSemaphorePool _gates = new();
private readonly ICertificateCache _memory;
private readonly IPersistentCertificateCache _persisted;
private readonly ConcurrentDictionary<string, byte> _forceMint = new();

/// <summary>
/// Inject both caches to avoid global state and enable testing.
Expand Down Expand Up @@ -59,8 +61,10 @@ public async Task<MtlsBindingInfo> GetOrCreateAsync(
throw new ArgumentNullException(nameof(factory));
}

bool forceMint = _forceMint.ContainsKey(cacheKey);

// 1) In-memory cache first
if (_memory.TryGet(cacheKey, out var cachedEntry, logger))
if (!forceMint && _memory.TryGet(cacheKey, out var cachedEntry, logger))
{
logger.Verbose(() =>
$"[PersistentCert] mTLS binding cache HIT (memory) for '{cacheKey}'.");
Expand All @@ -76,8 +80,10 @@ public async Task<MtlsBindingInfo> GetOrCreateAsync(

try
{
forceMint = _forceMint.ContainsKey(cacheKey);

// Re-check after acquiring the gate
if (_memory.TryGet(cacheKey, out cachedEntry, logger))
if (!forceMint && _memory.TryGet(cacheKey, out cachedEntry, logger))
{
logger.Verbose(() =>
$"[PersistentCert] mTLS binding cache HIT (memory-after-gate) for '{cacheKey}'.");
Expand All @@ -89,7 +95,7 @@ public async Task<MtlsBindingInfo> GetOrCreateAsync(
}

// 3) Persistent cache (best-effort)
if (_persisted.Read(cacheKey, out var persistedEntry, logger))
if (!forceMint && _persisted.Read(cacheKey, out var persistedEntry, logger))
{
logger.Verbose(() =>
$"[PersistentCert] mTLS binding cache HIT (persistent) for '{cacheKey}'.");
Expand Down Expand Up @@ -135,6 +141,11 @@ public async Task<MtlsBindingInfo> GetOrCreateAsync(
// This is also best-effort and must not throw.
_persisted.Delete(cacheKey, logger);

if (forceMint)
{
_forceMint.TryRemove(cacheKey, out _);
}

// Pass through the factory result (already an MtlsBindingInfo)
return mintedBinding;
}
Expand All @@ -143,5 +154,36 @@ public async Task<MtlsBindingInfo> GetOrCreateAsync(
_gates.Release(cacheKey);
}
}

/// <summary>
/// Removes a certificate from both in-memory and persistent cache when SCHANNEL rejects it.
/// </summary>
public void RemoveBadCert(string cacheKey, ILoggerAdapter logger)
{
if (cacheKey != null)
{
_forceMint[cacheKey] = 0;
}

try
{
_memory.Remove(cacheKey, logger);
logger?.Verbose(() => $"[PersistentCert] Removed bad cert from memory cache for '{cacheKey}'");
}
catch (Exception ex)
{
logger?.Verbose(() => $"[PersistentCert] Error removing from memory cache: {ex.Message}");
}

try
{
_persisted.DeleteAllForAlias(cacheKey, logger);
logger?.Verbose(() => $"[PersistentCert] Removed bad cert from persistent cache for '{cacheKey}'");
}
catch (Exception ex)
{
logger?.Verbose(() => $"[PersistentCert] Error removing from persistent cache: {ex.Message}");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,10 @@ public void Delete(string alias, ILoggerAdapter logger)
{
// no-op
}

public void DeleteAllForAlias(string alias, ILoggerAdapter logger)
{
// no-op
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,61 @@ public void Delete(string alias, ILoggerAdapter logger)
logVerbose: s => logger.Verbose(() => s));
}

public void DeleteAllForAlias(string alias, ILoggerAdapter logger)
{
// Best-effort: short, non-configurable timeout.
InterprocessLock.TryWithAliasLock(
alias,
timeout: TimeSpan.FromMilliseconds(300),
action: () =>
{
try
{
using var store = new X509Store(StoreName.My, StoreLocation.CurrentUser);
store.Open(OpenFlags.ReadWrite);

X509Certificate2[] items;
try
{
items = new X509Certificate2[store.Certificates.Count];
store.Certificates.CopyTo(items, 0);
}
catch (Exception ex)
{
logger.Verbose(() => "[PersistentCert] Store snapshot via CopyTo failed; falling back to enumeration. Details: " + ex.Message);
items = store.Certificates.Cast<X509Certificate2>().ToArray();
}

int removed = 0;

foreach (var existing in items)
{
try
{
if (!MsiCertificateFriendlyNameEncoder.TryDecode(existing.FriendlyName, out var decodedAlias, out _))
continue;
if (!StringComparer.Ordinal.Equals(decodedAlias, alias))
continue;

// Delete ALL certs for this alias
store.Remove(existing);
removed++;
logger?.Verbose(() => $"[PersistentCert] Deleted certificate from store for alias '{alias}'");
}
finally
{
existing.Dispose();
}
}
}
catch (Exception ex)
{
logger.Verbose(() => "[PersistentCert] DeleteAllForAlias failed: " + ex.Message);
}
},
logVerbose: s => logger.Verbose(() => s));
}

/// <summary>
/// Deletes only certificates that are actually expired (NotAfter &lt; nowUtc),
/// scoped to the given alias (cache key) via FriendlyName.
Expand Down
Loading