diff --git a/src/libraries/System.Linq/src/System/Linq/Enumerable.cs b/src/libraries/System.Linq/src/System/Linq/Enumerable.cs index 8e5571ddda534e..2fa719a2f52ce0 100644 --- a/src/libraries/System.Linq/src/System/Linq/Enumerable.cs +++ b/src/libraries/System.Linq/src/System/Linq/Enumerable.cs @@ -55,7 +55,7 @@ internal static bool TryGetSpan(this IEnumerable source, out R { span = Unsafe.As(source); } - else if (source.GetType() == typeof(List)) + else if (source.GetType() == typeof(List)) // avoid accidentally bypassing a derived type's reimplementation of IEnumerable { span = CollectionsMarshal.AsSpan(Unsafe.As>(source)); } diff --git a/src/libraries/System.Linq/src/System/Linq/Select.cs b/src/libraries/System.Linq/src/System/Linq/Select.cs index ac27c8dda22eeb..acba44012423fd 100644 --- a/src/libraries/System.Linq/src/System/Linq/Select.cs +++ b/src/libraries/System.Linq/src/System/Linq/Select.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using static System.Linq.Utilities; namespace System.Linq @@ -52,9 +53,9 @@ public static IEnumerable Select( return new ArraySelectIterator(array, selector); } - if (source is List list) + if (source.GetType() == typeof(List)) // avoid accidentally bypassing a derived type's reimplementation of IEnumerable { - return new ListSelectIterator(list, selector); + return new ListSelectIterator(Unsafe.As>(source), selector); } return new IListSelectIterator(ilist, selector); diff --git a/src/libraries/System.Linq/src/System/Linq/ToCollection.cs b/src/libraries/System.Linq/src/System/Linq/ToCollection.cs index c6654b6e4fff1b..7cf50fc9631e61 100644 --- a/src/libraries/System.Linq/src/System/Linq/ToCollection.cs +++ b/src/libraries/System.Linq/src/System/Linq/ToCollection.cs @@ -148,6 +148,8 @@ public static Dictionary ToDictionary(this IEnumer ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector); } + Dictionary dict; + if (source.TryGetNonEnumeratedCount(out int capacity)) { if (capacity == 0) @@ -155,35 +157,25 @@ public static Dictionary ToDictionary(this IEnumer return new Dictionary(comparer); } - if (source is TSource[] array) + if (source.TryGetSpan(out ReadOnlySpan span)) { - return SpanToDictionary(array, keySelector, comparer); - } + dict = new Dictionary(span.Length, comparer); + foreach (TSource element in span) + { + dict.Add(keySelector(element), element); + } - if (source is List list) - { - ReadOnlySpan span = CollectionsMarshal.AsSpan(list); - return SpanToDictionary(span, keySelector, comparer); + return dict; } } - Dictionary d = new Dictionary(capacity, comparer); + dict = new Dictionary(capacity, comparer); foreach (TSource element in source) { - d.Add(keySelector(element), element); + dict.Add(keySelector(element), element); } - return d; - } - - private static Dictionary SpanToDictionary(ReadOnlySpan source, Func keySelector, IEqualityComparer? comparer) where TKey : notnull - { - Dictionary d = new Dictionary(source.Length, comparer); - foreach (TSource element in source) - { - d.Add(keySelector(element), element); - } - return d; + return dict; } public static Dictionary ToDictionary(this IEnumerable source, Func keySelector, Func elementSelector) where TKey : notnull => @@ -206,6 +198,8 @@ public static Dictionary ToDictionary(t ThrowHelper.ThrowArgumentNullException(ExceptionArgument.elementSelector); } + Dictionary dict; + if (source.TryGetNonEnumeratedCount(out int capacity)) { if (capacity == 0) @@ -213,35 +207,24 @@ public static Dictionary ToDictionary(t return new Dictionary(comparer); } - if (source is TSource[] array) + if (source.TryGetSpan(out ReadOnlySpan span)) { - return SpanToDictionary(array, keySelector, elementSelector, comparer); - } - - if (source is List list) - { - ReadOnlySpan span = CollectionsMarshal.AsSpan(list); - return SpanToDictionary(span, keySelector, elementSelector, comparer); + dict = new Dictionary(span.Length, comparer); + foreach (TSource element in span) + { + dict.Add(keySelector(element), elementSelector(element)); + } + return dict; } } - Dictionary d = new Dictionary(capacity, comparer); + dict = new Dictionary(capacity, comparer); foreach (TSource element in source) { - d.Add(keySelector(element), elementSelector(element)); + dict.Add(keySelector(element), elementSelector(element)); } - return d; - } - - private static Dictionary SpanToDictionary(ReadOnlySpan source, Func keySelector, Func elementSelector, IEqualityComparer? comparer) where TKey : notnull - { - Dictionary d = new Dictionary(source.Length, comparer); - foreach (TSource element in source) - { - d.Add(keySelector(element), elementSelector(element)); - } - return d; + return dict; } public static HashSet ToHashSet(this IEnumerable source) => source.ToHashSet(comparer: null); diff --git a/src/libraries/System.Linq/src/System/Linq/Where.cs b/src/libraries/System.Linq/src/System/Linq/Where.cs index 4371af8299fb2e..1ebe76396c8aaa 100644 --- a/src/libraries/System.Linq/src/System/Linq/Where.cs +++ b/src/libraries/System.Linq/src/System/Linq/Where.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics; +using System.Runtime.CompilerServices; using static System.Linq.Utilities; namespace System.Linq @@ -36,9 +37,9 @@ public static IEnumerable Where(this IEnumerable sour return new ArrayWhereIterator(array, predicate); } - if (source is List list) + if (source.GetType() == typeof(List)) { - return new ListWhereIterator(list, predicate); + return new ListWhereIterator(Unsafe.As>(source), predicate); } return new IEnumerableWhereIterator(source, predicate);