Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Commit

Permalink
Add ConcurrentDictionary GetOrAdd/AddOrUpdate overloads with generic arg
Browse files Browse the repository at this point in the history
Adds one overload to each of GetOrAdd and AddOrUpdate.  These overloads accept a generic argument that is passed through to any invocations of the supplied delegates, enabling developers to avoid delegate/closure allocations when more input is needed than just the key or existing value.  For AddOrUpdate, there are two existing overloads with delegates; this only provide a new overload for the one that accepts two delegates.
  • Loading branch information
stephentoub authored and justinvp committed Nov 19, 2016
1 parent ff81e90 commit c8f90c6
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,37 @@ public TValue GetOrAdd(TKey key, Func<TKey, TValue> valueFactory)
return resultingValue;
}

/// <summary>
/// Adds a key/value pair to the <see cref="ConcurrentDictionary{TKey,TValue}"/>
/// if the key does not already exist.
/// </summary>
/// <param name="key">The key of the element to add.</param>
/// <param name="valueFactory">The function used to generate a value for the key</param>
/// <param name="factoryArgument">An argument value to pass into <paramref name="valueFactory"/>.</param>
/// <exception cref="T:System.ArgumentNullException"><paramref name="key"/> is a null reference
/// (Nothing in Visual Basic).</exception>
/// <exception cref="T:System.ArgumentNullException"><paramref name="valueFactory"/> is a null reference
/// (Nothing in Visual Basic).</exception>
/// <exception cref="T:System.OverflowException">The dictionary contains too many
/// elements.</exception>
/// <returns>The value for the key. This will be either the existing value for the key if the
/// key is already in the dictionary, or the new value for the key as returned by valueFactory
/// if the key was not in the dictionary.</returns>
public TValue GetOrAdd<TArg>(TKey key, Func<TKey, TArg, TValue> valueFactory, TArg factoryArgument)
{
if (key == null) throw new ArgumentNullException("key");
if (valueFactory == null) throw new ArgumentNullException("valueFactory");

int hashcode = _comparer.GetHashCode(key);

TValue resultingValue;
if (!TryGetValueInternal(key, hashcode, out resultingValue))
{
TryAddInternal(key, hashcode, valueFactory(key, factoryArgument), false, true, out resultingValue);
}
return resultingValue;
}

/// <summary>
/// Adds a key/value pair to the <see cref="ConcurrentDictionary{TKey,TValue}"/>
/// if the key does not already exist.
Expand All @@ -1068,6 +1099,59 @@ public TValue GetOrAdd(TKey key, TValue value)
return resultingValue;
}

/// <summary>
/// Adds a key/value pair to the <see cref="ConcurrentDictionary{TKey,TValue}"/> if the key does not already
/// exist, or updates a key/value pair in the <see cref="ConcurrentDictionary{TKey,TValue}"/> if the key
/// already exists.
/// </summary>
/// <param name="key">The key to be added or whose value should be updated</param>
/// <param name="addValueFactory">The function used to generate a value for an absent key</param>
/// <param name="updateValueFactory">The function used to generate a new value for an existing key
/// based on the key's existing value</param>
/// <param name="factoryArgument">An argument to pass into <paramref name="addValueFactory"/> and <paramref name="updateValueFactory"/>.</param>
/// <exception cref="T:System.ArgumentNullException"><paramref name="key"/> is a null reference
/// (Nothing in Visual Basic).</exception>
/// <exception cref="T:System.ArgumentNullException"><paramref name="addValueFactory"/> is a null reference
/// (Nothing in Visual Basic).</exception>
/// <exception cref="T:System.ArgumentNullException"><paramref name="updateValueFactory"/> is a null reference
/// (Nothing in Visual Basic).</exception>
/// <exception cref="T:System.OverflowException">The dictionary contains too many
/// elements.</exception>
/// <returns>The new value for the key. This will be either be the result of addValueFactory (if the key was
/// absent) or the result of updateValueFactory (if the key was present).</returns>
public TValue AddOrUpdate<TArg>(
TKey key, Func<TKey, TArg, TValue> addValueFactory, Func<TKey, TValue, TArg, TValue> updateValueFactory, TArg factoryArgument)
{
if (key == null) throw new ArgumentNullException("key");
if (addValueFactory == null) throw new ArgumentNullException("addValueFactory");
if (updateValueFactory == null) throw new ArgumentNullException("updateValueFactory");

int hashcode = _comparer.GetHashCode(key);

while (true)
{
TValue oldValue;
if (TryGetValueInternal(key, hashcode, out oldValue))
{
// key exists, try to update
TValue newValue = updateValueFactory(key, oldValue, factoryArgument);
if (TryUpdateInternal(key, hashcode, newValue, oldValue))
{
return newValue;
}
}
else
{
// key doesn't exist, try to add
TValue resultingValue;
if (TryAddInternal(key, hashcode, addValueFactory(key, factoryArgument), false, true, out resultingValue))
{
return resultingValue;
}
}
}
}

/// <summary>
/// Adds a key/value pair to the <see cref="ConcurrentDictionary{TKey,TValue}"/> if the key does not already
/// exist, or updates a key/value pair in the <see cref="ConcurrentDictionary{TKey,TValue}"/> if the key
Expand Down Expand Up @@ -1099,16 +1183,17 @@ public TValue AddOrUpdate(TKey key, Func<TKey, TValue> addValueFactory, Func<TKe
{
TValue oldValue;
if (TryGetValueInternal(key, hashcode, out oldValue))
//key exists, try to update
{
// key exists, try to update
TValue newValue = updateValueFactory(key, oldValue);
if (TryUpdateInternal(key, hashcode, newValue, oldValue))
{
return newValue;
}
}
else //try add
else
{
// key doesn't exist, try to add
TValue resultingValue;
if (TryAddInternal(key, hashcode, addValueFactory(key), false, true, out resultingValue))
{
Expand Down Expand Up @@ -1146,16 +1231,17 @@ public TValue AddOrUpdate(TKey key, TValue addValue, Func<TKey, TValue, TValue>
{
TValue oldValue;
if (TryGetValueInternal(key, hashcode, out oldValue))
//key exists, try to update
{
// key exists, try to update
TValue newValue = updateValueFactory(key, oldValue);
if (TryUpdateInternal(key, hashcode, newValue, oldValue))
{
return newValue;
}
}
else //try add
else
{
// key doesn't exist, try to add
TValue resultingValue;
if (TryAddInternal(key, hashcode, addValue, false, true, out resultingValue))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,25 +441,33 @@ private static void TestGetOrAddOrUpdate(int cLevel, int initSize, int threads,
{
if (isAdd)
{
//call either of the two overloads of GetOrAdd
if (j + ii % 2 == 0)
//call one of the overloads of GetOrAdd
switch (j % 3)
{
dict.GetOrAdd(j, -j);
}
else
{
dict.GetOrAdd(j, x => -x);
case 0:
dict.GetOrAdd(j, -j);
break;
case 1:
dict.GetOrAdd(j, x => -x);
break;
case 2:
dict.GetOrAdd(j, (x,m) => x * m, -1);
break;
}
}
else
{
if (j + ii % 2 == 0)
{
dict.AddOrUpdate(j, -j, (k, v) => -j);
}
else
switch (j % 3)
{
dict.AddOrUpdate(j, (k) => -k, (k, v) => -k);
case 0:
dict.AddOrUpdate(j, -j, (k, v) => -j);
break;
case 1:
dict.AddOrUpdate(j, (k) => -k, (k, v) => -k);
break;
case 2:
dict.AddOrUpdate(j, (k,m) => k*m, (k, v, m) => k * m, -1);
break;
}
}
}
Expand Down Expand Up @@ -615,6 +623,12 @@ public static void TestExceptions()
() => dictionary[null] = 1);
// "TestExceptions: FAILED. this[] didn't throw ANE when null key is passed");

Assert.Throws<ArgumentNullException>(
() => dictionary.GetOrAdd(null, (k,m) => 0, 0));
// "TestExceptions: FAILED. GetOrAdd didn't throw ANE when null key is passed");
Assert.Throws<ArgumentNullException>(
() => dictionary.GetOrAdd("1", null, 0));
// "TestExceptions: FAILED. GetOrAdd didn't throw ANE when null valueFactory is passed");
Assert.Throws<ArgumentNullException>(
() => dictionary.GetOrAdd(null, (k) => 0));
// "TestExceptions: FAILED. GetOrAdd didn't throw ANE when null key is passed");
Expand All @@ -625,6 +639,15 @@ public static void TestExceptions()
() => dictionary.GetOrAdd(null, 0));
// "TestExceptions: FAILED. GetOrAdd didn't throw ANE when null key is passed");

Assert.Throws<ArgumentNullException>(
() => dictionary.AddOrUpdate(null, (k, m) => 0, (k, v, m) => 0, 42));
// "TestExceptions: FAILED. AddOrUpdate didn't throw ANE when null key is passed");
Assert.Throws<ArgumentNullException>(
() => dictionary.AddOrUpdate("1", (k, m) => 0, null, 42));
// "TestExceptions: FAILED. AddOrUpdate didn't throw ANE when null updateFactory is passed");
Assert.Throws<ArgumentNullException>(
() => dictionary.AddOrUpdate("1", null, (k, v, m) => 0, 42));
// "TestExceptions: FAILED. AddOrUpdate didn't throw ANE when null addFactory is passed");
Assert.Throws<ArgumentNullException>(
() => dictionary.AddOrUpdate(null, (k) => 0, (k, v) => 0));
// "TestExceptions: FAILED. AddOrUpdate didn't throw ANE when null key is passed");
Expand Down

0 comments on commit c8f90c6

Please sign in to comment.