Skip to content

Commit

Permalink
ref and implicit broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelgsharp committed Apr 22, 2024
1 parent 278ccd9 commit 61faee1
Show file tree
Hide file tree
Showing 6 changed files with 749 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@ public static partial class Tensor
public static System.Numerics.Tensors.SpanND<T> Add<T>(System.Numerics.Tensors.SpanND<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Add<T>(System.Numerics.Tensors.Tensor<T> input, System.Numerics.Tensors.Tensor<T> other) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Add<T>(System.Numerics.Tensors.Tensor<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
public static bool AreShapesBroadcastToCompatible(System.ReadOnlySpan<nint> shape1, System.ReadOnlySpan<nint> shape2) { throw null; }
public static bool AreShapesBroadcastToCompatible<T>(System.Numerics.Tensors.Tensor<T> tensor1, System.Numerics.Tensors.Tensor<T> tensor2) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.Tensor<T> BroadcastTo<T>(System.Numerics.Tensors.Tensor<T> input, System.ReadOnlySpan<nint> shape) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Broadcast<T>(System.Numerics.Tensors.Tensor<T> input, System.ReadOnlySpan<nint> shape) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Concatenate<T>(System.ReadOnlySpan<System.Numerics.Tensors.Tensor<T>> tensors, int axis = 0) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.SpanND<T> CosInPlace<T>(System.Numerics.Tensors.SpanND<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.ITrigonometricFunctions<T> { throw null; }
public static System.Numerics.Tensors.Tensor<T> CosInPlace<T>(System.Numerics.Tensors.Tensor<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.ITrigonometricFunctions<T> { throw null; }
Expand Down Expand Up @@ -218,7 +216,6 @@ public static partial class Tensor
public static System.Numerics.Tensors.Tensor<T> FillRange<T>(System.Collections.Generic.IEnumerable<T> data) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.Tensor<T> FilteredUpdate<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<bool> filter, System.Numerics.Tensors.Tensor<T> values) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.Tensor<T> FilteredUpdate<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<bool> filter, T value) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static nint[] GetIntermediateShape(System.ReadOnlySpan<nint> shape1, int shape2Length) { throw null; }
public static bool GreaterThanAll<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<T> right) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IComparisonOperators<T, T, bool> { throw null; }
public static bool GreaterThanAny<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<T> right) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IComparisonOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.Tensor<bool> GreaterThan<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<T> right) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IComparisonOperators<T, T, bool> { throw null; }
Expand Down Expand Up @@ -249,13 +246,14 @@ public static partial class Tensor
public static System.Numerics.Tensors.Tensor<T> MultiplyInPlace<T>(System.Numerics.Tensors.Tensor<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
public static System.Numerics.Tensors.SpanND<T> Multiply<T>(System.Numerics.Tensors.SpanND<T> input, System.Numerics.Tensors.Tensor<T> other) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
public static System.Numerics.Tensors.SpanND<T> Multiply<T>(System.Numerics.Tensors.SpanND<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Multiply<T>(System.Numerics.Tensors.Tensor<T> input, System.Numerics.Tensors.Tensor<T> other) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Multiply<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<T> right) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Multiply<T>(System.Numerics.Tensors.Tensor<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Normal<T>(params nint[] lengths) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IFloatingPoint<T> { throw null; }
public static T Norm<T>(System.Numerics.Tensors.SpanND<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IRootFunctions<T> { throw null; }
public static T Norm<T>(System.Numerics.Tensors.Tensor<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IRootFunctions<T> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Permute<T>(System.Numerics.Tensors.Tensor<T> input, params int[] axis) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Permute<T>(System.Numerics.Tensors.Tensor<T> input, System.ReadOnlySpan<int> axis) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.SpanND<T> Reshape<T>(this System.Numerics.Tensors.SpanND<T> input, System.ReadOnlySpan<nint> lengths) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Reshape<T>(this System.Numerics.Tensors.Tensor<T> input, params nint[] lengths) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Reshape<T>(this System.Numerics.Tensors.Tensor<T> input, System.ReadOnlySpan<nint> lengths) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.SpanND<T> Resize<T>(System.Numerics.Tensors.SpanND<T> input, System.ReadOnlySpan<nint> shape) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
Expand All @@ -276,7 +274,9 @@ public static partial class Tensor
public static System.Numerics.Tensors.Tensor<T> Squeeze<T>(System.Numerics.Tensors.Tensor<T> input, int axis = -1) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static System.Numerics.Tensors.Tensor<T> Stack<T>(System.Numerics.Tensors.Tensor<T>[] input, int axis = 0) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
public static T StdDev<T>(System.Numerics.Tensors.Tensor<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IFloatingPoint<T>, System.Numerics.IPowerFunctions<T>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
public static System.Numerics.Tensors.Tensor<T> StdDev<T>(System.Numerics.Tensors.Tensor<T> input, int axis) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IFloatingPoint<T>, System.Numerics.IPowerFunctions<T>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
public static TResult StdDev<T, TResult>(System.Numerics.Tensors.Tensor<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.INumber<T> where TResult : System.IEquatable<TResult>, System.Numerics.IEqualityOperators<TResult, TResult, bool>, System.Numerics.IFloatingPoint<TResult> { throw null; }
public static System.Numerics.Tensors.Tensor<TResult> StdDev<T, TResult>(System.Numerics.Tensors.Tensor<T> input, int axis) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.INumber<T> where TResult : System.IEquatable<TResult>, System.Numerics.IEqualityOperators<TResult, TResult, bool>, System.Numerics.IFloatingPoint<TResult> { throw null; }
public static System.Numerics.Tensors.SpanND<T> SubtractInPlace<T>(System.Numerics.Tensors.SpanND<T> input, System.Numerics.Tensors.Tensor<T> other) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.ISubtractionOperators<T, T, T> { throw null; }
public static System.Numerics.Tensors.SpanND<T> SubtractInPlace<T>(System.Numerics.Tensors.SpanND<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.ISubtractionOperators<T, T, T> { throw null; }
public static System.Numerics.Tensors.Tensor<T> SubtractInPlace<T>(System.Numerics.Tensors.Tensor<T> input, System.Numerics.Tensors.Tensor<T> other) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.ISubtractionOperators<T, T, T> { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
</ItemGroup>

<ItemGroup Condition="'$(TargetFrameworkIdentifier)' == '.NETCoreApp'">
<Compile Include="System\Numerics\Tensors\netcore\TensorHelpers.cs" />
<Compile Include="System\Numerics\Tensors\netcore\TensorExtensions.cs" />
<Compile Include="System\Numerics\Tensors\netcore\Tensor.Factory.cs" />
<Compile Include="System\Numerics\Tensors\netcore\Tensor.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@ public static unsafe bool SequenceEqual<T>(this SpanND<T> span, SpanND<T> other)
nint length = span.LinearLength;
nint otherLength = other.LinearLength;

//if (RuntimeHelpers.IsBitwiseEquatable<T>())
//{
// return length == otherLength &&
// SpanHelpers.SequenceEqual(
// ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
// ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
// ((uint)otherLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
//}

return length == otherLength && SpanHelpers.SequenceEqual(ref span.GetPinnableReference(), ref other.GetPinnableReference(), length);
}

Expand Down
Loading

0 comments on commit 61faee1

Please sign in to comment.