From a51bf0f6f9a5b212eb5baf5c62ab8d280531beb8 Mon Sep 17 00:00:00 2001 From: Zhenya Polyvanyi Date: Tue, 19 May 2026 11:00:34 +0100 Subject: [PATCH] Use LRU cache for issuer address in B2C OpenID connect event handler --- .../AzureADB2COpenIDConnectEventHandlers.cs | 30 ++++-- ...ureADB2COpenIDConnectEventHandlersTests.cs | 96 +++++++++++++++++++ 2 files changed, 120 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.Identity.Web/AzureADB2COpenIDConnectEventHandlers.cs b/src/Microsoft.Identity.Web/AzureADB2COpenIDConnectEventHandlers.cs index d2b822baf..9e7dc7254 100644 --- a/src/Microsoft.Identity.Web/AzureADB2COpenIDConnectEventHandlers.cs +++ b/src/Microsoft.Identity.Web/AzureADB2COpenIDConnectEventHandlers.cs @@ -2,10 +2,10 @@ // Licensed under the MIT License. using System; -using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication.OpenIdConnect; +using Microsoft.Extensions.Caching.Memory; using Microsoft.IdentityModel.Protocols.OpenIdConnect; namespace Microsoft.Identity.Web @@ -14,8 +14,16 @@ internal class AzureADB2COpenIDConnectEventHandlers { private readonly ILoginErrorAccessor _errorAccessor; - private readonly Dictionary _userFlowToIssuerAddress = - new(StringComparer.OrdinalIgnoreCase); + internal const int MaxCacheEntries = 100; + + private static readonly TimeSpan CacheSlidingExpiration = TimeSpan.FromHours(1); + + private readonly MemoryCache _issuerAddressCache = new(new MemoryCacheOptions + { + SizeLimit = MaxCacheEntries, + }); + + private static readonly char[] _invalidPolicyCharacters = { '/', '?', '#', '%' }; public AzureADB2COpenIDConnectEventHandlers( string schemeName, @@ -38,6 +46,12 @@ public Task OnRedirectToIdentityProvider(RedirectContext context) !string.IsNullOrEmpty(userFlow) && !string.Equals(userFlow, defaultUserFlow, StringComparison.OrdinalIgnoreCase)) { + if (userFlow.IndexOfAny(_invalidPolicyCharacters) >= 0) + { + context.Properties.Items.Remove(OidcConstants.PolicyKey); + return Task.CompletedTask; + } + context.ProtocolMessage.IssuerAddress = BuildIssuerAddress(context, defaultUserFlow, userFlow); context.Properties.Items.Remove(OidcConstants.PolicyKey); @@ -100,16 +114,20 @@ public Task OnRemoteFailure(RemoteFailureContext context) private string BuildIssuerAddress(RedirectContext context, string? defaultUserFlow, string userFlow) { - if (!_userFlowToIssuerAddress.TryGetValue(userFlow, out var issuerAddress)) + if (!_issuerAddressCache.TryGetValue(userFlow, out string? issuerAddress)) { issuerAddress = context.ProtocolMessage.IssuerAddress .Replace($"/{defaultUserFlow}/", $"/{userFlow}/", StringComparison.OrdinalIgnoreCase); issuerAddress = issuerAddress.ToLowerInvariant(); - _userFlowToIssuerAddress[userFlow] = issuerAddress; + var cacheEntryOptions = new MemoryCacheEntryOptions() + .SetSize(1) + .SetSlidingExpiration(CacheSlidingExpiration); + + _issuerAddressCache.Set(userFlow, issuerAddress, cacheEntryOptions); } - return issuerAddress; + return issuerAddress!; } } } diff --git a/tests/Microsoft.Identity.Web.Test/AzureADB2COpenIDConnectEventHandlersTests.cs b/tests/Microsoft.Identity.Web.Test/AzureADB2COpenIDConnectEventHandlersTests.cs index e10b0e83e..73bb470bd 100644 --- a/tests/Microsoft.Identity.Web.Test/AzureADB2COpenIDConnectEventHandlersTests.cs +++ b/tests/Microsoft.Identity.Web.Test/AzureADB2COpenIDConnectEventHandlersTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication.OpenIdConnect; @@ -149,5 +150,100 @@ await handler.OnRemoteFailure( errorAccessor.Received(1).SetMessage(httpContext, otherException); httpContext.Response.Received().Redirect($"{httpContext.Request.PathBase}/MicrosoftIdentity/Account/Error"); } + + [Theory] + [InlineData("../../some-path")] + [InlineData("policy?test=123")] + [InlineData("policy#fragment")] + [InlineData("policy%2F%2Fabc")] + [InlineData("/absolute/path")] + public async Task OnRedirectToIdentityProvider_PolicyWithInvalidChars_FallsBackToDefault(string invalidPolicy) + { + // Arrange + var errorAccessor = Substitute.For(); + var options = new MicrosoftIdentityOptions() { SignUpSignInPolicyId = DefaultUserFlow }; + var handler = new AzureADB2COpenIDConnectEventHandlers(OpenIdConnectDefaults.AuthenticationScheme, options, errorAccessor); + var httpContext = HttpContextUtilities.CreateHttpContext(); + var authProperties = new AuthenticationProperties(); + authProperties.Items.Add(OidcConstants.PolicyKey, invalidPolicy); + var context = new RedirectContext(httpContext, _authScheme, new OpenIdConnectOptions(), authProperties) + { + ProtocolMessage = new OpenIdConnectMessage() { IssuerAddress = _defaultIssuer }, + }; + + // Act + await handler.OnRedirectToIdentityProvider(context); + + // Assert + Assert.Equal(_defaultIssuer, context.ProtocolMessage.IssuerAddress); + Assert.False(context.Properties.Items.ContainsKey(OidcConstants.PolicyKey)); + Assert.Null(context.ProtocolMessage.ResponseType); + } + + [Fact] + public async Task OnRedirectToIdentityProvider_CacheBoundedAt100Entries() + { + // Arrange + var errorAccessor = Substitute.For(); + var options = new MicrosoftIdentityOptions() { SignUpSignInPolicyId = DefaultUserFlow }; + var handler = new AzureADB2COpenIDConnectEventHandlers(OpenIdConnectDefaults.AuthenticationScheme, options, errorAccessor); + + // Act — send 200 unique policy values (exceeds the 100 limit) + for (int i = 0; i < 200; i++) + { + var httpContext = HttpContextUtilities.CreateHttpContext(); + var authProperties = new AuthenticationProperties(); + authProperties.Items.Add(OidcConstants.PolicyKey, $"policy_{i}"); + var context = new RedirectContext(httpContext, _authScheme, new OpenIdConnectOptions(), authProperties) + { + ProtocolMessage = new OpenIdConnectMessage() { IssuerAddress = _defaultIssuer }, + }; + await handler.OnRedirectToIdentityProvider(context); + } + + // Assert + var cacheField = typeof(AzureADB2COpenIDConnectEventHandlers) + .GetField("_issuerAddressCache", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + var cache = cacheField!.GetValue(handler) as Microsoft.Extensions.Caching.Memory.MemoryCache; + Assert.NotNull(cache); + + Assert.True(cache!.Count <= AzureADB2COpenIDConnectEventHandlers.MaxCacheEntries, + $"Cache count {cache.Count} should not exceed {AzureADB2COpenIDConnectEventHandlers.MaxCacheEntries}"); + } + + [Fact] + public async Task OnRedirectToIdentityProvider_BeyondCacheLimit_StillComputesAddress() + { + // Arrange + var errorAccessor = Substitute.For(); + var options = new MicrosoftIdentityOptions() { SignUpSignInPolicyId = DefaultUserFlow }; + var handler = new AzureADB2COpenIDConnectEventHandlers(OpenIdConnectDefaults.AuthenticationScheme, options, errorAccessor); + + for (int i = 0; i < 100; i++) + { + var httpContext = HttpContextUtilities.CreateHttpContext(); + var authProperties = new AuthenticationProperties(); + authProperties.Items.Add(OidcConstants.PolicyKey, $"policy_{i}"); + var context = new RedirectContext(httpContext, _authScheme, new OpenIdConnectOptions(), authProperties) + { + ProtocolMessage = new OpenIdConnectMessage() { IssuerAddress = _defaultIssuer }, + }; + await handler.OnRedirectToIdentityProvider(context); + } + + // Act - address computed, LRU eviction occurs + var httpContext101 = HttpContextUtilities.CreateHttpContext(); + var authProperties101 = new AuthenticationProperties(); + authProperties101.Items.Add(OidcConstants.PolicyKey, "policy_beyond_limit"); + var context101 = new RedirectContext(httpContext101, _authScheme, new OpenIdConnectOptions(), authProperties101) + { + ProtocolMessage = new OpenIdConnectMessage() { IssuerAddress = _defaultIssuer }, + }; + await handler.OnRedirectToIdentityProvider(context101); + + // Assert — address was computed correctly + Assert.Contains("policy_beyond_limit", context101.ProtocolMessage.IssuerAddress, StringComparison.OrdinalIgnoreCase); + Assert.False(context101.Properties.Items.ContainsKey(OidcConstants.PolicyKey)); + } } }