diff --git a/tests/Microsoft.Identity.Web.Test.Common/Mocks/MockHttpMessageHandler.cs b/tests/Microsoft.Identity.Web.Test.Common/Mocks/MockHttpMessageHandler.cs index d3c4b5041..ebbb6c292 100644 --- a/tests/Microsoft.Identity.Web.Test.Common/Mocks/MockHttpMessageHandler.cs +++ b/tests/Microsoft.Identity.Web.Test.Common/Mocks/MockHttpMessageHandler.cs @@ -13,13 +13,14 @@ namespace Microsoft.Identity.Web.Test.Common.Mocks { public class MockHttpMessageHandler : HttpMessageHandler { + public Func ReplaceMockHttpMessageHandler; + private readonly bool _ignoreInstanceDiscovery; #pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. - public MockHttpMessageHandler(bool ignoreInstanceDiscovery = true) + public MockHttpMessageHandler() #pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. { - _ignoreInstanceDiscovery = ignoreInstanceDiscovery; } public HttpResponseMessage ResponseMessage { get; set; } @@ -48,6 +49,24 @@ protected override async Task SendAsync(HttpRequestMessage Assert.NotNull(uri); + //Intercept instance discovery requests and serve a response. + //Also, requeue the current mock handler for MSAL's next request. +#if NET6_0_OR_GREATER + if (uri.AbsoluteUri.Contains("/discovery/instance", StringComparison.OrdinalIgnoreCase)) +#else + if (uri.AbsoluteUri.Contains("/discovery/instance")) +#endif + { + ReplaceMockHttpMessageHandler?.Invoke(this); + + var responseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(TestConstants.DiscoveryJsonResponse), + }; + + return responseMessage; + } + if (!string.IsNullOrEmpty(ExpectedUrl)) { Assert.Equal( @@ -61,14 +80,10 @@ protected override async Task SendAsync(HttpRequestMessage Assert.Equal(ExpectedMethod, request.Method); - - - if (request.Method != HttpMethod.Get && request.Content != null) { string postData = await request.Content.ReadAsStringAsync(); ActualRequestPostData = QueryStringParser.ParseKeyValueList(postData, '&', true, false); - } return ResponseMessage; diff --git a/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs b/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs index 83f1aa84e..78986cf5f 100644 --- a/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs +++ b/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs @@ -211,6 +211,9 @@ private static async Task CreateAppAndGetTokenAsync( if (addTokenMock) { mockHttp.AddMockHandler(tokenHandler); + + //Enables the mock handler to requeue requests that have been intercepted for instance discovery for example + tokenHandler.ReplaceMockHttpMessageHandler = mockHttp.AddMockHandler; } var confidentialApp = ConfidentialClientApplicationBuilder @@ -248,6 +251,7 @@ private static async Task CreateAppAndGetTokenAsync( var result = await confidentialApp.AcquireTokenForClient(new[] { TestConstants.s_scopeForApp }) .ExecuteAsync().ConfigureAwait(false); + tokenHandler.ReplaceMockHttpMessageHandler = null!; return result; }