Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authentication.OpenIdConnect;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.IdentityModel.Protocols.OpenIdConnect;

namespace Microsoft.Identity.Web
Expand All @@ -14,8 +14,16 @@ internal class AzureADB2COpenIDConnectEventHandlers
{
private readonly ILoginErrorAccessor _errorAccessor;

private readonly Dictionary<string, string> _userFlowToIssuerAddress =
new(StringComparer.OrdinalIgnoreCase);
internal const int MaxCacheEntries = 100;

private static readonly TimeSpan CacheSlidingExpiration = TimeSpan.FromHours(1);

private readonly MemoryCache _issuerAddressCache = new(new MemoryCacheOptions
{
SizeLimit = MaxCacheEntries,
});

private static readonly char[] _invalidPolicyCharacters = { '/', '?', '#', '%' };

public AzureADB2COpenIDConnectEventHandlers(
string schemeName,
Expand All @@ -38,6 +46,12 @@ public Task OnRedirectToIdentityProvider(RedirectContext context)
!string.IsNullOrEmpty(userFlow) &&
!string.Equals(userFlow, defaultUserFlow, StringComparison.OrdinalIgnoreCase))
{
if (userFlow.IndexOfAny(_invalidPolicyCharacters) >= 0)
{
context.Properties.Items.Remove(OidcConstants.PolicyKey);
return Task.CompletedTask;
}

context.ProtocolMessage.IssuerAddress = BuildIssuerAddress(context, defaultUserFlow, userFlow);
context.Properties.Items.Remove(OidcConstants.PolicyKey);

Expand Down Expand Up @@ -100,16 +114,20 @@ public Task OnRemoteFailure(RemoteFailureContext context)

private string BuildIssuerAddress(RedirectContext context, string? defaultUserFlow, string userFlow)
{
if (!_userFlowToIssuerAddress.TryGetValue(userFlow, out var issuerAddress))
if (!_issuerAddressCache.TryGetValue(userFlow, out string? issuerAddress))
{
issuerAddress = context.ProtocolMessage.IssuerAddress
.Replace($"/{defaultUserFlow}/", $"/{userFlow}/", StringComparison.OrdinalIgnoreCase);
issuerAddress = issuerAddress.ToLowerInvariant();

_userFlowToIssuerAddress[userFlow] = issuerAddress;
var cacheEntryOptions = new MemoryCacheEntryOptions()
.SetSize(1)
.SetSlidingExpiration(CacheSlidingExpiration);

_issuerAddressCache.Set(userFlow, issuerAddress, cacheEntryOptions);
}

return issuerAddress;
return issuerAddress!;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authentication.OpenIdConnect;
Expand Down Expand Up @@ -149,5 +150,100 @@ await handler.OnRemoteFailure(
errorAccessor.Received(1).SetMessage(httpContext, otherException);
httpContext.Response.Received().Redirect($"{httpContext.Request.PathBase}/MicrosoftIdentity/Account/Error");
}

[Theory]
[InlineData("../../some-path")]
[InlineData("policy?test=123")]
[InlineData("policy#fragment")]
[InlineData("policy%2F%2Fabc")]
[InlineData("/absolute/path")]
public async Task OnRedirectToIdentityProvider_PolicyWithInvalidChars_FallsBackToDefault(string invalidPolicy)
{
// Arrange
var errorAccessor = Substitute.For<ILoginErrorAccessor>();
var options = new MicrosoftIdentityOptions() { SignUpSignInPolicyId = DefaultUserFlow };
var handler = new AzureADB2COpenIDConnectEventHandlers(OpenIdConnectDefaults.AuthenticationScheme, options, errorAccessor);
var httpContext = HttpContextUtilities.CreateHttpContext();
var authProperties = new AuthenticationProperties();
authProperties.Items.Add(OidcConstants.PolicyKey, invalidPolicy);
var context = new RedirectContext(httpContext, _authScheme, new OpenIdConnectOptions(), authProperties)
{
ProtocolMessage = new OpenIdConnectMessage() { IssuerAddress = _defaultIssuer },
};

// Act
await handler.OnRedirectToIdentityProvider(context);

// Assert
Assert.Equal(_defaultIssuer, context.ProtocolMessage.IssuerAddress);
Assert.False(context.Properties.Items.ContainsKey(OidcConstants.PolicyKey));
Assert.Null(context.ProtocolMessage.ResponseType);
}

[Fact]
public async Task OnRedirectToIdentityProvider_CacheBoundedAt100Entries()
{
// Arrange
var errorAccessor = Substitute.For<ILoginErrorAccessor>();
var options = new MicrosoftIdentityOptions() { SignUpSignInPolicyId = DefaultUserFlow };
var handler = new AzureADB2COpenIDConnectEventHandlers(OpenIdConnectDefaults.AuthenticationScheme, options, errorAccessor);

// Act — send 200 unique policy values (exceeds the 100 limit)
for (int i = 0; i < 200; i++)
{
var httpContext = HttpContextUtilities.CreateHttpContext();
var authProperties = new AuthenticationProperties();
authProperties.Items.Add(OidcConstants.PolicyKey, $"policy_{i}");
var context = new RedirectContext(httpContext, _authScheme, new OpenIdConnectOptions(), authProperties)
{
ProtocolMessage = new OpenIdConnectMessage() { IssuerAddress = _defaultIssuer },
};
await handler.OnRedirectToIdentityProvider(context);
}

// Assert
var cacheField = typeof(AzureADB2COpenIDConnectEventHandlers)
.GetField("_issuerAddressCache", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
var cache = cacheField!.GetValue(handler) as Microsoft.Extensions.Caching.Memory.MemoryCache;
Assert.NotNull(cache);

Assert.True(cache!.Count <= AzureADB2COpenIDConnectEventHandlers.MaxCacheEntries,
$"Cache count {cache.Count} should not exceed {AzureADB2COpenIDConnectEventHandlers.MaxCacheEntries}");
}

[Fact]
public async Task OnRedirectToIdentityProvider_BeyondCacheLimit_StillComputesAddress()
{
// Arrange
var errorAccessor = Substitute.For<ILoginErrorAccessor>();
var options = new MicrosoftIdentityOptions() { SignUpSignInPolicyId = DefaultUserFlow };
var handler = new AzureADB2COpenIDConnectEventHandlers(OpenIdConnectDefaults.AuthenticationScheme, options, errorAccessor);

for (int i = 0; i < 100; i++)
{
var httpContext = HttpContextUtilities.CreateHttpContext();
var authProperties = new AuthenticationProperties();
authProperties.Items.Add(OidcConstants.PolicyKey, $"policy_{i}");
var context = new RedirectContext(httpContext, _authScheme, new OpenIdConnectOptions(), authProperties)
{
ProtocolMessage = new OpenIdConnectMessage() { IssuerAddress = _defaultIssuer },
};
await handler.OnRedirectToIdentityProvider(context);
}

// Act - address computed, LRU eviction occurs
var httpContext101 = HttpContextUtilities.CreateHttpContext();
var authProperties101 = new AuthenticationProperties();
authProperties101.Items.Add(OidcConstants.PolicyKey, "policy_beyond_limit");
var context101 = new RedirectContext(httpContext101, _authScheme, new OpenIdConnectOptions(), authProperties101)
{
ProtocolMessage = new OpenIdConnectMessage() { IssuerAddress = _defaultIssuer },
};
await handler.OnRedirectToIdentityProvider(context101);

// Assert — address was computed correctly
Assert.Contains("policy_beyond_limit", context101.ProtocolMessage.IssuerAddress, StringComparison.OrdinalIgnoreCase);
Assert.False(context101.Properties.Items.ContainsKey(OidcConstants.PolicyKey));
}
}
}
Loading