diff --git a/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs b/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs index de8f7d89725c68..2fd38469dd25c9 100644 --- a/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs +++ b/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs @@ -61,11 +61,31 @@ protected ClaimsPrincipal(SerializationInfo info, StreamingContext context) { ArgumentNullException.ThrowIfNull(identities); - foreach (ClaimsIdentity identity in identities) + // If the identities value is exactly a List, special case it so that + // the enumerator allocation can be skipped. Doing this for List is the 99% + // case because it is normally used on the _identities value, which is a List. + if (identities.GetType() == typeof(List)) { - if (identity != null) + List identitiesList = (identities as List)!; + + for (int i = 0; i < identitiesList.Count; i++) + { + ClaimsIdentity identity = identitiesList[i]; + + if (identity != null) + { + return identity; + } + } + } + else + { + foreach (ClaimsIdentity identity in identities) { - return identity; + if (identity != null) + { + return identity; + } } } diff --git a/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs b/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs index c51d2dc37202da..a3cc102e212c1e 100644 --- a/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs +++ b/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections; using System.Collections.Generic; using System.IO; using System.Linq; @@ -242,6 +243,53 @@ public void Current_FallsBackToThread_UnauthenticatedPrincipalPolicy() }).Dispose(); } + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public void PrimaryIdentitySelector_Default() + { + RemoteExecutor.Invoke(static () => + { + ClaimsIdentity identity0 = null; + ClaimsIdentity identity1 = new([new Claim("type", "value")]); + ClaimsIdentity identity2 = new([new Claim("type", "value")]); + IEnumerable identities = [identity0, identity1, identity2]; + Func, ClaimsIdentity> selector = ClaimsPrincipal.PrimaryIdentitySelector; + + Assert.Same(identity1, selector(identities)); + Assert.Null(selector([])); + AssertExtensions.Throws("identities", () => selector(null)); + }).Dispose(); + } + + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public void PrimaryIdentitySelector_DefaultOnlySpecialCasesList() + { + RemoteExecutor.Invoke(static () => + { + ClaimsIdentity identity0 = null; + ClaimsIdentity identity1 = new([new Claim("type", "value")]); + ClaimsIdentity identity2 = new([new Claim("type", "value")]); + ClaimsIdentityList identities = [identity0, identity1, identity2]; + Func selector = ClaimsPrincipal.PrimaryIdentitySelector; + + Assert.Same(identity1, selector(identities)); + Assert.Equal(1, identities.GetEnumeratorCount); + Assert.Null(selector(new ClaimsIdentityList())); + }).Dispose(); + } + + private sealed class ClaimsIdentityList : List, IEnumerable + { + private readonly List _claimsIdentities = []; + + public int GetEnumeratorCount { get; private set; } + + public new IEnumerator GetEnumerator() + { + GetEnumeratorCount++; + return base.GetEnumerator(); + } + } + private class NonClaimsPrincipal : IPrincipal { public IIdentity Identity { get; set; }