diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs
index 0da8b6dfcdec2..41fe81416b27a 100644
--- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs
+++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs
@@ -22,7 +22,8 @@ public static partial class TensorPrimitives
/// If a value is equal to , the result stored into the corresponding destination location is the original NaN value with the sign bit removed.
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
public static void Abs(ReadOnlySpan x, Span destination) =>
@@ -39,7 +40,9 @@ public static void Abs(ReadOnlySpan x, Span destination) =>
/// This method effectively computes [i] = [i] + [i].
///
///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example,
+ /// to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -58,7 +61,9 @@ public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span
/// This method effectively computes [i] = [i] + .
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// and may overlap, but only if they start at the same memory location;
+ /// otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters, such as to perform
+ /// an in-place operation.
///
///
/// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -79,7 +84,9 @@ public static void Add(ReadOnlySpan x, float y, Span destination)
/// This method effectively computes [i] = ([i] + [i]) * [i].
///
///
- /// , , and may overlap, but none of them may overlap with ; if they do, behavior is undefined.
+ /// , , and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined.
+ /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -100,7 +107,9 @@ public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, Rea
/// This method effectively computes [i] = ([i] + [i]) * .
///
///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined.
+ /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -121,7 +130,9 @@ public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, flo
/// This method effectively computes [i] = ([i] + ) * [i].
///
///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined.
+ /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -139,7 +150,8 @@ public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan[i] = .Cosh([i]).
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// If a value is equal to or , the result stored into the corresponding destination location is set to .
@@ -250,7 +262,9 @@ public static float Distance(ReadOnlySpan x, ReadOnlySpan y)
/// This method effectively computes [i] = [i] / [i].
///
///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example,
+ /// to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -269,7 +283,8 @@ public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span[i] = [i] / .
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -320,7 +335,8 @@ public static float Dot(ReadOnlySpan x, ReadOnlySpan y)
/// This method effectively computes [i] = .Exp([i]).
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// If a value equals or , the result stored into the corresponding destination location is set to NaN.
@@ -559,7 +575,8 @@ public static unsafe int IndexOfMinMagnitude(ReadOnlySpan x)
/// This method effectively computes [i] = .Log([i]).
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// If a value equals 0, the result stored into the corresponding destination location is set to .
@@ -594,7 +611,8 @@ public static void Log(ReadOnlySpan x, Span destination)
/// This method effectively computes [i] = .Log2([i]).
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// If a value equals 0, the result stored into the corresponding destination location is set to .
@@ -648,7 +666,9 @@ public static float Max(ReadOnlySpan x) =>
/// This method effectively computes [i] = MathF.Max([i], [i]).
///
///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example,
+ /// to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If either value is equal to ,
@@ -689,12 +709,9 @@ public static float MaxMagnitude(ReadOnlySpan x) =>
/// This method effectively computes [i] = MathF.MaxMagnitude([i], [i]).
///
///
- /// The determination of the maximum magnitude matches the IEEE 754:2019 `maximumMagnitude` function. If either value is equal to ,
- /// that value is stored as the result. If the two values have the same magnitude and one is positive and the other is negative,
- /// the positive value is considered to have the larger magnitude.
- ///
- ///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example,
+ /// to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
@@ -736,7 +753,9 @@ public static float Min(ReadOnlySpan x) =>
/// that value is stored as the result. Positive 0 is considered greater than negative 0.
///
///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example,
+ /// to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
@@ -778,7 +797,9 @@ public static float MinMagnitude(ReadOnlySpan x) =>
/// the negative value is considered to have the smaller magnitude.
///
///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example,
+ /// to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
@@ -799,7 +820,9 @@ public static void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Sp
/// This method effectively computes [i] = [i] * [i].
///
///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined. It is safe, for example,
+ /// to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -819,7 +842,8 @@ public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Spanscal method defined by BLAS1.
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -840,12 +864,16 @@ public static void Multiply(ReadOnlySpan x, float y, Span destinat
/// This method effectively computes [i] = ([i] * [i]) + [i].
///
///
- /// , , and may overlap, but none of them may overlap with ; if they do, behavior is undefined.
+ /// , , and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined.
+ /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN.
///
///
+
+
public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) =>
InvokeSpanSpanSpanIntoSpan(x, y, addend, destination);
@@ -862,7 +890,9 @@ public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, Rea
/// It corresponds to the axpy method defined by BLAS1.
///
///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined.
+ /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -883,7 +913,9 @@ public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, flo
/// This method effectively computes [i] = ([i] * ) + [i].
///
///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined.
+ /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -901,7 +933,8 @@ public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan[i] = -[i].
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -1035,7 +1068,8 @@ public static float ProductOfSums(ReadOnlySpan x, ReadOnlySpan y)
/// This method effectively computes [i] = 1f / (1f + .Exp(-[i])).
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
@@ -1069,7 +1103,8 @@ public static void Sigmoid(ReadOnlySpan x, Span destination)
/// This method effectively computes [i] = .Sinh([i]).
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// If a value is equal to , , or ,
@@ -1107,7 +1142,8 @@ public static void Sinh(ReadOnlySpan x, Span destination)
/// It then effectively computes [i] = MathF.Exp([i]) / sum.
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
@@ -1150,7 +1186,9 @@ public static void SoftMax(ReadOnlySpan x, Span destination)
/// This method effectively computes [i] = [i] - [i].
///
///
- /// and may overlap, but neither may overlap with ; if they do, behavior is undefined.
+ /// and may overlap arbitrarily, but they may only overlap with
+ /// if the input and the output span begin at the same memory location; otherwise, behavior is undefined.
+ /// It is safe, for example, to use the same span for any subset of the span parameters, such as to perform an in-place operation.
///
///
/// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -1169,7 +1207,8 @@ public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span[i] = [i] - .
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN.
@@ -1244,7 +1283,8 @@ public static float SumOfSquares(ReadOnlySpan x) =>
/// This method effectively computes [i] = .Tanh([i]).
///
///
- /// and may not overlap; if they do, behavior is undefined.
+ /// may overlap with , but only if the input and the output span begin at the same memory
+ /// location; otherwise, behavior is undefined. It is safe, for example, to use the same span for all span parameters.
///
///
/// If a value is equal to , the corresponding destination location is set to -1.
diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs
index bed07eedfefd1..bd18b16d47b69 100644
--- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs
+++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs
@@ -23,6 +23,9 @@ public static partial class TensorPrimitives
///
/// This method effectively computes [i] = (Half)[i].
///
+ ///
+ /// and must not overlap. If they do, behavior is undefined.
+ ///
///
public static void ConvertToHalf(ReadOnlySpan source, Span destination)
{
@@ -48,6 +51,9 @@ public static void ConvertToHalf(ReadOnlySpan source, Span destinat
///
/// This method effectively computes [i] = (float)[i].
///
+ ///
+ /// and must not overlap. If they do, behavior is undefined.
+ ///
///
public static void ConvertToSingle(ReadOnlySpan source, Span destination)
{
@@ -519,7 +525,10 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax :
return GetFirstNaN(current);
}
- result = TMinMax.Invoke(result, current);
+ result = Vector512.ConditionalSelect(
+ Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero),
+ result,
+ TMinMax.Invoke(result, current));
}
// Aggregate the lanes in the vector to create the final scalar result.
@@ -565,7 +574,10 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax :
return GetFirstNaN(current);
}
- result = TMinMax.Invoke(result, current);
+ result = Vector256.ConditionalSelect(
+ Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero),
+ result,
+ TMinMax.Invoke(result, current));
}
// Aggregate the lanes in the vector to create the final scalar result.
@@ -610,7 +622,10 @@ private static float MinMaxCore(ReadOnlySpan x) where TMinMax :
return GetFirstNaN(current);
}
- result = TMinMax.Invoke(result, current);
+ result = Vector128.ConditionalSelect(
+ Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero),
+ result,
+ TMinMax.Invoke(result, current));
}
// Aggregate the lanes in the vector to create the final scalar result.
@@ -672,7 +687,10 @@ private static unsafe void InvokeSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector512.Count);
- TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector512.ConditionalSelect(
+ Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero),
+ Vector512.LoadUnsafe(ref dRef, lastVectorIndex),
+ TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -698,7 +716,10 @@ private static unsafe void InvokeSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector256.Count);
- TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector256.ConditionalSelect(
+ Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero),
+ Vector256.LoadUnsafe(ref dRef, lastVectorIndex),
+ TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -723,7 +744,10 @@ private static unsafe void InvokeSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector128.Count);
- TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector128.ConditionalSelect(
+ Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero),
+ Vector128.LoadUnsafe(ref dRef, lastVectorIndex),
+ TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -777,8 +801,11 @@ private static unsafe void InvokeSpanSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector512.Count);
- TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex),
- Vector512.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector512.ConditionalSelect(
+ Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero),
+ Vector512.LoadUnsafe(ref dRef, lastVectorIndex),
+ TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex),
+ Vector512.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -805,8 +832,11 @@ private static unsafe void InvokeSpanSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector256.Count);
- TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex),
- Vector256.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector256.ConditionalSelect(
+ Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero),
+ Vector256.LoadUnsafe(ref dRef, lastVectorIndex),
+ TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex),
+ Vector256.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -832,8 +862,11 @@ private static unsafe void InvokeSpanSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector128.Count);
- TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex),
- Vector128.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector128.ConditionalSelect(
+ Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero),
+ Vector128.LoadUnsafe(ref dRef, lastVectorIndex),
+ TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex),
+ Vector128.LoadUnsafe(ref yRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -884,8 +917,11 @@ private static unsafe void InvokeSpanScalarIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector512.Count);
- TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex),
- yVec).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector512.ConditionalSelect(
+ Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero),
+ Vector512.LoadUnsafe(ref dRef, lastVectorIndex),
+ TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex),
+ yVec)).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -914,8 +950,11 @@ private static unsafe void InvokeSpanScalarIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector256.Count);
- TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex),
- yVec).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector256.ConditionalSelect(
+ Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero),
+ Vector256.LoadUnsafe(ref dRef, lastVectorIndex),
+ TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex),
+ yVec)).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -943,8 +982,11 @@ private static unsafe void InvokeSpanScalarIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector128.Count);
- TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex),
- yVec).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector128.ConditionalSelect(
+ Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero),
+ Vector128.LoadUnsafe(ref dRef, lastVectorIndex),
+ TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex),
+ yVec)).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -1001,9 +1043,12 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector512.Count);
- TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex),
- Vector512.LoadUnsafe(ref yRef, lastVectorIndex),
- Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector512.ConditionalSelect(
+ Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero),
+ Vector512.LoadUnsafe(ref dRef, lastVectorIndex),
+ TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex),
+ Vector512.LoadUnsafe(ref yRef, lastVectorIndex),
+ Vector512.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -1031,9 +1076,12 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector256.Count);
- TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex),
- Vector256.LoadUnsafe(ref yRef, lastVectorIndex),
- Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector256.ConditionalSelect(
+ Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero),
+ Vector256.LoadUnsafe(ref dRef, lastVectorIndex),
+ TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex),
+ Vector256.LoadUnsafe(ref yRef, lastVectorIndex),
+ Vector256.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -1060,9 +1108,12 @@ private static unsafe void InvokeSpanSpanSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector128.Count);
- TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex),
- Vector128.LoadUnsafe(ref yRef, lastVectorIndex),
- Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector128.ConditionalSelect(
+ Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero),
+ Vector128.LoadUnsafe(ref dRef, lastVectorIndex),
+ TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex),
+ Vector128.LoadUnsafe(ref yRef, lastVectorIndex),
+ Vector128.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -1121,9 +1172,12 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector512.Count);
- TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex),
- Vector512.LoadUnsafe(ref yRef, lastVectorIndex),
- zVec).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector512.ConditionalSelect(
+ Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero),
+ Vector512.LoadUnsafe(ref dRef, lastVectorIndex),
+ TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex),
+ Vector512.LoadUnsafe(ref yRef, lastVectorIndex),
+ zVec)).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -1153,9 +1207,12 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector256.Count);
- TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex),
- Vector256.LoadUnsafe(ref yRef, lastVectorIndex),
- zVec).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector256.ConditionalSelect(
+ Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero),
+ Vector256.LoadUnsafe(ref dRef, lastVectorIndex),
+ TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex),
+ Vector256.LoadUnsafe(ref yRef, lastVectorIndex),
+ zVec)).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -1184,9 +1241,12 @@ private static unsafe void InvokeSpanSpanScalarIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector128.Count);
- TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex),
- Vector128.LoadUnsafe(ref yRef, lastVectorIndex),
- zVec).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector128.ConditionalSelect(
+ Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero),
+ Vector128.LoadUnsafe(ref dRef, lastVectorIndex),
+ TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex),
+ Vector128.LoadUnsafe(ref yRef, lastVectorIndex),
+ zVec)).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -1245,9 +1305,12 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector512.Count);
- TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex),
- yVec,
- Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector512.ConditionalSelect(
+ Vector512.Equals(LoadRemainderMaskSingleVector512(x.Length - i), Vector512.Zero),
+ Vector512.LoadUnsafe(ref dRef, lastVectorIndex),
+ TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex),
+ yVec,
+ Vector512.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -1277,9 +1340,12 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector256.Count);
- TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex),
- yVec,
- Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector256.ConditionalSelect(
+ Vector256.Equals(LoadRemainderMaskSingleVector256(x.Length - i), Vector256.Zero),
+ Vector256.LoadUnsafe(ref dRef, lastVectorIndex),
+ TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex),
+ yVec,
+ Vector256.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
@@ -1308,9 +1374,12 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan(
if (i != x.Length)
{
uint lastVectorIndex = (uint)(x.Length - Vector128.Count);
- TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex),
- yVec,
- Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex);
+ Vector128.ConditionalSelect(
+ Vector128.Equals(LoadRemainderMaskSingleVector128(x.Length - i), Vector128.Zero),
+ Vector128.LoadUnsafe(ref dRef, lastVectorIndex),
+ TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex),
+ yVec,
+ Vector128.LoadUnsafe(ref zRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex);
}
return;
diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs
index e05e54bcad769..70207a5c8995b 100644
--- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs
+++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs
@@ -320,7 +320,11 @@ private static void InvokeSpanIntoSpan(
if (i != x.Length)
{
int lastVectorIndex = x.Length - Vector.Count;
- AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex));
+ ref Vector dest = ref AsVector(ref dRef, lastVectorIndex);
+ dest = Vector.ConditionalSelect(
+ Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero),
+ dest,
+ op.Invoke(AsVector(ref xRef, lastVectorIndex)));
}
return;
@@ -374,8 +378,12 @@ private static void InvokeSpanSpanIntoSpan(
if (i != x.Length)
{
int lastVectorIndex = x.Length - Vector.Count;
- AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex),
- AsVector(ref yRef, lastVectorIndex));
+ ref Vector dest = ref AsVector(ref dRef, lastVectorIndex);
+ dest = Vector.ConditionalSelect(
+ Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero),
+ dest,
+ op.Invoke(AsVector(ref xRef, lastVectorIndex),
+ AsVector(ref yRef, lastVectorIndex)));
}
return;
@@ -424,8 +432,11 @@ private static void InvokeSpanScalarIntoSpan(
if (i != x.Length)
{
int lastVectorIndex = x.Length - Vector.Count;
- AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex),
- yVec);
+ ref Vector dest = ref AsVector(ref dRef, lastVectorIndex);
+ dest = Vector.ConditionalSelect(
+ Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero),
+ dest,
+ op.Invoke(AsVector(ref xRef, lastVectorIndex), yVec));
}
return;
@@ -482,9 +493,13 @@ private static void InvokeSpanSpanSpanIntoSpan(
if (i != x.Length)
{
int lastVectorIndex = x.Length - Vector.Count;
- AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex),
- AsVector(ref yRef, lastVectorIndex),
- AsVector(ref zRef, lastVectorIndex));
+ ref Vector dest = ref AsVector(ref dRef, lastVectorIndex);
+ dest = Vector.ConditionalSelect(
+ Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero),
+ dest,
+ op.Invoke(AsVector(ref xRef, lastVectorIndex),
+ AsVector(ref yRef, lastVectorIndex),
+ AsVector(ref zRef, lastVectorIndex)));
}
return;
@@ -543,9 +558,13 @@ private static void InvokeSpanSpanScalarIntoSpan(
if (i != x.Length)
{
int lastVectorIndex = x.Length - Vector.Count;
- AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex),
- AsVector(ref yRef, lastVectorIndex),
- zVec);
+ ref Vector dest = ref AsVector(ref dRef, lastVectorIndex);
+ dest = Vector.ConditionalSelect(
+ Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero),
+ dest,
+ op.Invoke(AsVector(ref xRef, lastVectorIndex),
+ AsVector(ref yRef, lastVectorIndex),
+ zVec));
}
return;
@@ -604,9 +623,13 @@ private static void InvokeSpanScalarSpanIntoSpan(
if (i != x.Length)
{
int lastVectorIndex = x.Length - Vector.Count;
- AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex),
- yVec,
- AsVector(ref zRef, lastVectorIndex));
+ ref Vector dest = ref AsVector(ref dRef, lastVectorIndex);
+ dest = Vector.ConditionalSelect(
+ Vector.Equals(LoadRemainderMaskSingleVector(x.Length - i), Vector.Zero),
+ dest,
+ op.Invoke(AsVector(ref xRef, lastVectorIndex),
+ yVec,
+ AsVector(ref zRef, lastVectorIndex)));
}
return;
diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs
index edcebe8eb4775..751e352dd1da5 100644
--- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs
+++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs
@@ -75,6 +75,21 @@ public static void Abs(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Abs_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Abs(x, x);
+
+ for (int i = 0; i < x.Length; i++)
+ {
+ Assert.Equal(MathF.Abs(xOrig[i]), x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Abs_ThrowsForTooShortDestination(int tensorLength)
@@ -96,11 +111,34 @@ public static void Add_TwoTensors(int tensorLength)
using BoundedMemory destination = CreateTensor(tensorLength);
TensorPrimitives.Add(x, y, destination);
-
for (int i = 0; i < tensorLength; i++)
{
Assert.Equal(x[i] + y[i], destination[i], Tolerance);
}
+
+ float[] xOrig = x.Span.ToArray();
+
+ // Validate that the destination can be the same as an input.
+ TensorPrimitives.Add(x, x, x);
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(xOrig[i] + xOrig[i], x[i], Tolerance);
+ }
+ }
+
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Add_TwoTensors_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Add(x, x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(xOrig[i] + xOrig[i], x[i], Tolerance);
+ }
}
[Theory]
@@ -142,6 +180,22 @@ public static void Add_TensorScalar(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Add_TensorScalar_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+ float y = NextSingle();
+
+ TensorPrimitives.Add(x, y, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(xOrig[i] + y, x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Add_TensorScalar_ThrowsForTooShortDestination(int tensorLength)
@@ -172,6 +226,21 @@ public static void AddMultiply_ThreeTensors(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void AddMultiply_ThreeTensors_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.AddMultiply(x, x, x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal((xOrig[i] + xOrig[i]) * xOrig[i], x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void AddMultiply_ThreeTensors_ThrowsForMismatchedLengths(int tensorLength)
@@ -215,6 +284,22 @@ public static void AddMultiply_TensorTensorScalar(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void AddMultiply_TensorTensorScalar_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+ float multiplier = NextSingle();
+
+ TensorPrimitives.AddMultiply(x, x, multiplier, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal((xOrig[i] + xOrig[i]) * multiplier, x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void AddMultiply_TensorTensorScalar_ThrowsForMismatchedLengths_x_y(int tensorLength)
@@ -257,6 +342,22 @@ public static void AddMultiply_TensorScalarTensor(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void AddMultiply_TensorScalarTensor_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+ float y = NextSingle();
+
+ TensorPrimitives.AddMultiply(x, y, x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal((xOrig[i] + y) * xOrig[i], x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void AddMultiply_TensorScalarTensor_ThrowsForMismatchedLengths_x_z(int tensorLength)
@@ -299,6 +400,21 @@ public static void Cosh(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Cosh_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Cosh(x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Cosh(xOrig[i]), x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Cosh_ThrowsForTooShortDestination(int tensorLength)
@@ -421,6 +537,21 @@ public static void Divide_TwoTensors(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Divide_TwoTensors_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Divide(x, x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(xOrig[i] / xOrig[i], x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Divide_TwoTensors_ThrowsForMismatchedLengths(int tensorLength)
@@ -460,6 +591,22 @@ public static void Divide_TensorScalar(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Divide_TensorScalar_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+ float y = NextSingle();
+
+ TensorPrimitives.Divide(x, y, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(xOrig[i] / y, x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Divide_TensorScalar_ThrowsForTooShortDestination(int tensorLength)
@@ -527,6 +674,21 @@ public static void Exp(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Exp_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Exp(x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Exp(xOrig[i]), x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Exp_ThrowsForTooShortDestination(int tensorLength)
@@ -735,6 +897,21 @@ public static void Log(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Log_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Log(x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Log(xOrig[i]), x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Log_ThrowsForTooShortDestination(int tensorLength)
@@ -762,6 +939,21 @@ public static void Log2(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Log2_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Log2(x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Log(xOrig[i], 2), x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Log2_ThrowsForTooShortDestination(int tensorLength)
@@ -834,6 +1026,32 @@ public static void Max_TwoTensors(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Max_TwoTensors_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ using BoundedMemory y = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray();
+
+ TensorPrimitives.Max(x, y, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Max(xOrig[i], y[i]), x[i], Tolerance);
+ }
+
+ xOrig.AsSpan().CopyTo(x.Span);
+ yOrig.AsSpan().CopyTo(y.Span);
+
+ TensorPrimitives.Max(x, y, y);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Max(x[i], yOrig[i]), y[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Max_TwoTensors_SpecialValues(int tensorLength)
@@ -955,6 +1173,32 @@ public static void MaxMagnitude_TwoTensors(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void MaxMagnitude_TwoTensors_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ using BoundedMemory y = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray();
+
+ TensorPrimitives.MaxMagnitude(x, y, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathFMaxMagnitude(xOrig[i], y[i]), x[i], Tolerance);
+ }
+
+ xOrig.AsSpan().CopyTo(x.Span);
+ yOrig.AsSpan().CopyTo(y.Span);
+
+ TensorPrimitives.MaxMagnitude(x, y, y);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathFMaxMagnitude(x[i], yOrig[i]), y[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void MaxMagnitude_TwoTensors_SpecialValues(int tensorLength)
@@ -1075,6 +1319,32 @@ public static void Min_TwoTensors(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Min_TwoTensors_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ using BoundedMemory y = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray();
+
+ TensorPrimitives.Min(x, y, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Min(xOrig[i], y[i]), x[i], Tolerance);
+ }
+
+ xOrig.AsSpan().CopyTo(x.Span);
+ yOrig.AsSpan().CopyTo(y.Span);
+
+ TensorPrimitives.Min(x, y, y);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Min(x[i], yOrig[i]), y[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Min_TwoTensors_SpecialValues(int tensorLength)
@@ -1194,6 +1464,32 @@ public static void MinMagnitude_TwoTensors(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void MinMagnitude_TwoTensors_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ using BoundedMemory y = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray();
+
+ TensorPrimitives.MinMagnitude(x, y, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathFMinMagnitude(xOrig[i], y[i]), x[i], Tolerance);
+ }
+
+ xOrig.AsSpan().CopyTo(x.Span);
+ yOrig.AsSpan().CopyTo(y.Span);
+
+ TensorPrimitives.MinMagnitude(x, y, y);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathFMinMagnitude(x[i], yOrig[i]), y[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void MinMagnitude_TwoTensors_SpecialValues(int tensorLength)
@@ -1270,6 +1566,21 @@ public static void Multiply_TwoTensors(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Multiply_TwoTensors_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Multiply(x, x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(xOrig[i] * xOrig[i], x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Multiply_TwoTensors_ThrowsForMismatchedLengths(int tensorLength)
@@ -1309,6 +1620,22 @@ public static void Multiply_TensorScalar(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Multiply_TensorScalar_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+ float y = NextSingle();
+
+ TensorPrimitives.Multiply(x, y, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(xOrig[i] * y, x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Multiply_TensorScalar_ThrowsForTooShortDestination(int tensorLength)
@@ -1339,6 +1666,21 @@ public static void MultiplyAdd_ThreeTensors(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void MultiplyAdd_ThreeTensors_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.MultiplyAdd(x, x, x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal((xOrig[i] * xOrig[i]) + xOrig[i], x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void MultiplyAdd_ThreeTensors_ThrowsForMismatchedLengths_x_y(int tensorLength)
@@ -1382,6 +1724,22 @@ public static void MultiplyAdd_TensorTensorScalar(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void MultiplyAdd_TensorTensorScalar_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+ float addend = NextSingle();
+
+ TensorPrimitives.MultiplyAdd(x, x, addend, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal((xOrig[i] * xOrig[i]) + addend, x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void MultiplyAdd_TensorTensorScalar_ThrowsForTooShortDestination(int tensorLength)
@@ -1411,6 +1769,22 @@ public static void MultiplyAdd_TensorScalarTensor(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void MultiplyAdd_TensorScalarTensor_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+ float y = NextSingle();
+
+ TensorPrimitives.MultiplyAdd(x, y, x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal((xOrig[i] * y) + xOrig[i], x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void MultiplyAdd_TensorScalarTensor_ThrowsForTooShortDestination(int tensorLength)
@@ -1440,6 +1814,21 @@ public static void Negate(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Negate_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Negate(x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(-xOrig[i], x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Negate_ThrowsForTooShortDestination(int tensorLength)
@@ -1598,6 +1987,36 @@ public static void ProductOfSums_KnownValues()
#endregion
#region Sigmoid
+ [Theory]
+ [MemberData(nameof(TensorLengths))]
+ public static void Sigmoid(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ using BoundedMemory destination = CreateTensor(tensorLength);
+
+ TensorPrimitives.Sigmoid(x, destination);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(1f / (1f + MathF.Exp(-x[i])), destination[i], Tolerance);
+ }
+ }
+
+ [Theory]
+ [MemberData(nameof(TensorLengths))]
+ public static void Sigmoid_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Sigmoid(x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(1f / (1f + MathF.Exp(-xOrig[i])), x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Sigmoid_ThrowsForTooShortDestination(int tensorLength)
@@ -1612,7 +2031,7 @@ public static void Sigmoid_ThrowsForTooShortDestination(int tensorLength)
[InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })]
[InlineData(new float[] { 4.5f, 5 }, new float[] { 0.9890f, 0.9933f })]
[InlineData(new float[] { 0, -3, 3, .5f }, new float[] { 0.5f, 0.0474f, 0.9525f, 0.6224f })]
- public static void Sigmoid(float[] x, float[] expectedResult)
+ public static void Sigmoid_KnownValues(float[] x, float[] expectedResult)
{
using BoundedMemory dest = CreateTensor(x.Length);
TensorPrimitives.Sigmoid(x, dest);
@@ -1663,6 +2082,21 @@ public static void Sinh(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Sinh_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Sinh(x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Sinh(xOrig[i]), x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Sinh_ThrowsForTooShortDestination(int tensorLength)
@@ -1675,6 +2109,38 @@ public static void Sinh_ThrowsForTooShortDestination(int tensorLength)
#endregion
#region SoftMax
+ [Theory]
+ [MemberData(nameof(TensorLengths))]
+ public static void SoftMax(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ using BoundedMemory destination = CreateTensor(tensorLength);
+
+ TensorPrimitives.SoftMax(x, destination);
+
+ float expSum = MemoryMarshal.ToEnumerable(x.Memory).Sum(MathF.Exp);
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Exp(x[i]) / expSum, destination[i], Tolerance);
+ }
+ }
+
+ [Theory]
+ [MemberData(nameof(TensorLengths))]
+ public static void SoftMax_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.SoftMax(x, x);
+
+ float expSum = xOrig.Sum(MathF.Exp);
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Exp(xOrig[i]) / expSum, x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void SoftMax_ThrowsForTooShortDestination(int tensorLength)
@@ -1690,7 +2156,7 @@ public static void SoftMax_ThrowsForTooShortDestination(int tensorLength)
[InlineData(new float[] { 3, 4, 1 }, new float[] { 0.2594f, 0.705384f, 0.0351f })]
[InlineData(new float[] { 5, 3 }, new float[] { 0.8807f, 0.1192f })]
[InlineData(new float[] { 4, 2, 1, 9 }, new float[] { 0.0066f, 9.04658e-4f, 3.32805e-4f, 0.9920f })]
- public static void SoftMax(float[] x, float[] expectedResult)
+ public static void SoftMax_KnownValues(float[] x, float[] expectedResult)
{
using BoundedMemory dest = CreateTensor(x.Length);
TensorPrimitives.SoftMax(x, dest);
@@ -1739,6 +2205,21 @@ public static void Subtract_TwoTensors(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Subtract_TwoTensors_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Subtract(x, x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(xOrig[i] - xOrig[i], x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Subtract_TwoTensors_ThrowsForMismatchedLengths(int tensorLength)
@@ -1778,6 +2259,22 @@ public static void Subtract_TensorScalar(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Subtract_TensorScalar_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+ float y = NextSingle();
+
+ TensorPrimitives.Subtract(x, y, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(xOrig[i] - y, x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Subtract_TensorScalar_ThrowsForTooShortDestination(int tensorLength)
@@ -1797,7 +2294,7 @@ public static void Sum(int tensorLength)
{
using BoundedMemory x = CreateAndFillTensor(tensorLength);
- Assert.Equal(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Sum(x), Tolerance);
+ Assert.Equal(MemoryMarshal.ToEnumerable(x.Memory).Sum(), TensorPrimitives.Sum(x), Tolerance);
float sum = 0;
foreach (float f in x.Span)
@@ -1890,6 +2387,21 @@ public static void Tanh(int tensorLength)
}
}
+ [Theory]
+ [MemberData(nameof(TensorLengthsIncluding0))]
+ public static void Tanh_InPlace(int tensorLength)
+ {
+ using BoundedMemory x = CreateAndFillTensor(tensorLength);
+ float[] xOrig = x.Span.ToArray();
+
+ TensorPrimitives.Tanh(x, x);
+
+ for (int i = 0; i < tensorLength; i++)
+ {
+ Assert.Equal(MathF.Tanh(xOrig[i]), x[i], Tolerance);
+ }
+ }
+
[Theory]
[MemberData(nameof(TensorLengths))]
public static void Tanh_ThrowsForTooShortDestination(int tensorLength)