Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Microsoft.Identity.Client.Platforms.net;
using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute;

namespace Microsoft.Identity.Client.Region
{
[JsonObject]
[Preserve(AllMembers = true)]
internal sealed class LocalImdsComputeResponse
{
[JsonProperty("location")]
public string Location { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ public RegionInfo(string region, RegionAutodetectionSource regionSource, string
}

// For information of the current api-version refer: https://learn.microsoft.com/azure/virtual-machines/instance-metadata-service?tabs=windows#versioning
private const string ImdsEndpoint = "http://169.254.169.254/metadata/instance/compute/location";
private const string DefaultApiVersion = "2020-06-01";
private const string ImdsEndpoint = "http://169.254.169.254/metadata/instance/compute";
private const string DefaultApiVersion = "2021-02-01";

private readonly IHttpManager _httpManager;
private readonly int _imdsCallTimeoutMs;
Expand Down Expand Up @@ -237,7 +237,8 @@ private async Task<RegionInfo> DiscoverAsync(ILoggerAdapter logger, Cancellation

if (response.StatusCode == HttpStatusCode.OK && !response.Body.IsNullOrEmpty())
Comment thread
Robbie-Microsoft marked this conversation as resolved.
{
region = response.Body;
LocalImdsComputeResponse computeResponse = JsonHelper.DeserializeFromJson<LocalImdsComputeResponse>(response.Body);
region = computeResponse?.Location;

if (ValidateRegion(region, $"IMDS call to {imdsUri.AbsoluteUri}", logger))
{
Expand Down Expand Up @@ -363,7 +364,6 @@ private static Uri BuildImdsUri(string apiVersion)
{
UriBuilder uriBuilder = new UriBuilder(ImdsEndpoint);
uriBuilder.AppendQueryParameters($"api-version={apiVersion}");
uriBuilder.AppendQueryParameters("format=text");
return uriBuilder.Uri;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace Microsoft.Identity.Client.Platforms.net
/// </summary>
[JsonSerializable(typeof(InstanceDiscoveryResponse))]
[JsonSerializable(typeof(LocalImdsErrorResponse))]
[JsonSerializable(typeof(LocalImdsComputeResponse))]
[JsonSerializable(typeof(AdalResultWrapper))]
[JsonSerializable(typeof(List<KeyValuePair<string, IEnumerable<string>>>))]
[JsonSerializable(typeof(ClientInfo))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,12 @@ public static void AddRegionDiscoveryMockHandler(
new MockHttpMessageHandler
{
ExpectedMethod = HttpMethod.Get,
ExpectedUrl = "http://169.254.169.254/metadata/instance/compute/location",
ExpectedUrl = "http://169.254.169.254/metadata/instance/compute",
ExpectedRequestHeaders = new Dictionary<string, string>
{
{"Metadata", "true"}
},
ResponseMessage = MockHelpers.CreateSuccessResponseMessage(response)
ResponseMessage = MockHelpers.CreateSuccessResponseMessage($"{{\"location\":\"{response}\"}}")
Comment thread
Robbie-Microsoft marked this conversation as resolved.
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ public static HashSet<string> s_scope
/// <summary>
/// IMDS region discovery URL.
/// </summary>
public const string ImdsUrl = $"http://{ImdsHost}/metadata/instance/compute/location";
public const string ImdsUrl = $"http://{ImdsHost}/metadata/instance/compute";

/// <summary>
/// App Service MSI endpoint used in tests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ const Microsoft.Identity.Test.Unit.TestConstants.HomeAccountId = "my-uid.my-utid
const Microsoft.Identity.Test.Unit.TestConstants.IdentityProvider = "my-idp" -> string
const Microsoft.Identity.Test.Unit.TestConstants.ImdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" -> string
const Microsoft.Identity.Test.Unit.TestConstants.ImdsHost = "169.254.169.254" -> string
const Microsoft.Identity.Test.Unit.TestConstants.ImdsUrl = "http://169.254.169.254/metadata/instance/compute/location" -> string
const Microsoft.Identity.Test.Unit.TestConstants.ImdsUrl = "http://169.254.169.254/metadata/instance/compute" -> string
const Microsoft.Identity.Test.Unit.TestConstants.InvalidRegion = "invalidregion" -> string
const Microsoft.Identity.Test.Unit.TestConstants.IOSBrokerErrDescr = "Test Error Description" -> string
const Microsoft.Identity.Test.Unit.TestConstants.IOSBrokerErrorMetadata = "error_metadata" -> string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public async Task SuccessfulResponseFromEnvironmentVariableAsync()
[TestMethod]
public async Task SuccessfulResponseFromLocalImdsAsync()
{
AddMockedResponse(MockHelpers.CreateSuccessResponseMessage(TestConstants.Region));
AddMockedResponse(CreateImdsComputeResponse(TestConstants.Region));

_testRequestContext.ServiceBundle.Config.AzureRegion =
ConfidentialClientApplication.AttemptRegionDiscovery;
Expand All @@ -97,7 +97,7 @@ public void MultiThreadSuccessfulResponseFromLocalImds_HasOnlyOneImdsCall()
const int MaxThreadCount = 5;
// add the mock response only once and call it 5 times on multiple threads
// if the http mock is called more than once, it will fail in dispose as queue will be non-empty
AddMockedResponse(MockHelpers.CreateSuccessResponseMessage(TestConstants.Region));
AddMockedResponse(CreateImdsComputeResponse(TestConstants.Region));
int threadCount = MaxThreadCount;
#pragma warning disable VSTHRD101 // Avoid unsupported async delegates - acceptable risk (crash the test proj)
var result = Parallel.For(0, MaxThreadCount, async (i) =>
Expand Down Expand Up @@ -136,7 +136,7 @@ public void MultiThreadSuccessfulResponseFromLocalImds_HasOnlyOneImdsCall()
[TestMethod]
public async Task FetchRegionFromLocalImdsThenGetMetadataFromCacheAsync()
{
AddMockedResponse(MockHelpers.CreateSuccessResponseMessage(TestConstants.Region));
AddMockedResponse(CreateImdsComputeResponse(TestConstants.Region));

_testRequestContext.ServiceBundle.Config.AzureRegion =
ConfidentialClientApplication.AttemptRegionDiscovery;
Expand Down Expand Up @@ -261,7 +261,7 @@ public async Task InvalidRegionEnvVariableAsync()
{
Environment.SetEnvironmentVariable(TestConstants.RegionName, "invalid`region");

AddMockedResponse(MockHelpers.CreateSuccessResponseMessage(TestConstants.Region)); // IMDS will return a valid region
AddMockedResponse(CreateImdsComputeResponse(TestConstants.Region)); // IMDS will return a valid region

_testRequestContext.ServiceBundle.Config.AzureRegion =
ConfidentialClientApplication.AttemptRegionDiscovery;
Expand All @@ -277,7 +277,7 @@ public async Task InvalidRegionEnvVariableAsync()
[DataRow("invalid`region")]
public async Task InvalidImdsAsync(string region)
{
AddMockedResponse(MockHelpers.CreateSuccessResponseMessage(region)); // IMDS will return an invalid region
AddMockedResponse(CreateImdsComputeResponse(region)); // IMDS will return an invalid region

_testRequestContext.ServiceBundle.Config.AzureRegion =
ConfidentialClientApplication.AttemptRegionDiscovery;
Expand Down Expand Up @@ -321,6 +321,40 @@ public async Task ResponseMissingRegionFromLocalImdsAsync()
Assert.Contains(TestConstants.RegionAutoDetectOkFailureMessage, _testRequestContext.ApiEvent.RegionDiscoveryFailureReason);
}

[TestMethod]
Comment thread
Robbie-Microsoft marked this conversation as resolved.
public async Task ResponseWithMissingLocationFieldFromLocalImdsAsync()
{
// Arrange - 200 OK with valid JSON that does not contain a "location" field
AddMockedResponse(MockHelpers.CreateSuccessResponseMessage("{\"vmId\":\"11111111-1111-1111-1111-111111111111\"}"));
_testRequestContext.ServiceBundle.Config.AzureRegion = ConfidentialClientApplication.AttemptRegionDiscovery;

// Act
InstanceDiscoveryMetadataEntry regionalMetadata = await _regionDiscoveryProvider.GetMetadataAsync(new Uri("https://login.microsoftonline.com/common/"), _testRequestContext).ConfigureAwait(false);

// Assert
Assert.IsNull(regionalMetadata, "Discovery requested, but it failed.");
Assert.IsNull(_testRequestContext.ApiEvent.RegionUsed);
Assert.AreEqual(RegionAutodetectionSource.FailedAutoDiscovery, _testRequestContext.ApiEvent.RegionAutodetectionSource);
Assert.AreEqual(RegionOutcome.FallbackToGlobal, _testRequestContext.ApiEvent.RegionOutcome);
}

[TestMethod]
Comment thread
Robbie-Microsoft marked this conversation as resolved.
Outdated
public async Task ResponseWithMalformedJsonFromLocalImdsAsync()
{
// Arrange - 200 OK with a non-empty but malformed JSON body
AddMockedResponse(MockHelpers.CreateSuccessResponseMessage("{ this is not valid json"));
_testRequestContext.ServiceBundle.Config.AzureRegion = ConfidentialClientApplication.AttemptRegionDiscovery;

// Act
InstanceDiscoveryMetadataEntry regionalMetadata = await _regionDiscoveryProvider.GetMetadataAsync(new Uri("https://login.microsoftonline.com/common/"), _testRequestContext).ConfigureAwait(false);

// Assert
Assert.IsNull(regionalMetadata, "Discovery requested, but it failed.");
Assert.IsNull(_testRequestContext.ApiEvent.RegionUsed);
Assert.AreEqual(RegionAutodetectionSource.FailedAutoDiscovery, _testRequestContext.ApiEvent.RegionAutodetectionSource);
Assert.AreEqual(RegionOutcome.FallbackToGlobal, _testRequestContext.ApiEvent.RegionOutcome);
}

[TestMethod]
[DataRow(HttpStatusCode.NotFound, 0, TestConstants.RegionAutoDetectNotFoundFailureMessage)] // No retries for 404 errors
[DataRow(HttpStatusCode.InternalServerError, TestRegionDiscoveryRetryPolicy.NumRetries, TestConstants.RegionAutoDetectInternalServerErrorFailureMessage)]
Expand Down Expand Up @@ -357,7 +391,7 @@ public async Task UpdateImdsApiVersionWhenCurrentVersionExpiresForImdsAsync()
AddMockedResponse(MockHelpers.CreateNullMessage(System.Net.HttpStatusCode.BadRequest));
AddMockedResponse(MockHelpers.CreateFailureMessage(System.Net.HttpStatusCode.BadRequest, File.ReadAllText(
ResourceHelper.GetTestResourceRelativePath("local-imds-error-response.json"))), expectedParams: false);
AddMockedResponse(MockHelpers.CreateSuccessResponseMessage(TestConstants.Region), apiVersion: "2020-10-01");
AddMockedResponse(CreateImdsComputeResponse(TestConstants.Region), apiVersion: "2020-10-01");
Comment thread
Robbie-Microsoft marked this conversation as resolved.
_testRequestContext.ServiceBundle.Config.AzureRegion = ConfidentialClientApplication.AttemptRegionDiscovery;

// Act
Expand Down Expand Up @@ -415,7 +449,7 @@ public async Task UpdateApiversionFailsWithNoNewestVersionsAsync()
public async Task RegionDiscoveryFails500OnceThenSucceeds200Async()
{
AddMockedResponse(MockHelpers.CreateNullMessage(HttpStatusCode.InternalServerError));
AddMockedResponse(MockHelpers.CreateSuccessResponseMessage(TestConstants.Region));
AddMockedResponse(CreateImdsComputeResponse(TestConstants.Region));

_testRequestContext.ServiceBundle.Config.AzureRegion = ConfidentialClientApplication.AttemptRegionDiscovery;

Expand Down Expand Up @@ -480,14 +514,13 @@ public async Task RegionDiscoveryDoesNotRetryOnNonRetryableStatusCodesAsync(Http
Assert.AreEqual(NumRequests, requestsMade);
}

private void AddMockedResponse(HttpResponseMessage responseMessage, string apiVersion = "2020-06-01", bool expectedParams = true)
private void AddMockedResponse(HttpResponseMessage responseMessage, string apiVersion = "2021-02-01", bool expectedParams = true)
{
var queryParams = new Dictionary<string, string>();

if (expectedParams)
{
queryParams.Add("api-version", apiVersion);
queryParams.Add("format", "text");

_httpManager.AddMockHandler(
new MockHttpMessageHandler
Expand Down Expand Up @@ -518,6 +551,11 @@ private void AddMockedResponse(HttpResponseMessage responseMessage, string apiVe
}
}

private static HttpResponseMessage CreateImdsComputeResponse(string location)
{
return MockHelpers.CreateSuccessResponseMessage($"{{\"location\":\"{location}\"}}");
}
Comment thread
Robbie-Microsoft marked this conversation as resolved.

private void ValidateInstanceMetadata(InstanceDiscoveryMetadataEntry entry, string region = "centralus")
{
InstanceDiscoveryMetadataEntry expectedEntry = new InstanceDiscoveryMetadataEntry()
Expand Down
Loading