Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ namespace Microsoft.Identity.Web.Test.Common.Mocks
{
public class MockHttpMessageHandler : HttpMessageHandler
{
public Func<MockHttpMessageHandler, MockHttpMessageHandler> ReplaceMockHttpMessageHandler;

private readonly bool _ignoreInstanceDiscovery;

#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.
public MockHttpMessageHandler(bool ignoreInstanceDiscovery = true)
public MockHttpMessageHandler()
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.
{
_ignoreInstanceDiscovery = ignoreInstanceDiscovery;
}
public HttpResponseMessage ResponseMessage { get; set; }

Expand Down Expand Up @@ -48,6 +49,24 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage

Assert.NotNull(uri);

//Intercept instance discovery requests and serve a response.
//Also, requeue the current mock handler for MSAL's next request.
#if NET6_0_OR_GREATER
if (uri.AbsoluteUri.Contains("/discovery/instance", StringComparison.OrdinalIgnoreCase))
#else
if (uri.AbsoluteUri.Contains("/discovery/instance"))
#endif
{
ReplaceMockHttpMessageHandler?.Invoke(this);

var responseMessage = new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent(TestConstants.DiscoveryJsonResponse),
};

return responseMessage;
}

if (!string.IsNullOrEmpty(ExpectedUrl))
{
Assert.Equal(
Expand All @@ -61,14 +80,10 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage

Assert.Equal(ExpectedMethod, request.Method);




if (request.Method != HttpMethod.Get && request.Content != null)
{
string postData = await request.Content.ReadAsStringAsync();
ActualRequestPostData = QueryStringParser.ParseKeyValueList(postData, '&', true, false);

}

return ResponseMessage;
Expand Down
4 changes: 4 additions & 0 deletions tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ private static async Task<AuthenticationResult> CreateAppAndGetTokenAsync(
if (addTokenMock)
{
mockHttp.AddMockHandler(tokenHandler);

//Enables the mock handler to requeue requests that have been intercepted for instance discovery for example
tokenHandler.ReplaceMockHttpMessageHandler = mockHttp.AddMockHandler;
}

var confidentialApp = ConfidentialClientApplicationBuilder
Expand Down Expand Up @@ -248,6 +251,7 @@ private static async Task<AuthenticationResult> CreateAppAndGetTokenAsync(
var result = await confidentialApp.AcquireTokenForClient(new[] { TestConstants.s_scopeForApp })
.ExecuteAsync().ConfigureAwait(false);

tokenHandler.ReplaceMockHttpMessageHandler = null!;
return result;
}

Expand Down