Skip to content

Commit

Permalink
Add generic overloads to TensorPrimitives (#94555)
Browse files Browse the repository at this point in the history
* Add generic overloads to TensorPrimitives

This overhauls the implementation and tests to have a generic overload for each existing float-based overload.  I've avoided touching the core logic, but have augmented the structure in a few ways, e.g. only taking vectorized code paths when the type supports vectorization.  To keep the shared definitions of the float-based APIs, on .NET 9 they delegate to shims that are implemented on top of the generic variants.

The tests have all been made instance members, with an abstract base class containing most of the tests, and calling into abstract methods for the core operations and validation routines. Derived types then fill in this logic, letting us use all the tests for both the non-generic and generic overloads.  Generic tests are validating most of the primitive types that implement the required interfaces.

This does not yet:
- Provide generic overloads for the IndexOfMin/Max{Magnitude} methods
- Vectorize the trig-related functions for Ts other than floats

* Disable tests on mono due to Vector128 bug

* Change "Float" to "Single" in various file, type, and member names
  • Loading branch information
stephentoub authored Jan 3, 2024
1 parent 5455432 commit 7e51126
Show file tree
Hide file tree
Showing 16 changed files with 8,302 additions and 5,308 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,49 @@ namespace System.Numerics.Tensors
{
public static partial class TensorPrimitives
{
public static void Abs<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.INumberBase<T> { }
public static void AddMultiply<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.ReadOnlySpan<T> multiplier, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
public static void AddMultiply<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, T multiplier, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
public static void AddMultiply<T>(System.ReadOnlySpan<T> x, T y, System.ReadOnlySpan<T> multiplier, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
public static void Add<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { }
public static void Add<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { }
public static void ConvertToHalf(System.ReadOnlySpan<float> source, System.Span<System.Half> destination) { throw null; }
public static void ConvertToSingle(System.ReadOnlySpan<System.Half> source, System.Span<float> destination) { throw null; }
public static void Cosh<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IHyperbolicFunctions<T> { }
public static T CosineSimilarity<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y) where T : System.Numerics.IRootFunctions<T> { throw null; }
public static T Distance<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y) where T : System.Numerics.IRootFunctions<T> { throw null; }
public static void Divide<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.IDivisionOperators<T, T, T> { }
public static void Divide<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.IDivisionOperators<T, T, T> { }
public static T Dot<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
public static void Exp<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IExponentialFunctions<T> { }
public static void Log2<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.ILogarithmicFunctions<T> { }
public static void Log<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.ILogarithmicFunctions<T> { }
public static T MaxMagnitude<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumberBase<T> { throw null; }
public static void MaxMagnitude<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumberBase<T> { }
public static T Max<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumber<T> { throw null; }
public static void Max<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
public static T MinMagnitude<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumberBase<T> { throw null; }
public static void MinMagnitude<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumberBase<T> { }
public static T Min<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumber<T> { throw null; }
public static void Min<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
public static void MultiplyAdd<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.ReadOnlySpan<T> addend, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
public static void MultiplyAdd<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, T addend, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
public static void MultiplyAdd<T>(System.ReadOnlySpan<T> x, T y, System.ReadOnlySpan<T> addend, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
public static void Multiply<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { }
public static void Multiply<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { }
public static void Negate<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IUnaryNegationOperators<T, T> { }
public static T Norm<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.IRootFunctions<T> { throw null; }
public static T ProductOfDifferences<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y) where T : System.Numerics.ISubtractionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
public static T ProductOfSums<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
public static T Product<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
public static void Sigmoid<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IExponentialFunctions<T> { }
public static void Sinh<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IHyperbolicFunctions<T> { }
public static void SoftMax<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IExponentialFunctions<T> { }
public static void Subtract<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.ISubtractionOperators<T, T, T> { }
public static void Subtract<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.ISubtractionOperators<T, T, T> { }
public static T SumOfMagnitudes<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumberBase<T> { throw null; }
public static T SumOfSquares<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T>, System.Numerics.IMultiplyOperators<T, T, T> { throw null; }
public static T Sum<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
public static void Tanh<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IHyperbolicFunctions<T> { }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,7 @@
<data name="Argument_InputAndDestinationSpanMustNotOverlap" xml:space="preserve">
<value>The destination span may only overlap with an input span if the two spans start at the same memory location.</value>
</data>
<data name="Overflow_NegateTwosCompNum" xml:space="preserve">
<value>Negating the minimum value of a twos complement number is invalid.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
</PropertyGroup>

<ItemGroup>
<Compile Include="System\Numerics\Tensors\TensorPrimitives.cs" />
<Compile Include="System\Numerics\Tensors\TensorPrimitives.Single.cs" />
<Compile Include="System\Numerics\Tensors\TensorPrimitives.Helpers.cs" />
<Compile Include="System\ThrowHelper.cs" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFrameworkIdentifier)' == '.NETCoreApp'">
<Compile Include="System\Numerics\Tensors\TensorPrimitives.netcore.cs" />
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.Single.netcore.cs" />
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.T.cs" />
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.netcore.cs" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'">
<Compile Include="System\Numerics\Tensors\TensorPrimitives.netstandard.cs" />
<Compile Include="System\Numerics\Tensors\netstandard\TensorPrimitives.Single.netstandard.cs" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
<ProjectReference Include="$(LibrariesProjectRoot)Microsoft.Bcl.Numerics\src\Microsoft.Bcl.Numerics.csproj" />
</ItemGroup>
Expand Down
Loading

0 comments on commit 7e51126

Please sign in to comment.