Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsyn
if (s_sourceName == ManagedIdentitySource.None)
{
// First invocation: detect and cache
source = await GetManagedIdentitySourceAsync(requestContext).ConfigureAwait(false);
source = await GetManagedIdentitySourceAsync(requestContext, isMtlsPopRequested).ConfigureAwait(false);
}
else
{
Expand Down Expand Up @@ -101,13 +101,18 @@ private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsyn

// Detect managed identity source based on the availability of environment variables and csr metadata probe request.
// This method is perf sensitive any changes should be benchmarked.
internal async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync(RequestContext requestContext)
internal async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync(
RequestContext requestContext,
bool isMtlsPopRequested)
{
// First check env vars to avoid the probe if possible
ManagedIdentitySource source = GetManagedIdentitySourceNoImdsV2(requestContext.Logger);

// If a source is detected via env vars, use it
if (source != ManagedIdentitySource.DefaultToImds)
// If a source is detected via env vars, or
// a source wasn't detected (it defaulted to ImdsV1) and MtlsPop was NOT requested,
// use the source.
// (don't trigger the ImdsV2 probe endpoint if MtlsPop was NOT requested)
if (source != ManagedIdentitySource.DefaultToImds || !isMtlsPopRequested)
{
s_sourceName = source;
return source;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync()
// Create a temporary RequestContext for the CSR metadata probe request.
var csrMetadataProbeRequestContext = new RequestContext(this.ServiceBundle, Guid.NewGuid(), null, CancellationToken.None);

return await ManagedIdentityClient.GetManagedIdentitySourceAsync(csrMetadataProbeRequestContext).ConfigureAwait(false);
// GetManagedIdentitySourceAsync might return ImdsV2 = true, but it still requires .WithMtlsProofOfPossesion on the Managed Identity Application object to hit the ImdsV2 flow
return await ManagedIdentityClient.GetManagedIdentitySourceAsync(csrMetadataProbeRequestContext, isMtlsPopRequested: true).ConfigureAwait(false);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,62 @@ public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIs
}
}

[DataTestMethod]
[DataRow(UserAssignedIdentityId.None, null)] // SAMI
[DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] // UAMI
[DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] // UAMI
[DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] // UAMI
public async Task ProbeDoesNotFireWhenMtlsPopNotRequested(
UserAssignedIdentityId userAssignedIdentityId,
string userAssignedId)
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint);

ManagedIdentityApplicationBuilder miBuilder = null;

var uami = userAssignedIdentityId != UserAssignedIdentityId.None && userAssignedId != null;
if (uami)
{
miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId);
}
else
{
miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned);
}

miBuilder
.WithHttpManager(httpManager)
.WithRetryPolicyFactory(_testRetryPolicyFactory);

var managedIdentityApp = miBuilder.Build();

// mock probe to show ImdsV2 is available
httpManager.AddMockHandler(MockHelpers.MockCsrResponse(userAssignedIdentityId: userAssignedIdentityId, userAssignedId: userAssignedId));

var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false);
// this indicates ImdsV2 is available
Assert.AreEqual(ManagedIdentitySource.ImdsV2, miSource);

httpManager.AddManagedIdentityMockHandler(
ManagedIdentityTests.ImdsEndpoint,
ManagedIdentityTests.Resource,
MockHelpers.GetMsiSuccessfulResponse(),
ManagedIdentitySource.Imds,
userAssignedId: userAssignedId,
userAssignedIdentityId: userAssignedIdentityId);

// ImdsV1 flow will be used since .WithMtlsProofOfPossession() is not used here
var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource).ExecuteAsync().ConfigureAwait(false);

Assert.IsNotNull(result);
Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
}
}

#region Cuid Tests
[TestMethod]
public void TestCsrGeneration_OnlyVmId()
Expand Down Expand Up @@ -684,7 +740,7 @@ await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Res
}
#endregion

#region cached certificate tests
#region Cached certificate tests
[TestMethod]
public async Task mTLSPop_ForceRefresh_UsesCachedCert_NoIssueCredential_PostsCanonicalClientId_AndSkipsAttestation()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ public class ManagedIdentityTests : TestBase

private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory();

private void AddImdsV2CsrMockHandlerIfNeeded(
ManagedIdentitySource managedIdentitySource,
MockHttpManager httpManager)
{
if (managedIdentitySource == ManagedIdentitySource.Imds)
{
httpManager.AddMockHandler(MockHelpers.MockCsrResponseFailure());
}
}

[DataTestMethod]
[DataRow("http://127.0.0.1:41564/msi/token/", ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)]
[DataRow(AppServiceEndpoint, ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)]
Expand All @@ -69,15 +59,11 @@ public async Task GetManagedIdentityTests(
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
.WithHttpManager(httpManager);




ManagedIdentityApplication mi = miBuilder.Build() as ManagedIdentityApplication;

Assert.AreEqual(expectedManagedIdentitySource, await mi.GetManagedIdentitySourceAsync().ConfigureAwait(false));
Expand Down Expand Up @@ -106,7 +92,6 @@ public async Task SAMIHappyPathAsync(
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
Expand Down Expand Up @@ -155,7 +140,6 @@ public async Task UAMIHappyPathAsync(
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId);
Expand Down Expand Up @@ -203,7 +187,6 @@ public async Task ManagedIdentityDifferentScopesTestAsync(
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
Expand Down Expand Up @@ -261,7 +244,6 @@ public async Task ManagedIdentityForceRefreshTestAsync(
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
Expand Down Expand Up @@ -320,7 +302,6 @@ public async Task ManagedIdentityWithClaimsAndCapabilitiesTestAsync(
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
Expand Down Expand Up @@ -384,7 +365,6 @@ public async Task ManagedIdentityWithClaimsTestAsync(
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
Expand Down Expand Up @@ -456,7 +436,6 @@ public async Task ManagedIdentityTestWrongScopeAsync(string resource, ManagedIde
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
Expand Down Expand Up @@ -560,7 +539,6 @@ public async Task ManagedIdentityErrorResponseNoPayloadTestAsync(ManagedIdentity
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
Expand Down Expand Up @@ -600,7 +578,6 @@ public async Task ManagedIdentityNullResponseAsync(ManagedIdentitySource managed
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
Expand Down Expand Up @@ -638,7 +615,6 @@ public async Task ManagedIdentityUnreachableNetworkAsync(ManagedIdentitySource m
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
Expand Down Expand Up @@ -1036,7 +1012,6 @@ public async Task InvalidJsonResponseHandling(ManagedIdentitySource managedIdent
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder
Expand Down Expand Up @@ -1077,7 +1052,6 @@ public async Task ManagedIdentityRequestTokensForDifferentScopesTestAsync(
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder
Expand Down Expand Up @@ -1332,7 +1306,6 @@ public async Task ManagedIdentityWithCapabilitiesTestAsync(
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
AddImdsV2CsrMockHandlerIfNeeded(managedIdentitySource, httpManager);
SetEnvironmentVariables(managedIdentitySource, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
Expand Down