Skip to content

Commit 61faee1

Browse files
committed
ref and implicit broadcast
1 parent 278ccd9 commit 61faee1

File tree

6 files changed

+749
-139
lines changed

6 files changed

+749
-139
lines changed

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,7 @@ public static partial class Tensor
181181
public static System.Numerics.Tensors.SpanND<T> Add<T>(System.Numerics.Tensors.SpanND<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
182182
public static System.Numerics.Tensors.Tensor<T> Add<T>(System.Numerics.Tensors.Tensor<T> input, System.Numerics.Tensors.Tensor<T> other) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
183183
public static System.Numerics.Tensors.Tensor<T> Add<T>(System.Numerics.Tensors.Tensor<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
184-
public static bool AreShapesBroadcastToCompatible(System.ReadOnlySpan<nint> shape1, System.ReadOnlySpan<nint> shape2) { throw null; }
185-
public static bool AreShapesBroadcastToCompatible<T>(System.Numerics.Tensors.Tensor<T> tensor1, System.Numerics.Tensors.Tensor<T> tensor2) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
186-
public static System.Numerics.Tensors.Tensor<T> BroadcastTo<T>(System.Numerics.Tensors.Tensor<T> input, System.ReadOnlySpan<nint> shape) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
184+
public static System.Numerics.Tensors.Tensor<T> Broadcast<T>(System.Numerics.Tensors.Tensor<T> input, System.ReadOnlySpan<nint> shape) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
187185
public static System.Numerics.Tensors.Tensor<T> Concatenate<T>(System.ReadOnlySpan<System.Numerics.Tensors.Tensor<T>> tensors, int axis = 0) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
188186
public static System.Numerics.Tensors.SpanND<T> CosInPlace<T>(System.Numerics.Tensors.SpanND<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.ITrigonometricFunctions<T> { throw null; }
189187
public static System.Numerics.Tensors.Tensor<T> CosInPlace<T>(System.Numerics.Tensors.Tensor<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.ITrigonometricFunctions<T> { throw null; }
@@ -218,7 +216,6 @@ public static partial class Tensor
218216
public static System.Numerics.Tensors.Tensor<T> FillRange<T>(System.Collections.Generic.IEnumerable<T> data) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
219217
public static System.Numerics.Tensors.Tensor<T> FilteredUpdate<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<bool> filter, System.Numerics.Tensors.Tensor<T> values) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
220218
public static System.Numerics.Tensors.Tensor<T> FilteredUpdate<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<bool> filter, T value) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
221-
public static nint[] GetIntermediateShape(System.ReadOnlySpan<nint> shape1, int shape2Length) { throw null; }
222219
public static bool GreaterThanAll<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<T> right) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IComparisonOperators<T, T, bool> { throw null; }
223220
public static bool GreaterThanAny<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<T> right) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IComparisonOperators<T, T, bool> { throw null; }
224221
public static System.Numerics.Tensors.Tensor<bool> GreaterThan<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<T> right) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IComparisonOperators<T, T, bool> { throw null; }
@@ -249,13 +246,14 @@ public static partial class Tensor
249246
public static System.Numerics.Tensors.Tensor<T> MultiplyInPlace<T>(System.Numerics.Tensors.Tensor<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
250247
public static System.Numerics.Tensors.SpanND<T> Multiply<T>(System.Numerics.Tensors.SpanND<T> input, System.Numerics.Tensors.Tensor<T> other) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
251248
public static System.Numerics.Tensors.SpanND<T> Multiply<T>(System.Numerics.Tensors.SpanND<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
252-
public static System.Numerics.Tensors.Tensor<T> Multiply<T>(System.Numerics.Tensors.Tensor<T> input, System.Numerics.Tensors.Tensor<T> other) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
249+
public static System.Numerics.Tensors.Tensor<T> Multiply<T>(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<T> right) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
253250
public static System.Numerics.Tensors.Tensor<T> Multiply<T>(System.Numerics.Tensors.Tensor<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
254251
public static System.Numerics.Tensors.Tensor<T> Normal<T>(params nint[] lengths) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IFloatingPoint<T> { throw null; }
255252
public static T Norm<T>(System.Numerics.Tensors.SpanND<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IRootFunctions<T> { throw null; }
256253
public static T Norm<T>(System.Numerics.Tensors.Tensor<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IRootFunctions<T> { throw null; }
257254
public static System.Numerics.Tensors.Tensor<T> Permute<T>(System.Numerics.Tensors.Tensor<T> input, params int[] axis) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
258255
public static System.Numerics.Tensors.Tensor<T> Permute<T>(System.Numerics.Tensors.Tensor<T> input, System.ReadOnlySpan<int> axis) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
256+
public static System.Numerics.Tensors.SpanND<T> Reshape<T>(this System.Numerics.Tensors.SpanND<T> input, System.ReadOnlySpan<nint> lengths) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
259257
public static System.Numerics.Tensors.Tensor<T> Reshape<T>(this System.Numerics.Tensors.Tensor<T> input, params nint[] lengths) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
260258
public static System.Numerics.Tensors.Tensor<T> Reshape<T>(this System.Numerics.Tensors.Tensor<T> input, System.ReadOnlySpan<nint> lengths) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
261259
public static System.Numerics.Tensors.SpanND<T> Resize<T>(System.Numerics.Tensors.SpanND<T> input, System.ReadOnlySpan<nint> shape) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
@@ -276,7 +274,9 @@ public static partial class Tensor
276274
public static System.Numerics.Tensors.Tensor<T> Squeeze<T>(System.Numerics.Tensors.Tensor<T> input, int axis = -1) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
277275
public static System.Numerics.Tensors.Tensor<T> Stack<T>(System.Numerics.Tensors.Tensor<T>[] input, int axis = 0) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool> { throw null; }
278276
public static T StdDev<T>(System.Numerics.Tensors.Tensor<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IFloatingPoint<T>, System.Numerics.IPowerFunctions<T>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
277+
public static System.Numerics.Tensors.Tensor<T> StdDev<T>(System.Numerics.Tensors.Tensor<T> input, int axis) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.IFloatingPoint<T>, System.Numerics.IPowerFunctions<T>, System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
279278
public static TResult StdDev<T, TResult>(System.Numerics.Tensors.Tensor<T> input) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.INumber<T> where TResult : System.IEquatable<TResult>, System.Numerics.IEqualityOperators<TResult, TResult, bool>, System.Numerics.IFloatingPoint<TResult> { throw null; }
279+
public static System.Numerics.Tensors.Tensor<TResult> StdDev<T, TResult>(System.Numerics.Tensors.Tensor<T> input, int axis) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.INumber<T> where TResult : System.IEquatable<TResult>, System.Numerics.IEqualityOperators<TResult, TResult, bool>, System.Numerics.IFloatingPoint<TResult> { throw null; }
280280
public static System.Numerics.Tensors.SpanND<T> SubtractInPlace<T>(System.Numerics.Tensors.SpanND<T> input, System.Numerics.Tensors.Tensor<T> other) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.ISubtractionOperators<T, T, T> { throw null; }
281281
public static System.Numerics.Tensors.SpanND<T> SubtractInPlace<T>(System.Numerics.Tensors.SpanND<T> input, T val) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.ISubtractionOperators<T, T, T> { throw null; }
282282
public static System.Numerics.Tensors.Tensor<T> SubtractInPlace<T>(System.Numerics.Tensors.Tensor<T> input, System.Numerics.Tensors.Tensor<T> other) where T : System.IEquatable<T>, System.Numerics.IEqualityOperators<T, T, bool>, System.Numerics.ISubtractionOperators<T, T, T> { throw null; }

src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
</ItemGroup>
1717

1818
<ItemGroup Condition="'$(TargetFrameworkIdentifier)' == '.NETCoreApp'">
19+
<Compile Include="System\Numerics\Tensors\netcore\TensorHelpers.cs" />
1920
<Compile Include="System\Numerics\Tensors\netcore\TensorExtensions.cs" />
2021
<Compile Include="System\Numerics\Tensors\netcore\Tensor.Factory.cs" />
2122
<Compile Include="System\Numerics\Tensors\netcore\Tensor.cs" />

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/SpanNDExtensions.cs

-9
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,6 @@ public static unsafe bool SequenceEqual<T>(this SpanND<T> span, SpanND<T> other)
2323
nint length = span.LinearLength;
2424
nint otherLength = other.LinearLength;
2525

26-
//if (RuntimeHelpers.IsBitwiseEquatable<T>())
27-
//{
28-
// return length == otherLength &&
29-
// SpanHelpers.SequenceEqual(
30-
// ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
31-
// ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
32-
// ((uint)otherLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
33-
//}
34-
3526
return length == otherLength && SpanHelpers.SequenceEqual(ref span.GetPinnableReference(), ref other.GetPinnableReference(), length);
3627
}
3728

0 commit comments

Comments
 (0)