diff --git a/src/libraries/Common/tests/System/Collections/TestBase.NonGeneric.cs b/src/libraries/Common/tests/System/Collections/TestBase.NonGeneric.cs index 64bb1d0db1925..587b03422049f 100644 --- a/src/libraries/Common/tests/System/Collections/TestBase.NonGeneric.cs +++ b/src/libraries/Common/tests/System/Collections/TestBase.NonGeneric.cs @@ -20,6 +20,12 @@ public static IEnumerable ValidCollectionSizes() yield return new object[] { 75 }; } + public static IEnumerable ValidPositiveCollectionSizes() + { + yield return new object[] { 1 }; + yield return new object[] { 75 }; + } + public enum EnumerableType { HashSet, diff --git a/src/libraries/Common/tests/System/Collections/TestingTypes.cs b/src/libraries/Common/tests/System/Collections/TestingTypes.cs index a74753849c807..9f66fb472a594 100644 --- a/src/libraries/Common/tests/System/Collections/TestingTypes.cs +++ b/src/libraries/Common/tests/System/Collections/TestingTypes.cs @@ -366,5 +366,23 @@ public struct ValueDelegateEquatable : IEquatable public bool Equals(ValueDelegateEquatable other) => EqualsWorker(other); } + public sealed class TrackingEqualityComparer : IEqualityComparer + { + public int EqualsCalls; + public int GetHashCodeCalls; + + public bool Equals(T x, T y) + { + EqualsCalls++; + return EqualityComparer.Default.Equals(x, y); + } + + public int GetHashCode(T obj) + { + GetHashCodeCalls++; + return EqualityComparer.Default.GetHashCode(obj); + } + } + #endregion } diff --git a/src/libraries/System.Collections/src/System/Collections/Generic/HashSet.cs b/src/libraries/System.Collections/src/System/Collections/Generic/HashSet.cs index 4f5f973e9356e..9a440fbdd89dd 100644 --- a/src/libraries/System.Collections/src/System/Collections/Generic/HashSet.cs +++ b/src/libraries/System.Collections/src/System/Collections/Generic/HashSet.cs @@ -416,7 +416,7 @@ public bool Remove(T item) for (i = _buckets[bucket] - 1; i >= 0; last = i, i = slots[i].next) { - if (slots[i].hashCode == hashCode && EqualityComparer.Default.Equals(slots[i].value, item)) + if (slots[i].hashCode == hashCode && comparer.Equals(slots[i].value, item)) { goto ReturnFound; } diff --git a/src/libraries/System.Collections/tests/Generic/HashSet/HashSet.Generic.Tests.cs b/src/libraries/System.Collections/tests/Generic/HashSet/HashSet.Generic.Tests.cs index 1339f97f42747..ef527301b4afc 100644 --- a/src/libraries/System.Collections/tests/Generic/HashSet/HashSet.Generic.Tests.cs +++ b/src/libraries/System.Collections/tests/Generic/HashSet/HashSet.Generic.Tests.cs @@ -626,5 +626,29 @@ public void EnsureCapacity_Generic_GrowCapacityWithFreeList(int setLength) } #endregion + + #region Remove + + [Theory] + [MemberData(nameof(ValidPositiveCollectionSizes))] + public void Remove_NonDefaultComparer_ComparerUsed(int capacity) + { + var c = new TrackingEqualityComparer(); + var set = new HashSet(capacity, c); + + AddToCollection(set, capacity); + T first = set.First(); + c.EqualsCalls = 0; + c.GetHashCodeCalls = 0; + + Assert.Equal(capacity, set.Count); + set.Remove(first); + Assert.Equal(capacity - 1, set.Count); + + Assert.InRange(c.EqualsCalls, 1, int.MaxValue); + Assert.InRange(c.GetHashCodeCalls, 1, int.MaxValue); + } + + #endregion } }