Skip to content

Commit

Permalink
Add CollectionsMarshal.GetValueRefOrAddDefault (#54611)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio0694 authored Jul 15, 2021
1 parent 720279c commit bbcb6b7
Show file tree
Hide file tree
Showing 5 changed files with 404 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ public static void ComparerImplementations_Dictionary_WithWellKnownStringCompare
expectedInternalComparerTypeBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(),
expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture,
expectedInternalComparerTypeAfterCollisionThreshold: StringComparer.InvariantCulture.GetType());

// CollectionsMarshal.GetValueRefOrAddDefault

RunCollectionTestCommon(
() => new Dictionary<string, object>(StringComparer.Ordinal),
(dictionary, key) => CollectionsMarshal.GetValueRefOrAddDefault(dictionary, key, out _) = null,
(dictionary, key) => dictionary.ContainsKey(key),
dictionary => dictionary.Comparer,
expectedInternalComparerTypeBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal,
expectedInternalComparerTypeAfterCollisionThreshold: randomizedOrdinalComparerType);

static void RunDictionaryTest(
IEqualityComparer<string> equalityComparer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ private int Initialize(int capacity)

private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
{
// NOTE: this method is mirrored in CollectionsMarshal.GetValueRefOrAddDefault below.
// If you make any changes here, make sure to keep that version in sync as well.

if (key == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
Expand Down Expand Up @@ -681,6 +684,190 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
return true;
}

/// <summary>
/// A helper class containing APIs exposed through <see cref="Runtime.InteropServices.CollectionsMarshal"/>.
/// These methods are relatively niche and only used in specific scenarios, so adding them in a separate type avoids
/// the additional overhead on each <see cref="Dictionary{TKey, TValue}"/> instantiation, especially in AOT scenarios.
/// </summary>
internal static class CollectionsMarshalHelper
{
/// <inheritdoc cref="Runtime.InteropServices.CollectionsMarshal.GetValueRefOrAddDefault{TKey, TValue}(Dictionary{TKey, TValue}, TKey, out bool)"/>
public static ref TValue? GetValueRefOrAddDefault(Dictionary<TKey, TValue> dictionary, TKey key, out bool exists)
{
// NOTE: this method is mirrored by Dictionary<TKey, TValue>.TryInsert above.
// If you make any changes here, make sure to keep that version in sync as well.

if (key == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

if (dictionary._buckets == null)
{
dictionary.Initialize(0);
}
Debug.Assert(dictionary._buckets != null);

Entry[]? entries = dictionary._entries;
Debug.Assert(entries != null, "expected entries to be non-null");

IEqualityComparer<TKey>? comparer = dictionary._comparer;
uint hashCode = (uint)((comparer == null) ? key.GetHashCode() : comparer.GetHashCode(key));

uint collisionCount = 0;
ref int bucket = ref dictionary.GetBucket(hashCode);
int i = bucket - 1; // Value in _buckets is 1-based

if (comparer == null)
{
if (typeof(TKey).IsValueType)
{
// ValueType: Devirtualize with EqualityComparer<TValue>.Default intrinsic
while (true)
{
// Should be a while loop https://github.com/dotnet/runtime/issues/9422
// Test uint in if rather than loop condition to drop range check for following array access
if ((uint)i >= (uint)entries.Length)
{
break;
}

if (entries[i].hashCode == hashCode && EqualityComparer<TKey>.Default.Equals(entries[i].key, key))
{
exists = true;

return ref entries[i].value!;
}

i = entries[i].next;

collisionCount++;
if (collisionCount > (uint)entries.Length)
{
// The chain of entries forms a loop; which means a concurrent update has happened.
// Break out of the loop and throw, rather than looping forever.
ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
}
}
}
else
{
// Object type: Shared Generic, EqualityComparer<TValue>.Default won't devirtualize
// https://github.com/dotnet/runtime/issues/10050
// So cache in a local rather than get EqualityComparer per loop iteration
EqualityComparer<TKey> defaultComparer = EqualityComparer<TKey>.Default;
while (true)
{
// Should be a while loop https://github.com/dotnet/runtime/issues/9422
// Test uint in if rather than loop condition to drop range check for following array access
if ((uint)i >= (uint)entries.Length)
{
break;
}

if (entries[i].hashCode == hashCode && defaultComparer.Equals(entries[i].key, key))
{
exists = true;

return ref entries[i].value!;
}

i = entries[i].next;

collisionCount++;
if (collisionCount > (uint)entries.Length)
{
// The chain of entries forms a loop; which means a concurrent update has happened.
// Break out of the loop and throw, rather than looping forever.
ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
}
}
}
}
else
{
while (true)
{
// Should be a while loop https://github.com/dotnet/runtime/issues/9422
// Test uint in if rather than loop condition to drop range check for following array access
if ((uint)i >= (uint)entries.Length)
{
break;
}

if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key))
{
exists = true;

return ref entries[i].value!;
}

i = entries[i].next;

collisionCount++;
if (collisionCount > (uint)entries.Length)
{
// The chain of entries forms a loop; which means a concurrent update has happened.
// Break out of the loop and throw, rather than looping forever.
ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
}
}
}

int index;
if (dictionary._freeCount > 0)
{
index = dictionary._freeList;
Debug.Assert((StartOfFreeList - entries[dictionary._freeList].next) >= -1, "shouldn't overflow because `next` cannot underflow");
dictionary._freeList = StartOfFreeList - entries[dictionary._freeList].next;
dictionary._freeCount--;
}
else
{
int count = dictionary._count;
if (count == entries.Length)
{
dictionary.Resize();
bucket = ref dictionary.GetBucket(hashCode);
}
index = count;
dictionary._count = count + 1;
entries = dictionary._entries;
}

ref Entry entry = ref entries![index];
entry.hashCode = hashCode;
entry.next = bucket - 1; // Value in _buckets is 1-based
entry.key = key;
entry.value = default!;
bucket = index + 1; // Value in _buckets is 1-based
dictionary._version++;

// Value types never rehash
if (!typeof(TKey).IsValueType && collisionCount > HashHelpers.HashCollisionThreshold && comparer is NonRandomizedStringEqualityComparer)
{
// If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing
// i.e. EqualityComparer<string>.Default.
dictionary.Resize(entries.Length, true);

exists = false;

// At this point the entries array has been resized, so the current reference we have is no longer valid.
// We're forced to do a new lookup and return an updated reference to the new entry instance. This new
// lookup is guaranteed to always find a value though and it will never return a null reference here.
ref TValue? value = ref dictionary.FindValue(key)!;

Debug.Assert(!Unsafe.IsNullRef(ref value), "the lookup result cannot be a null ref here");

return ref value;
}

exists = false;

return ref entry.value!;
}
}

public virtual void OnDeserialization(object? sender)
{
HashHelpers.SerializationInfoTable.TryGetValue(this, out SerializationInfo? siInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,15 @@ public static Span<T> AsSpan<T>(List<T>? list)
/// </remarks>
public static ref TValue GetValueRefOrNullRef<TKey, TValue>(Dictionary<TKey, TValue> dictionary, TKey key) where TKey : notnull
=> ref dictionary.FindValue(key);

/// <summary>
/// Gets a ref to a <typeparamref name="TValue"/> in the <see cref="Dictionary{TKey, TValue}"/>, adding a new entry with a default value if it does not exist in the <paramref name="dictionary"/>.
/// </summary>
/// <param name="dictionary">The dictionary to get the ref to <typeparamref name="TValue"/> from.</param>
/// <param name="key">The key used for lookup.</param>
/// <param name="exists">Whether or not a new entry for the given key was added to the dictionary.</param>
/// <remarks>Items should not be added to or removed from the <see cref="Dictionary{TKey, TValue}"/> while the ref <typeparamref name="TValue"/> is in use.</remarks>
public static ref TValue? GetValueRefOrAddDefault<TKey, TValue>(Dictionary<TKey, TValue> dictionary, TKey key, out bool exists) where TKey : notnull
=> ref Dictionary<TKey, TValue>.CollectionsMarshalHelper.GetValueRefOrAddDefault(dictionary, key, out exists);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ public static partial class CollectionsMarshal
{
public static System.Span<T> AsSpan<T>(System.Collections.Generic.List<T>? list) { throw null; }
public static ref TValue GetValueRefOrNullRef<TKey, TValue>(System.Collections.Generic.Dictionary<TKey, TValue> dictionary, TKey key) where TKey : notnull { throw null; }
public static ref TValue? GetValueRefOrAddDefault<TKey, TValue>(System.Collections.Generic.Dictionary<TKey, TValue> dictionary, TKey key, out bool exists) where TKey : notnull { throw null; }
}
[System.AttributeUsageAttribute(System.AttributeTargets.Field | System.AttributeTargets.Parameter | System.AttributeTargets.Property | System.AttributeTargets.ReturnValue, Inherited=false)]
public sealed partial class ComAliasNameAttribute : System.Attribute
Expand Down
Loading

0 comments on commit bbcb6b7

Please sign in to comment.