Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,21 @@ internal class ManagedIdentityClient
{
private const string WindowsHimdsFilePath = "%Programfiles%\\AzureConnectedMachineAgent\\himds.exe";
private const string LinuxHimdsFilePath = "/opt/azcmagent/bin/himds";
internal static ManagedIdentitySource s_sourceName = ManagedIdentitySource.None;

// Preview guard: once we fall back to IMDSv1 while IMDSv2 is cached,
// disallow switching to IMDSv2 PoP in the same process (preview behavior).
internal static bool s_imdsV1UsedForPreview = false;
// Non-null only after the explicit discovery API (GetManagedIdentitySourceAsync) runs.
// Allows caching "NoneFound" (Source=None) without confusing it with "not discovered yet".
private static ManagedIdentitySourceResult s_cachedSourceResult = null;

// Holds the most recently minted mTLS binding certificate for this application instance.
private X509Certificate2 _runtimeMtlsBindingCertificate;
internal X509Certificate2 RuntimeMtlsBindingCertificate => Volatile.Read(ref _runtimeMtlsBindingCertificate);

internal static void ResetSourceForTest()
{
s_sourceName = ManagedIdentitySource.None;
s_cachedSourceResult = null;
s_imdsV1UsedForPreview = false;

// Clear cert caches so each test starts fresh
Expand All @@ -48,42 +51,66 @@ internal async Task<ManagedIdentityResponse> SendTokenRequestForManagedIdentityA
return await msi.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false);
}

// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsync(
// This method selects the managed identity source for token acquisition.
// It does NOT probe IMDS. It uses the cached explicit discovery result if available,
// otherwise checks environment variables, and defaults to IMDS without probing.
private Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsync(
RequestContext requestContext,
bool isMtlsPopRequested,
CancellationToken cancellationToken)
{
using (requestContext.Logger.LogMethodDuration())
{
requestContext.Logger.Info($"[Managed Identity] Selecting managed identity source if not cached. Cached value is {s_sourceName} ");
requestContext.Logger.Info($"[Managed Identity] Selecting managed identity source. " +
$"Discovery cached: {s_cachedSourceResult != null}");

// Fail fast if cancellation was requested, before performing expensive network probes
cancellationToken.ThrowIfCancellationRequested();

ManagedIdentitySourceResult sourceResult = null;
ManagedIdentitySource source;

// If the source is not already set, determine it
if (s_sourceName == ManagedIdentitySource.None)
if (s_cachedSourceResult != null)
{
// First invocation: detect and cache
sourceResult = await GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested, cancellationToken).ConfigureAwait(false);
source = sourceResult.Source;
// Use the cached explicit discovery result (including NoneFound)
source = s_cachedSourceResult.Source;
requestContext.Logger.Info($"[Managed Identity] Using cached discovery result: {source}");
}
else
{
// Reuse cached value
source = s_sourceName;
// Standard path: check environment variables only, no IMDS probing
source = GetManagedIdentitySourceNoImds(requestContext.Logger);

if (source == ManagedIdentitySource.None)
{
// No environment-based source found; default to IMDS based on mTLS PoP flag
if (isMtlsPopRequested)
{
// Route mTLS PoP requests directly to IMDSv2 (no probing)
requestContext.Logger.Info("[Managed Identity] mTLS PoP requested, routing to IMDSv2 directly without probing.");
return Task.FromResult<AbstractManagedIdentity>(ImdsV2ManagedIdentitySource.Create(requestContext));
}

// Default to IMDSv1 without probing
requestContext.Logger.Info("[Managed Identity] Defaulting to IMDSv1 without probing.");
return Task.FromResult<AbstractManagedIdentity>(ImdsManagedIdentitySource.Create(requestContext));
}
}

// Handle NoneFound from cached discovery
if (source == ManagedIdentitySource.None)
{
throw CreateManagedIdentityUnavailableException(s_cachedSourceResult);
}

// If the source has already been set to ImdsV2 (via this method, or GetManagedIdentitySourceAsync in ManagedIdentityApplication.cs)
// and mTLS PoP was NOT requested: fall back to ImdsV1, because ImdsV2 currently only supports mTLS PoP requests
// Preview fallback: if ImdsV2 is cached but mTLS PoP not requested, fall back per-request to ImdsV1
if (source == ManagedIdentitySource.ImdsV2 && !isMtlsPopRequested)
{
requestContext.Logger.Info("[Managed Identity] ImdsV2 detected, but mTLS PoP was not requested. Falling back to ImdsV1 for this request only. Please use the \"WithMtlsProofOfPossession\" API to request a token via ImdsV2.");

// Mark that we used IMDSv1 in this process while IMDSv2 is cached (preview behavior).
s_imdsV1UsedForPreview = true;

// Do NOT modify s_sourceName; keep cached ImdsV2 so future PoP
// Do NOT modify s_cachedSourceResult; keep cached ImdsV2 so future PoP
// requests can leverage it.
source = ManagedIdentitySource.Imds;
}
Expand All @@ -106,7 +133,7 @@ private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsyn
MsalErrorMessage.MtlsPopTokenNotSupportedinImdsV1);
}

return source switch
return Task.FromResult<AbstractManagedIdentity>(source switch
{
ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(requestContext),
Expand All @@ -115,63 +142,65 @@ private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsyn
ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.ImdsV2 => ImdsV2ManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.Imds => ImdsManagedIdentitySource.Create(requestContext),
_ => throw CreateManagedIdentityUnavailableException(sourceResult)
};
_ => throw CreateManagedIdentityUnavailableException(s_cachedSourceResult)
});
}
}

// Detect managed identity source based on the availability of environment variables and csr metadata probe request.
// This method is perf sensitive any changes should be benchmarked.
private static ManagedIdentitySourceResult CacheDiscoveryResult(ManagedIdentitySourceResult result)
{
s_cachedSourceResult = result;
return result;
}

// Detect managed identity source by probing IMDS endpoints.
// This method is called only by the explicit discovery path (GetManagedIdentitySourceAsync in ManagedIdentityApplication.cs).
// It probes IMDS v1 first, then v2 if v1 fails, and caches the result.
internal async Task<ManagedIdentitySourceResult> GetManagedIdentitySourceAsync(
RequestContext requestContext,
bool isMtlsPopRequested,
CancellationToken cancellationToken)
{
// Return cached result if explicit discovery already ran
if (s_cachedSourceResult != null)
{
return s_cachedSourceResult;
}

// First check env vars to avoid the probe if possible
ManagedIdentitySource source = GetManagedIdentitySourceNoImds(requestContext.Logger);

if (source != ManagedIdentitySource.None)
{
s_sourceName = source;
return new ManagedIdentitySourceResult(source);
return CacheDiscoveryResult(new ManagedIdentitySourceResult(source));
}

string imdsV2FailureReason = null;
string imdsV1FailureReason = null;
string imdsV2FailureReason = null;

// skip the ImdsV2 probe if MtlsPop was NOT requested
if (isMtlsPopRequested)
{
var (imdsV2Success, imdsV2Failure) = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V2, cancellationToken).ConfigureAwait(false);
if (imdsV2Success)
{
requestContext.Logger.Info("[Managed Identity] ImdsV2 detected.");
s_sourceName = ManagedIdentitySource.ImdsV2;
return new ManagedIdentitySourceResult(s_sourceName);
}
imdsV2FailureReason = imdsV2Failure;
}
else
{
requestContext.Logger.Info("[Managed Identity] Mtls Pop was not requested; skipping ImdsV2 probe.");
}

// Probe IMDS v1 first
var (imdsV1Success, imdsV1Failure) = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V1, cancellationToken).ConfigureAwait(false);
if (imdsV1Success)
{
requestContext.Logger.Info("[Managed Identity] ImdsV1 detected.");
s_sourceName = ManagedIdentitySource.Imds;
return new ManagedIdentitySourceResult(s_sourceName);
return CacheDiscoveryResult(new ManagedIdentitySourceResult(ManagedIdentitySource.Imds));
}
imdsV1FailureReason = imdsV1Failure;

// If v1 fails, probe IMDS v2
var (imdsV2Success, imdsV2Failure) = await ImdsManagedIdentitySource.ProbeImdsEndpointAsync(requestContext, ImdsVersion.V2, cancellationToken).ConfigureAwait(false);
if (imdsV2Success)
{
requestContext.Logger.Info("[Managed Identity] ImdsV2 detected.");
return CacheDiscoveryResult(new ManagedIdentitySourceResult(ManagedIdentitySource.ImdsV2));
}
imdsV2FailureReason = imdsV2Failure;

requestContext.Logger.Info($"[Managed Identity] {MsalErrorMessage.ManagedIdentityAllSourcesUnavailable}");
s_sourceName = ManagedIdentitySource.None;

return new ManagedIdentitySourceResult(s_sourceName)
return CacheDiscoveryResult(new ManagedIdentitySourceResult(ManagedIdentitySource.None)
{
ImdsV1FailureReason = imdsV1FailureReason,
ImdsV2FailureReason = imdsV2FailureReason
};
});
}

/// <summary>
Expand Down Expand Up @@ -252,8 +281,8 @@ private static bool ValidateAzureArcEnvironment(string identityEndpoint, string
{
logger?.Verbose(() => "[Managed Identity] Azure Arc managed identity is available through file detection.");
return true;
}
}

logger?.Verbose(() => "[Managed Identity] Azure Arc managed identity is not available.");
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ private static void ThrowCsrMetadataRequestException(
Exception ex = null,
int? statusCode = null)
{
// A 404 from the IMDSv2 CSR endpoint indicates that the host supports only IMDSv1.
// This happens when WithMtlsProofOfPossession() is used without a prior
// GetManagedIdentitySourceAsync() call: MSAL routes directly to IMDSv2, and
// on an IMDSv1-only host the /getplatformmetadata endpoint does not exist.
// Translate to a client error so callers know mTLS PoP is not supported here.
if (statusCode == (int)HttpStatusCode.NotFound)
{
throw new MsalClientException(
MsalError.MtlsPopTokenNotSupportedinImdsV1,
MsalErrorMessage.MtlsPopTokenNotSupportedinImdsV1);
}

throw MsalServiceExceptionFactory.CreateManagedIdentityException(
MsalError.ManagedIdentityRequestFailed,
$"[ImdsV2] {errorMessage}",
Expand Down Expand Up @@ -154,8 +166,8 @@ public static AbstractManagedIdentity Create(RequestContext requestContext)
return new ImdsV2ManagedIdentitySource(requestContext);
}

internal ImdsV2ManagedIdentitySource(RequestContext requestContext)
: this(requestContext,
internal ImdsV2ManagedIdentitySource(RequestContext requestContext)
: this(requestContext,
new MtlsBindingCache(s_mtlsCertificateCache, PersistentCertificateCacheFactory
.Create(requestContext.Logger)))
{
Expand Down Expand Up @@ -436,9 +448,9 @@ protected override async Task<ManagedIdentityRequest> CreateRequestAsync(string
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>JWT string suitable for the IMDSv2 attested PoP flow, or null for non-attested flow.</returns>
private async Task<string> GetAttestationJwtAsync(
string clientId,
Uri attestationEndpoint,
ManagedIdentityKeyInfo keyInfo,
string clientId,
Uri attestationEndpoint,
ManagedIdentityKeyInfo keyInfo,
CancellationToken cancellationToken)
{
// Check if attestation token provider has been configured
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public sealed class ManagedIdentityApplication
IManagedIdentityApplication
{
internal ManagedIdentityClient ManagedIdentityClient { get; }

internal ManagedIdentityApplication(
ApplicationConfiguration configuration)
: base(configuration)
Expand All @@ -34,8 +34,8 @@ internal ManagedIdentityApplication(

AppTokenCacheInternal = configuration.AppTokenCacheInternalForTest ?? new TokenCache(ServiceBundle, true);

this.ServiceBundle.ApplicationLogger.Verbose(()=>$"ManagedIdentityApplication {configuration.GetHashCode()} created");
this.ServiceBundle.ApplicationLogger.Verbose(() => $"ManagedIdentityApplication {configuration.GetHashCode()} created");

ManagedIdentityClient = new ManagedIdentityClient();
}

Expand All @@ -58,16 +58,10 @@ public AcquireTokenForManagedIdentityParameterBuilder AcquireTokenForManagedIden
/// <inheritdoc/>
public async Task<ManagedIdentitySourceResult> GetManagedIdentitySourceAsync(CancellationToken cancellationToken)
{
if (ManagedIdentityClient.s_sourceName != ManagedIdentitySource.None)
{
return new ManagedIdentitySourceResult(ManagedIdentityClient.s_sourceName);
}

// Create a temporary RequestContext for the logger and the IMDS probe request.
var requestContext = new RequestContext(this.ServiceBundle, Guid.NewGuid(), null, cancellationToken);

// GetManagedIdentitySourceAsync might return ImdsV2 = true, but it still requires .WithMtlsProofOfPossesion on the Managed Identity Application object to hit the ImdsV2 flow
return await ManagedIdentityClient.GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested: true, cancellationToken).ConfigureAwait(false);
return await ManagedIdentityClient.GetManagedIdentitySourceAsync(requestContext, cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand All @@ -78,7 +72,7 @@ public async Task<ManagedIdentitySourceResult> GetManagedIdentitySourceAsync(Can
public static ManagedIdentitySource GetManagedIdentitySource()
{
var source = ManagedIdentityClient.GetManagedIdentitySourceNoImds();

return source == ManagedIdentitySource.None
#pragma warning disable CS0618
// ManagedIdentitySource.DefaultToImds is marked obsolete, but is intentionally used here as a sentinel value to support legacy detection logic.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ public async Task AcquireToken_OnImdsV2_MtlsPoP_WithAttestation_Succeeds(string
Assert.Inconclusive("Credential Guard attestation is only available on Windows.");
}

// Check if TOKEN_ATTESTATION_ENDPOINT is configured (required for attestation)
var attestationEndpoint = Environment.GetEnvironmentVariable("TOKEN_ATTESTATION_ENDPOINT");
if (string.IsNullOrWhiteSpace(attestationEndpoint))
{
Assert.Inconclusive("TOKEN_ATTESTATION_ENDPOINT is not configured. Attestation tests require this environment variable.");
}

var mi = BuildMi(id, idType);

try
Expand Down
Loading
Loading