From 61faee1e8c947bac6cf699b17081ddc106cf7bb8 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Mon, 22 Apr 2024 14:51:26 -0600 Subject: [PATCH] ref and implicit broadcast --- .../ref/System.Numerics.Tensors.netcore.cs | 10 +- .../src/System.Numerics.Tensors.csproj | 1 + .../Tensors/netcore/SpanNDExtensions.cs | 9 - .../Tensors/netcore/TensorExtensions.cs | 499 ++++++++++++++---- .../Numerics/Tensors/netcore/TensorHelpers.cs | 112 ++++ .../tests/TensorTests.cs | 257 ++++++++- 6 files changed, 749 insertions(+), 139 deletions(-) create mode 100644 src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorHelpers.cs diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs index 773043b6fe6928..e724b2d442b813 100644 --- a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs @@ -181,9 +181,7 @@ public static partial class Tensor public static System.Numerics.Tensors.SpanND Add(System.Numerics.Tensors.SpanND input, T val) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity { throw null; } public static System.Numerics.Tensors.Tensor Add(System.Numerics.Tensors.Tensor input, System.Numerics.Tensors.Tensor other) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity { throw null; } public static System.Numerics.Tensors.Tensor Add(System.Numerics.Tensors.Tensor input, T val) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity { throw null; } - public static bool AreShapesBroadcastToCompatible(System.ReadOnlySpan shape1, System.ReadOnlySpan shape2) { throw null; } - public static bool AreShapesBroadcastToCompatible(System.Numerics.Tensors.Tensor tensor1, System.Numerics.Tensors.Tensor tensor2) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } - public static System.Numerics.Tensors.Tensor BroadcastTo(System.Numerics.Tensors.Tensor input, System.ReadOnlySpan shape) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } + public static System.Numerics.Tensors.Tensor Broadcast(System.Numerics.Tensors.Tensor input, System.ReadOnlySpan shape) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } public static System.Numerics.Tensors.Tensor Concatenate(System.ReadOnlySpan> tensors, int axis = 0) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } public static System.Numerics.Tensors.SpanND CosInPlace(System.Numerics.Tensors.SpanND input) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.ITrigonometricFunctions { throw null; } public static System.Numerics.Tensors.Tensor CosInPlace(System.Numerics.Tensors.Tensor input) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.ITrigonometricFunctions { throw null; } @@ -218,7 +216,6 @@ public static partial class Tensor public static System.Numerics.Tensors.Tensor FillRange(System.Collections.Generic.IEnumerable data) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } public static System.Numerics.Tensors.Tensor FilteredUpdate(System.Numerics.Tensors.Tensor left, System.Numerics.Tensors.Tensor filter, System.Numerics.Tensors.Tensor values) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } public static System.Numerics.Tensors.Tensor FilteredUpdate(System.Numerics.Tensors.Tensor left, System.Numerics.Tensors.Tensor filter, T value) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } - public static nint[] GetIntermediateShape(System.ReadOnlySpan shape1, int shape2Length) { throw null; } public static bool GreaterThanAll(System.Numerics.Tensors.Tensor left, System.Numerics.Tensors.Tensor right) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IComparisonOperators { throw null; } public static bool GreaterThanAny(System.Numerics.Tensors.Tensor left, System.Numerics.Tensors.Tensor right) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IComparisonOperators { throw null; } public static System.Numerics.Tensors.Tensor GreaterThan(System.Numerics.Tensors.Tensor left, System.Numerics.Tensors.Tensor right) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IComparisonOperators { throw null; } @@ -249,13 +246,14 @@ public static partial class Tensor public static System.Numerics.Tensors.Tensor MultiplyInPlace(System.Numerics.Tensors.Tensor input, T val) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { throw null; } public static System.Numerics.Tensors.SpanND Multiply(System.Numerics.Tensors.SpanND input, System.Numerics.Tensors.Tensor other) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { throw null; } public static System.Numerics.Tensors.SpanND Multiply(System.Numerics.Tensors.SpanND input, T val) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { throw null; } - public static System.Numerics.Tensors.Tensor Multiply(System.Numerics.Tensors.Tensor input, System.Numerics.Tensors.Tensor other) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { throw null; } + public static System.Numerics.Tensors.Tensor Multiply(System.Numerics.Tensors.Tensor left, System.Numerics.Tensors.Tensor right) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { throw null; } public static System.Numerics.Tensors.Tensor Multiply(System.Numerics.Tensors.Tensor input, T val) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IMultiplyOperators, System.Numerics.IMultiplicativeIdentity { throw null; } public static System.Numerics.Tensors.Tensor Normal(params nint[] lengths) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IFloatingPoint { throw null; } public static T Norm(System.Numerics.Tensors.SpanND input) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IRootFunctions { throw null; } public static T Norm(System.Numerics.Tensors.Tensor input) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IRootFunctions { throw null; } public static System.Numerics.Tensors.Tensor Permute(System.Numerics.Tensors.Tensor input, params int[] axis) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } public static System.Numerics.Tensors.Tensor Permute(System.Numerics.Tensors.Tensor input, System.ReadOnlySpan axis) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } + public static System.Numerics.Tensors.SpanND Reshape(this System.Numerics.Tensors.SpanND input, System.ReadOnlySpan lengths) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } public static System.Numerics.Tensors.Tensor Reshape(this System.Numerics.Tensors.Tensor input, params nint[] lengths) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } public static System.Numerics.Tensors.Tensor Reshape(this System.Numerics.Tensors.Tensor input, System.ReadOnlySpan lengths) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } public static System.Numerics.Tensors.SpanND Resize(System.Numerics.Tensors.SpanND input, System.ReadOnlySpan shape) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } @@ -276,7 +274,9 @@ public static partial class Tensor public static System.Numerics.Tensors.Tensor Squeeze(System.Numerics.Tensors.Tensor input, int axis = -1) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } public static System.Numerics.Tensors.Tensor Stack(System.Numerics.Tensors.Tensor[] input, int axis = 0) where T : System.IEquatable, System.Numerics.IEqualityOperators { throw null; } public static T StdDev(System.Numerics.Tensors.Tensor input) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IFloatingPoint, System.Numerics.IPowerFunctions, System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity { throw null; } + public static System.Numerics.Tensors.Tensor StdDev(System.Numerics.Tensors.Tensor input, int axis) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IFloatingPoint, System.Numerics.IPowerFunctions, System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity { throw null; } public static TResult StdDev(System.Numerics.Tensors.Tensor input) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.INumber where TResult : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IFloatingPoint { throw null; } + public static System.Numerics.Tensors.Tensor StdDev(System.Numerics.Tensors.Tensor input, int axis) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.INumber where TResult : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.IFloatingPoint { throw null; } public static System.Numerics.Tensors.SpanND SubtractInPlace(System.Numerics.Tensors.SpanND input, System.Numerics.Tensors.Tensor other) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.ISubtractionOperators { throw null; } public static System.Numerics.Tensors.SpanND SubtractInPlace(System.Numerics.Tensors.SpanND input, T val) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.ISubtractionOperators { throw null; } public static System.Numerics.Tensors.Tensor SubtractInPlace(System.Numerics.Tensors.Tensor input, System.Numerics.Tensors.Tensor other) where T : System.IEquatable, System.Numerics.IEqualityOperators, System.Numerics.ISubtractionOperators { throw null; } diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj index 79b971dad754cd..214e1c1bf2a02d 100644 --- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj @@ -16,6 +16,7 @@ + diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/SpanNDExtensions.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/SpanNDExtensions.cs index 05568a2f8ebb76..a2be5c69519d95 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/SpanNDExtensions.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/SpanNDExtensions.cs @@ -23,15 +23,6 @@ public static unsafe bool SequenceEqual(this SpanND span, SpanND other) nint length = span.LinearLength; nint otherLength = other.LinearLength; - //if (RuntimeHelpers.IsBitwiseEquatable()) - //{ - // return length == otherLength && - // SpanHelpers.SequenceEqual( - // ref Unsafe.As(ref MemoryMarshal.GetReference(span)), - // ref Unsafe.As(ref MemoryMarshal.GetReference(other)), - // ((uint)otherLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking. - //} - return length == otherLength && SpanHelpers.SequenceEqual(ref span.GetPinnableReference(), ref other.GetPinnableReference(), length); } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs index 1d720cac9f64f3..5a6066d99017df 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using System.Security.Cryptography; #pragma warning disable CS8601 // Possible null reference assignment. #pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. @@ -50,10 +51,22 @@ public static SpanND Resize(SpanND input, ReadOnlySpan shape) #endregion #region Broadcast - public static Tensor BroadcastTo(Tensor input, ReadOnlySpan shape) + + public static Tensor Broadcast(Tensor input, ReadOnlySpan shape) where T : IEquatable, IEqualityOperators { - if (!AreShapesBroadcastToCompatible(input.Lengths, shape)) + Tensor intermediate = BroadcastTo(input, shape); + return Tensor.Create(intermediate.ToArray(), intermediate.Lengths); + } + + // Lazy/non-copy broadcasting, internal only for now. + internal static Tensor BroadcastTo(Tensor input, ReadOnlySpan shape) + where T : IEquatable, IEqualityOperators + { + if (input.Lengths.SequenceEqual(shape)) + return new Tensor(input._values, shape.ToArray(), input.IsPinned); + + if (!TensorHelpers.AreShapesBroadcastCompatible(input.Lengths, shape)) throw new Exception("Shapes are not broadcast compatible."); var newSize = SpanHelpers.CalculateTotalLength(shape); @@ -61,7 +74,7 @@ public static Tensor BroadcastTo(Tensor input, ReadOnlySpan shape if (newSize == input.LinearLength) return Reshape(input, shape); - var intermediateShape = GetIntermediateShape(input.Lengths, shape.Length); + var intermediateShape = TensorHelpers.GetIntermediateShape(input.Lengths, shape.Length); nint[] strides = new nint[shape.Length]; nint stride = 1; @@ -82,60 +95,39 @@ public static Tensor BroadcastTo(Tensor input, ReadOnlySpan shape return output; } - public static bool AreShapesBroadcastToCompatible(Tensor tensor1, Tensor tensor2) - where T : IEquatable, IEqualityOperators => AreShapesBroadcastToCompatible(tensor1.Lengths, tensor2.Lengths); - - - public static bool AreShapesBroadcastToCompatible(ReadOnlySpan shape1, ReadOnlySpan shape2) + internal static SpanND BroadcastTo(SpanND input, ReadOnlySpan shape) + where T : IEquatable, IEqualityOperators { - var shape1Index = shape1.Length - 1; - var shape2Index = shape2.Length - 1; + if (input.Lengths.SequenceEqual(shape)) + return new SpanND(ref input._reference, shape, input.Strides, input.IsPinned); - bool areCompatible = true; + if (!TensorHelpers.AreShapesBroadcastCompatible(input.Lengths, shape)) + throw new Exception("Shapes are not broadcast compatible."); - nint s1; - nint s2; + var newSize = SpanHelpers.CalculateTotalLength(shape); - while (shape1Index >= 0 || shape2Index >= 0) - { - // if a dimension is missing in one of the shapes, it is considered to be 1 - if (shape1Index < 0) - s1 = 1; - else - s1 = shape1[shape1Index--]; + if (newSize == input.LinearLength) + return Reshape(input, shape); - if (shape2Index < 0) - s2 = 1; - else - s2 = shape2[shape2Index--]; + var intermediateShape = TensorHelpers.GetIntermediateShape(input.Lengths, shape.Length); + nint[] strides = new nint[shape.Length]; - if (s1 == s2 || (s1 == 1 && s2 != 1) || (s1 == 1 && s2 != 1)) { } + nint stride = 1; + + for (int i = strides.Length - 1; i >= 0; i--) + { + if ((intermediateShape[i] == 1 && shape[i] != 1) || (intermediateShape[i] == 1 && shape[i] == 1)) + strides[i] = 0; else { - areCompatible = false; - break; + strides[i] = stride; + stride *= intermediateShape[i]; } } - return areCompatible; - } + SpanND output = new SpanND(ref input._reference, shape, strides, input.IsPinned); - public static nint[] GetIntermediateShape(ReadOnlySpan shape1, int shape2Length) - { - var shape1Index = shape1.Length - 1; - var newShapeIndex = Math.Max(shape1.Length, shape2Length) - 1; - nint[] newShape = new nint[Math.Max(shape1.Length, shape2Length)]; - - while (newShapeIndex >= 0) - { - // if a dimension is missing in one of the shapes, it is considered to be 1 - if (shape1Index < 0) - newShape[newShapeIndex--] = 1; - else - newShape[newShapeIndex--] = shape1[shape1Index--]; - } - - return newShape; + return output; } #endregion @@ -317,8 +309,7 @@ public static Tensor[] Split(Tensor input, nint numSplits, nint axis) #endregion #region SetSlice - // REVIEW: NOT IN DESIGN DOC BUT NEEDED FOR NIKLAS NOTEBOOK. - // REVIEW: WHAT DO WE WANT TO CALL THIS? COPYTO? IT DOES FIT IN WITH THE EXISTING COPY TO CONVENTIONS FOR VECTOR. + // REVIEW: WHAT DO WE WANT TO CALL THIS? COPYTO? IT DOES FIT IN WITH THE EXISTING COPY TO CONVENTIONS FOR VECTOR (albeit backwards). public static Tensor SetSlice(this Tensor tensor, Tensor values, params NativeRange[] ranges) where T : IEquatable, IEqualityOperators { @@ -342,7 +333,6 @@ public static Tensor SetSlice(this Tensor tensor, Tensor values, par #endregion #region FilteredUpdate - // REVIEW: NOT IN DESIGN DOC BUT NEEDED FOR NIKLAS NOTEBOOK. // REVIEW: PYTORCH/NUMPY DO THIS. // t0[t0 < 2] = -1; // OR SHOULD THIS BE AN OVERLOAD OF FILL THAT TAKES IN A FUNC TO KNOW WHICH ONE TO UPDATE? @@ -396,31 +386,65 @@ public static Tensor FilteredUpdate(Tensor left, Tensor filter, T #endregion #region SequenceEqual - // REVIEW: THIS NEEDS TO SUPPORT BROADCASTING AND ADD APPROPRIATE CHECKING. public static Tensor SequenceEqual(Tensor left, Tensor right) where T : IEquatable, IEqualityOperators { - Tensor result = Tensor.Create(false, left.Lengths); + Tensor result; + if (TensorHelpers.AreShapesTheSame(left, right)) + { + result = Tensor.Create(false, left.Lengths); - for (int i = 0; i < left.LinearLength; i++) + for (int i = 0; i < left.LinearLength; i++) + { + result._values[i] = left._values[i] == right._values[i]; + } + } + else { - result._values[i] = left._values[i] == right._values[i]; + nint[] newSize = TensorHelpers.GetSmallestBroadcastableSize(left.Lengths, right.Lengths); + result = Tensor.Create(false, newSize); + var broadcastedLeft = BroadcastTo(left, newSize); + var broadcastedRight = BroadcastTo(right, newSize); + nint[] curIndex = new nint[broadcastedRight.Lengths.Length]; + for (int i = 0; i < broadcastedLeft.LinearLength; i++) + { + result._values[i] = broadcastedLeft[curIndex] == broadcastedRight[curIndex]; + SpanHelpers.AdjustIndices(broadcastedRight.Rank - 1, 1, ref curIndex, broadcastedRight.Lengths); + } } + return result; } #endregion #region LessThan - // REVIEW: ALL OF THESE NEED TO SUPPORT BROADCASTING AND ADD APPROPRIATE CHECKING. public static Tensor LessThan(Tensor left, Tensor right) where T : IEquatable, IEqualityOperators, IComparisonOperators { - Tensor result = Tensor.Create(false, left.Lengths); + Tensor result; + if (TensorHelpers.AreShapesTheSame(left, right)) + { + result = Tensor.Create(false, left.Lengths); - for (int i = 0; i < left.LinearLength; i++) + for (int i = 0; i < left.LinearLength; i++) + { + result._values[i] = left._values[i] < right._values[i]; + } + } + else { - result._values[i] = left._values[i] < right._values[i]; + nint[] newSize = TensorHelpers.GetSmallestBroadcastableSize(left.Lengths, right.Lengths); + result = Tensor.Create(false, newSize); + var broadcastedLeft = BroadcastTo(left, newSize); + var broadcastedRight = BroadcastTo(right, newSize); + nint[] curIndex = new nint[broadcastedRight.Lengths.Length]; + for (int i = 0; i < broadcastedLeft.LinearLength; i++) + { + result._values[i] = broadcastedLeft[curIndex] < broadcastedRight[curIndex]; + SpanHelpers.AdjustIndices(broadcastedRight.Rank - 1, 1, ref curIndex, broadcastedRight.Lengths); + } } + return result; } @@ -439,37 +463,90 @@ public static Tensor LessThan(Tensor left, T right) public static bool LessThanAny(Tensor left, Tensor right) where T : IEquatable, IEqualityOperators, IComparisonOperators { - for (int i = 0; i < left.LinearLength; i++) + + if (TensorHelpers.AreShapesTheSame(left, right)) { - if (left._values[i] < right._values[i]) - return true; + + for (int i = 0; i < left.LinearLength; i++) + { + if (left._values[i] < right._values[i]) + return true; + } } + else + { + nint[] newSize = TensorHelpers.GetSmallestBroadcastableSize(left.Lengths, right.Lengths); + var broadcastedLeft = BroadcastTo(left, newSize); + var broadcastedRight = BroadcastTo(right, newSize); + nint[] curIndex = new nint[broadcastedRight.Lengths.Length]; + for (int i = 0; i < broadcastedLeft.LinearLength; i++) + { + if (broadcastedLeft[curIndex] < broadcastedRight[curIndex]) + return true; + SpanHelpers.AdjustIndices(broadcastedRight.Rank - 1, 1, ref curIndex, broadcastedRight.Lengths); + } + } + return false; } public static bool LessThanAll(Tensor left, Tensor right) where T : IEquatable, IEqualityOperators, IComparisonOperators { - for (int i = 0; i < left.LinearLength; i++) + if (TensorHelpers.AreShapesTheSame(left, right)) { - if (left._values[i] > right._values[i]) - return false; + + for (int i = 0; i < left.LinearLength; i++) + { + if (left._values[i] > right._values[i]) + return false; + } + } + else + { + nint[] newSize = TensorHelpers.GetSmallestBroadcastableSize(left.Lengths, right.Lengths); + var broadcastedLeft = BroadcastTo(left, newSize); + var broadcastedRight = BroadcastTo(right, newSize); + nint[] curIndex = new nint[broadcastedRight.Lengths.Length]; + for (int i = 0; i < broadcastedLeft.LinearLength; i++) + { + if (broadcastedLeft[curIndex] > broadcastedRight[curIndex]) + return false; + SpanHelpers.AdjustIndices(broadcastedRight.Rank - 1, 1, ref curIndex, broadcastedRight.Lengths); + } } return true; } #endregion #region GreaterThan - // REVIEW: ALL OF THESE NEED TO SUPPORT BROADCASTING AND ADD APPROPRIATE CHECKING. public static Tensor GreaterThan(Tensor left, Tensor right) where T : IEquatable, IEqualityOperators, IComparisonOperators { - Tensor result = Tensor.Create(false, left.Lengths); + Tensor result; + if (TensorHelpers.AreShapesTheSame(left, right)) + { + result = Tensor.Create(false, left.Lengths); - for (int i = 0; i < left.LinearLength; i++) + for (int i = 0; i < left.LinearLength; i++) + { + result._values[i] = left._values[i] > right._values[i]; + } + } + else { - result._values[i] = left._values[i] > right._values[i]; + nint[] newSize = TensorHelpers.GetSmallestBroadcastableSize(left.Lengths, right.Lengths); + result = Tensor.Create(false, newSize); + var broadcastedLeft = BroadcastTo(left, newSize); + var broadcastedRight = BroadcastTo(right, newSize); + nint[] curIndex = new nint[broadcastedRight.Lengths.Length]; + for (int i = 0; i < broadcastedLeft.LinearLength; i++) + { + result._values[i] = broadcastedLeft[curIndex] > broadcastedRight[curIndex]; + SpanHelpers.AdjustIndices(broadcastedRight.Rank - 1, 1, ref curIndex, broadcastedRight.Lengths); + } } + return result; } @@ -488,10 +565,27 @@ public static Tensor GreaterThan(Tensor left, T right) public static bool GreaterThanAny(Tensor left, Tensor right) where T : IEquatable, IEqualityOperators, IComparisonOperators { - for (int i = 0; i < left.LinearLength; i++) + if (TensorHelpers.AreShapesTheSame(left, right)) { - if (left._values[i] > right._values[i]) - return true; + + for (int i = 0; i < left.LinearLength; i++) + { + if (left._values[i] > right._values[i]) + return true; + } + } + else + { + nint[] newSize = TensorHelpers.GetSmallestBroadcastableSize(left.Lengths, right.Lengths); + var broadcastedLeft = BroadcastTo(left, newSize); + var broadcastedRight = BroadcastTo(right, newSize); + nint[] curIndex = new nint[broadcastedRight.Lengths.Length]; + for (int i = 0; i < broadcastedLeft.LinearLength; i++) + { + if (broadcastedLeft[curIndex] > broadcastedRight[curIndex]) + return true; + SpanHelpers.AdjustIndices(broadcastedRight.Rank - 1, 1, ref curIndex, broadcastedRight.Lengths); + } } return false; } @@ -499,10 +593,27 @@ public static bool GreaterThanAny(Tensor left, Tensor right) public static bool GreaterThanAll(Tensor left, Tensor right) where T : IEquatable, IEqualityOperators, IComparisonOperators { - for (int i = 0; i < left.LinearLength; i++) + if (TensorHelpers.AreShapesTheSame(left, right)) { - if (left._values[i] < right._values[i]) - return false; + + for (int i = 0; i < left.LinearLength; i++) + { + if (left._values[i] < right._values[i]) + return false; + } + } + else + { + nint[] newSize = TensorHelpers.GetSmallestBroadcastableSize(left.Lengths, right.Lengths); + var broadcastedLeft = BroadcastTo(left, newSize); + var broadcastedRight = BroadcastTo(right, newSize); + nint[] curIndex = new nint[broadcastedRight.Lengths.Length]; + for (int i = 0; i < broadcastedLeft.LinearLength; i++) + { + if (broadcastedLeft[curIndex] < broadcastedRight[curIndex]) + return false; + SpanHelpers.AdjustIndices(broadcastedRight.Rank - 1, 1, ref curIndex, broadcastedRight.Lengths); + } } return true; } @@ -560,6 +671,34 @@ public static Tensor Reshape(this Tensor input, ReadOnlySpan leng var strides = SpanHelpers.CalculateStrides(arrLengths.Length, arrLengths); return new Tensor(input._values, arrLengths, strides.ToArray(), input.IsPinned); } + + public static SpanND Reshape(this SpanND input, ReadOnlySpan lengths) + where T : IEquatable, IEqualityOperators + { + var arrLengths = lengths.ToArray(); + // Calculate wildcard info. + if (lengths.Contains(-1)) + { + if (lengths.Count(-1) > 1) + throw new ArgumentException("Provided dimensions can only include 1 wildcard."); + var tempTotal = input.LinearLength; + for (int i = 0; i < lengths.Length; i++) + { + if (lengths[i] != -1) + { + tempTotal /= lengths[i]; + } + } + arrLengths[lengths.IndexOf(-1)] = tempTotal; + + } + + var tempLinear = SpanHelpers.CalculateTotalLength(ref arrLengths); + if (tempLinear != input.LinearLength) + throw new ArgumentException("Provided dimensions are not valid for reshaping"); + var strides = SpanHelpers.CalculateStrides(arrLengths.Length, arrLengths); + return new SpanND(ref input._reference, arrLengths, strides, input.IsPinned); + } #endregion #region Squeeze @@ -626,13 +765,12 @@ public static Tensor Unsqueeze(Tensor input, int axis) #endregion #region Concatenate - //REVIEW: SHOULD AXIS BE NULLABLE INT SO NULL CAN BE PROVIDED INSTEAD OF -1? SENTINAL VALUE? /// - /// Join a sequence of arrays along an existing axis. + /// Join a sequence of tensors along an existing axis. /// /// - /// The arrays must have the same shape, except in the dimension corresponding to axis (the first, by default). - /// The axis along which the arrays will be joined. If axis is -1, arrays are flattened before use. Default is 0. + /// The tensors must have the same shape, except in the dimension corresponding to axis (the first, by default). + /// The axis along which the tensors will be joined. If axis is -1, arrays are flattened before use. Default is 0. /// public static Tensor Concatenate(ReadOnlySpan> tensors, int axis = 0) where T : IEquatable, IEqualityOperators @@ -735,6 +873,10 @@ public static T StdDev(Tensor input) return T.CreateChecked(sum / T.CreateChecked(input.LinearLength)); } + public static Tensor StdDev(Tensor input, int axis) + where T : IEquatable, IEqualityOperators, IFloatingPoint, IPowerFunctions, IAdditionOperators, IAdditiveIdentity + => throw new NotImplementedException(); + public static TResult StdDev(Tensor input) where T : IEquatable, IEqualityOperators, INumber where TResult : IEquatable, IEqualityOperators, IFloatingPoint @@ -743,6 +885,11 @@ public static TResult StdDev(Tensor input) T sum = Tensor.Sum(input); return TResult.CreateChecked(TResult.CreateChecked(sum) / TResult.CreateChecked(input.LinearLength)); } + + public static Tensor StdDev(Tensor input, int axis) + where T : IEquatable, IEqualityOperators, INumber + where TResult : IEquatable, IEqualityOperators, IFloatingPoint + => throw new NotImplementedException(); #endregion #region Mean @@ -826,7 +973,7 @@ public static Tensor Permute(Tensor input, ReadOnlySpan axis) indices = new nint[tensor.Rank]; for (int i = 0; i < input._linearLength; i++) { - PermuteIndices(ref indices, ref permutedIndices, ref permutation); + TensorHelpers.PermuteIndices(ref indices, ref permutedIndices, ref permutation); ospan[permutedIndices] = ispan[indices]; SpanHelpers.AdjustIndices(tensor.Rank - 1, 1, ref indices, input._lengths); } @@ -834,14 +981,6 @@ public static Tensor Permute(Tensor input, ReadOnlySpan axis) return tensor; } } - - private static void PermuteIndices(ref nint[] indices, ref nint[] permutedIndices, ref int[] permutation) - { - for (int i = 0; i < indices.Length; i++) - { - permutedIndices[i] = indices[permutation[i]]; - } - } #endregion #region TensorPrimitives @@ -866,10 +1005,10 @@ public static Tensor MultiplyInPlace(Tensor input, T val) return output; } - public static Tensor Multiply(Tensor input, Tensor other) + public static Tensor Multiply(Tensor left, Tensor right) where T : IEquatable, IEqualityOperators, IMultiplyOperators, IMultiplicativeIdentity { - return TensorPrimitivesHelperT1T2(input, other, TensorPrimitives.Multiply); + return TensorPrimitivesHelperT1T2(left, right, TensorPrimitives.Multiply); } public static Tensor MultiplyInPlace(Tensor input, Tensor other) @@ -1426,25 +1565,187 @@ private static SpanND TensorPrimitivesHelperT1(SpanND input, PerformCal return output; } - private static Tensor TensorPrimitivesHelperT1T2(Tensor input, Tensor inputTwo, PerformCalculationT1T2 performCalculation, bool inPlace = false) + private static Tensor TensorPrimitivesHelperT1T2(Tensor left, Tensor right, PerformCalculationT1T2 performCalculation, bool inPlace = false) where T : IEquatable, IEqualityOperators { - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref input._values[0], (int)input._linearLength); - ReadOnlySpan rspan = MemoryMarshal.CreateSpan(ref inputTwo._values[0], (int)inputTwo._linearLength); - Tensor output = inPlace ? input : Create(input.IsPinned, input.Lengths); - Span ospan = MemoryMarshal.CreateSpan(ref output._values[0], (int)output._linearLength); - performCalculation(span, rspan, ospan); + if (inPlace && left.Lengths != right.Lengths) + throw new ArgumentException("In place operations require the same shape for both tensors"); + + Tensor output; + if (inPlace) + { + + ReadOnlySpan span = MemoryMarshal.CreateSpan(ref left._values[0], (int)left._linearLength); + ReadOnlySpan rspan = MemoryMarshal.CreateSpan(ref right._values[0], (int)right._linearLength); + output = left; + Span ospan = MemoryMarshal.CreateSpan(ref output._values[0], (int)output._linearLength); + performCalculation(span, rspan, ospan); + } + // If not in place but sizes are the same. + else if (left.Lengths.SequenceEqual(right.Lengths)) + { + ReadOnlySpan span = MemoryMarshal.CreateSpan(ref left.AsSpan()._reference, (int)left.LinearLength); + ReadOnlySpan rspan = MemoryMarshal.CreateSpan(ref right.AsSpan()._reference, (int)right.LinearLength); + output = Create(left.IsPinned, left.Lengths); + Span ospan = MemoryMarshal.CreateSpan(ref output.AsSpan()._reference, (int)output.LinearLength); + performCalculation(span, rspan, ospan); + return output; + } + // Not in place and broadcasting needs to happen. + else + { + // Have a couple different possible cases here. + // 1 - Both tensors have row contiguous memory (i.e. a 1x5 being broadcast to a 5x5) + // 2 - One tensor has row contiguous memory and the other has column contiguous memory (i.e. a 1x5 and a 5x1) + + nint[] newSize = TensorHelpers.GetSmallestBroadcastableSize(left.Lengths, right.Lengths); + + var broadcastedLeft = Tensor.BroadcastTo(left, newSize); + var broadcastedRight = Tensor.BroadcastTo(right, newSize); + + output = Create(left.IsPinned, newSize); + var rowLength = newSize[^1]; + Span ospan; + Span ispan; + Span buffer = new T[rowLength]; + nint[] curIndex = new nint[newSize.Length]; + int outputOffset = 0; + // left not row contiguous + if (broadcastedLeft.Strides[^1] == 0) + { + while (outputOffset < output.LinearLength) + { + ospan = MemoryMarshal.CreateSpan(ref output._values[outputOffset], (int)rowLength); + buffer.Fill(broadcastedLeft[curIndex]); + ispan = MemoryMarshal.CreateSpan(ref broadcastedRight[curIndex], (int)rowLength); + performCalculation(buffer, ispan, ospan); + outputOffset += (int)rowLength; + SpanHelpers.AdjustIndices(broadcastedLeft.Rank - 2, 1, ref curIndex, broadcastedLeft.Lengths); + } + } + // right now row contiguous + else if (broadcastedRight.Strides[^1] == 0) + { + while (outputOffset < output.LinearLength) + { + ospan = MemoryMarshal.CreateSpan(ref output._values[outputOffset], (int)rowLength); + buffer.Fill(broadcastedRight[curIndex]); + ispan = MemoryMarshal.CreateSpan(ref broadcastedLeft[curIndex], (int)rowLength); + performCalculation(ispan, buffer, ospan); + outputOffset += (int)rowLength; + SpanHelpers.AdjustIndices(broadcastedLeft.Rank - 2, 1, ref curIndex, broadcastedLeft.Lengths); + } + } + // both row contiguous + else + { + Span rspan; + while (outputOffset < output.LinearLength) + { + ospan = MemoryMarshal.CreateSpan(ref output._values[outputOffset], (int)rowLength); + ispan = MemoryMarshal.CreateSpan(ref broadcastedLeft[curIndex], (int)rowLength); + rspan = MemoryMarshal.CreateSpan(ref broadcastedRight[curIndex], (int)rowLength); + performCalculation(ispan, rspan, ospan); + outputOffset += (int)rowLength; + SpanHelpers.AdjustIndices(broadcastedLeft.Rank - 2, 1, ref curIndex, broadcastedLeft.Lengths); + } + } + } return output; } - private static SpanND TensorPrimitivesHelperT1T2(SpanND input, SpanND inputTwo, PerformCalculationT1T2 performCalculation, bool inPlace = false) + private static SpanND TensorPrimitivesHelperT1T2(SpanND left, SpanND right, PerformCalculationT1T2 performCalculation, bool inPlace = false) where T : IEquatable, IEqualityOperators { - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref input._reference, (int)input.LinearLength); - ReadOnlySpan rspan = MemoryMarshal.CreateSpan(ref inputTwo._reference, (int)inputTwo.LinearLength); - SpanND output = inPlace ? input : Create(input.IsPinned, input.Lengths); - Span ospan = MemoryMarshal.CreateSpan(ref output._reference, (int)output.LinearLength); - performCalculation(span, rspan, ospan); + if (inPlace && left.Lengths != right.Lengths) + throw new ArgumentException("In place operations require the same shape for both spans"); + + //ReadOnlySpan span = MemoryMarshal.CreateSpan(ref left._reference, (int)left.LinearLength); + //ReadOnlySpan rspan = MemoryMarshal.CreateSpan(ref right._reference, (int)right.LinearLength); + //SpanND output = inPlace ? left : Create(left.IsPinned, left.Lengths); + //Span ospan = MemoryMarshal.CreateSpan(ref output._reference, (int)output.LinearLength); + //performCalculation(span, rspan, ospan); + //return output; + + SpanND output; + if (inPlace) + { + ReadOnlySpan span = MemoryMarshal.CreateSpan(ref left._reference, (int)left.LinearLength); + ReadOnlySpan rspan = MemoryMarshal.CreateSpan(ref right._reference, (int)right.LinearLength); + output = left; + Span ospan = MemoryMarshal.CreateSpan(ref output._reference, (int)output.LinearLength); + performCalculation(span, rspan, ospan); + } + // If not in place but sizes are the same. + else if (left.Lengths.SequenceEqual(right.Lengths)) + { + ReadOnlySpan span = MemoryMarshal.CreateSpan(ref left._reference, (int)left.LinearLength); + ReadOnlySpan rspan = MemoryMarshal.CreateSpan(ref right._reference, (int)right.LinearLength); + output = Create(left.IsPinned, left.Lengths); + Span ospan = MemoryMarshal.CreateSpan(ref output._reference, (int)output.LinearLength); + performCalculation(span, rspan, ospan); + return output; + } + // Not in place and broadcasting needs to happen. + else + { + // Have a couple different possible cases here. + // 1 - Both tensors have row contiguous memory (i.e. a 1x5 being broadcast to a 5x5) + // 2 - One tensor has row contiguous memory and the other has column contiguous memory (i.e. a 1x5 and a 5x1) + + nint[] newSize = TensorHelpers.GetSmallestBroadcastableSize(left.Lengths, right.Lengths); + + var broadcastedLeft = Tensor.BroadcastTo(left, newSize); + var broadcastedRight = Tensor.BroadcastTo(right, newSize); + + output = Create(left.IsPinned, newSize); + var rowLength = newSize[^1]; + Span ospan; + Span ispan; + Span buffer = new T[rowLength]; + nint[] curIndex = new nint[newSize.Length]; + int outputOffset = 0; + // left not row contiguous + if (broadcastedLeft.Strides[^1] == 0) + { + while (outputOffset < output.LinearLength) + { + ospan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref output._reference, outputOffset), (int)rowLength); + buffer.Fill(broadcastedLeft[curIndex]); + ispan = MemoryMarshal.CreateSpan(ref broadcastedRight[curIndex], (int)rowLength); + performCalculation(buffer, ispan, ospan); + outputOffset += (int)rowLength; + SpanHelpers.AdjustIndices(broadcastedLeft.Rank - 2, 1, ref curIndex, broadcastedLeft.Lengths); + } + } + // right now row contiguous + else if (broadcastedRight.Strides[^1] == 0) + { + while (outputOffset < output.LinearLength) + { + ospan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref output._reference, outputOffset), (int)rowLength); + buffer.Fill(broadcastedRight[curIndex]); + ispan = MemoryMarshal.CreateSpan(ref broadcastedLeft[curIndex], (int)rowLength); + performCalculation(ispan, buffer, ospan); + outputOffset += (int)rowLength; + SpanHelpers.AdjustIndices(broadcastedLeft.Rank - 2, 1, ref curIndex, broadcastedLeft.Lengths); + } + } + // both row contiguous + else + { + Span rspan; + while (outputOffset < output.LinearLength) + { + ospan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref output._reference, outputOffset), (int)rowLength); + ispan = MemoryMarshal.CreateSpan(ref broadcastedLeft[curIndex], (int)rowLength); + rspan = MemoryMarshal.CreateSpan(ref broadcastedRight[curIndex], (int)rowLength); + performCalculation(ispan, rspan, ospan); + outputOffset += (int)rowLength; + SpanHelpers.AdjustIndices(broadcastedLeft.Rank - 2, 1, ref curIndex, broadcastedLeft.Lengths); + } + } + } return output; } #endregion diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorHelpers.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorHelpers.cs new file mode 100644 index 00000000000000..862d5b95d46858 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorHelpers.cs @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Data; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Net.NetworkInformation; +using System.Reflection.Metadata.Ecma335; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using Microsoft.CSharp.RuntimeBinder; +using Microsoft.VisualBasic; + +namespace System.Numerics.Tensors +{ + internal static class TensorHelpers + { + + internal static bool AreShapesBroadcastCompatible(Tensor tensor1, Tensor tensor2) + where T : IEquatable, IEqualityOperators => AreShapesBroadcastCompatible(tensor1.Lengths, tensor2.Lengths); + + internal static bool AreShapesBroadcastCompatible(ReadOnlySpan shape1, ReadOnlySpan shape2) + { + var shape1Index = shape1.Length - 1; + var shape2Index = shape2.Length - 1; + + bool areCompatible = true; + + nint s1; + nint s2; + + while (shape1Index >= 0 || shape2Index >= 0) + { + // if a dimension is missing in one of the shapes, it is considered to be 1 + if (shape1Index < 0) + s1 = 1; + else + s1 = shape1[shape1Index--]; + + if (shape2Index < 0) + s2 = 1; + else + s2 = shape2[shape2Index--]; + + if (s1 == s2 || (s1 == 1 && s2 != 1) || (s2 == 1 && s1 != 1)) { } + else + { + areCompatible = false; + break; + } + } + + return areCompatible; + } + + internal static nint[] GetSmallestBroadcastableSize(ReadOnlySpan shape1, ReadOnlySpan shape2) + { + if (!AreShapesBroadcastCompatible(shape1, shape2)) + throw new Exception("Shapes are not broadcast compatible"); + + nint[] intermediateShape = GetIntermediateShape(shape1, shape2.Length); + for (int i = 1; i <= shape1.Length; i++) + { + intermediateShape[^i] = Math.Max(intermediateShape[^i], shape1[^i]); + } + for (int i = 1; i <= shape2.Length; i++) + { + intermediateShape[^i] = Math.Max(intermediateShape[^i], shape2[^i]); + } + + return intermediateShape; + } + + internal static nint[] GetIntermediateShape(ReadOnlySpan shape1, int shape2Length) + { + var shape1Index = shape1.Length - 1; + var newShapeIndex = Math.Max(shape1.Length, shape2Length) - 1; + nint[] newShape = new nint[Math.Max(shape1.Length, shape2Length)]; + + while (newShapeIndex >= 0) + { + // if a dimension is missing in one of the shapes, it is considered to be 1 + if (shape1Index < 0) + newShape[newShapeIndex--] = 1; + else + newShape[newShapeIndex--] = shape1[shape1Index--]; + } + + return newShape; + } + + internal static bool IsUnderlyingStorageSameSize(Tensor tensor1, Tensor tensor2) + where T : IEquatable, IEqualityOperators => tensor1.Lengths.Length == tensor2.Lengths.Length; + + internal static bool AreShapesTheSame(Tensor tensor1, Tensor tensor2) + where T : IEquatable, IEqualityOperators => tensor1._lengths.SequenceEqual(tensor2._lengths); + + + internal static void PermuteIndices(ref nint[] indices, ref nint[] permutedIndices, ref int[] permutation) + { + for (int i = 0; i < indices.Length; i++) + { + permutedIndices[i] = indices[permutation[i]]; + } + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs index 7e212a1ef40490..886e8ef1d03294 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading.Tasks; using Xunit; @@ -13,10 +14,135 @@ namespace System.Numerics.Tensors.Tests public class TensorTests { [Fact] - public static void TensorBroadcastToTests() + public static void TensorSequenceEqualTests() + { + Tensor t0 = Tensor.FillRange(Enumerable.Range(0, 3)); + Tensor t1 = Tensor.FillRange(Enumerable.Range(0, 3)); + Tensor equal = Tensor.SequenceEqual(t0, t1); + + Assert.Equal([3], equal.Lengths.ToArray()); + Assert.True(equal[0]); + Assert.True(equal[1]); + Assert.True(equal[2]); + + t0 = Tensor.FillRange(Enumerable.Range(0, 3)).Reshape(1, 3); + t1 = Tensor.FillRange(Enumerable.Range(0, 3)); + equal = Tensor.SequenceEqual(t0, t1); + + Assert.Equal([1, 3], equal.Lengths.ToArray()); + Assert.True(equal[0, 0]); + Assert.True(equal[0, 1]); + Assert.True(equal[0, 2]); + + t0 = Tensor.FillRange(Enumerable.Range(0, 3)).Reshape(1, 1, 3); + t1 = Tensor.FillRange(Enumerable.Range(0, 3)); + equal = Tensor.SequenceEqual(t0, t1); + + Assert.Equal([1, 1, 3], equal.Lengths.ToArray()); + Assert.True(equal[0, 0, 0]); + Assert.True(equal[0, 0, 1]); + Assert.True(equal[0, 0, 2]); + + t0 = Tensor.FillRange(Enumerable.Range(0, 3)); + t1 = Tensor.FillRange(Enumerable.Range(0, 3)).Reshape(1, 3); + equal = Tensor.SequenceEqual(t0, t1); + + Assert.Equal([1, 3], equal.Lengths.ToArray()); + Assert.True(equal[0, 0]); + Assert.True(equal[0, 1]); + Assert.True(equal[0, 2]); + + t0 = Tensor.FillRange(Enumerable.Range(0, 3)); + t1 = Tensor.FillRange(Enumerable.Range(0, 3)).Reshape(3, 1); + equal = Tensor.SequenceEqual(t0, t1); + + Assert.Equal([3, 3], equal.Lengths.ToArray()); + Assert.True(equal[0, 0]); + Assert.False(equal[0, 1]); + Assert.False(equal[0, 2]); + Assert.False(equal[1, 0]); + Assert.True(equal[1, 1]); + Assert.False(equal[1, 2]); + Assert.False(equal[2, 0]); + Assert.False(equal[2, 1]); + Assert.True(equal[2, 2]); + + t0 = Tensor.FillRange(Enumerable.Range(0, 3)).Reshape(1, 3); + t1 = Tensor.FillRange(Enumerable.Range(0, 3)).Reshape(3, 1); + equal = Tensor.SequenceEqual(t0, t1); + + Assert.Equal([3, 3], equal.Lengths.ToArray()); + Assert.True(equal[0, 0]); + Assert.False(equal[0, 1]); + Assert.False(equal[0, 2]); + Assert.False(equal[1, 0]); + Assert.True(equal[1, 1]); + Assert.False(equal[1, 2]); + Assert.False(equal[2, 0]); + Assert.False(equal[2, 1]); + Assert.True(equal[2, 2]); + + t0 = Tensor.FillRange(Enumerable.Range(0, 4)); + t1 = Tensor.FillRange(Enumerable.Range(0, 3)); + Assert.Throws(() => Tensor.SequenceEqual(t0, t1)); + } + + [Fact] + public static void TensorMultiplyTests() + { + Tensor t0 = Tensor.FillRange(Enumerable.Range(0, 3)); + Tensor t1 = Tensor.FillRange(Enumerable.Range(0, 3)).Reshape(3, 1); + Tensor t2 = Tensor.Multiply(t0, t1); + + Assert.Equal([3,3], t2.Lengths.ToArray()); + Assert.Equal(0, t2[0, 0]); + Assert.Equal(0, t2[0, 1]); + Assert.Equal(0, t2[0, 2]); + Assert.Equal(0, t2[1, 0]); + Assert.Equal(1, t2[1, 1]); + Assert.Equal(2, t2[1, 2]); + Assert.Equal(0, t2[2, 0]); + Assert.Equal(2, t2[2, 1]); + Assert.Equal(4, t2[2, 2]); + + t2 = Tensor.Multiply(t1, t0); + + Assert.Equal([3, 3], t2.Lengths.ToArray()); + Assert.Equal(0, t2[0, 0]); + Assert.Equal(0, t2[0, 1]); + Assert.Equal(0, t2[0, 2]); + Assert.Equal(0, t2[1, 0]); + Assert.Equal(1, t2[1, 1]); + Assert.Equal(2, t2[1, 2]); + Assert.Equal(0, t2[2, 0]); + Assert.Equal(2, t2[2, 1]); + Assert.Equal(4, t2[2, 2]); + + t1 = Tensor.FillRange(Enumerable.Range(0, 9)).Reshape(3, 3); + t2 = Tensor.Multiply(t0, t1); + + Assert.Equal([3, 3], t2.Lengths.ToArray()); + Assert.Equal(0, t2[0, 0]); + Assert.Equal(1, t2[0, 1]); + Assert.Equal(4, t2[0, 2]); + Assert.Equal(0, t2[1, 0]); + Assert.Equal(4, t2[1, 1]); + Assert.Equal(10, t2[1, 2]); + Assert.Equal(0, t2[2, 0]); + Assert.Equal(7, t2[2, 1]); + Assert.Equal(16, t2[2, 2]); + + + + + } + + [Fact] + public static void TensorBroadcastTests() { Tensor t0 = Tensor.Reshape(Tensor.FillRange(Enumerable.Range(0, 3)), 1, 3, 1, 1, 1); - Tensor t1 = Tensor.BroadcastTo(t0, [1, 3, 1, 2, 1]); + Tensor t1 = Tensor.Broadcast(t0, [1, 3, 1, 2, 1]); + Assert.Equal([1, 3, 1, 2, 1], t1.Lengths.ToArray()); Assert.Equal(0, t1[0, 0, 0, 0, 0]); @@ -26,7 +152,7 @@ public static void TensorBroadcastToTests() Assert.Equal(2, t1[0, 2, 0, 0, 0]); Assert.Equal(2, t1[0, 2, 0, 1, 0]); - t1 = Tensor.BroadcastTo(t0, [1, 3, 2, 1, 1]); + t1 = Tensor.Broadcast(t0, [1, 3, 2, 1, 1]); Assert.Equal([1, 3, 2, 1, 1], t1.Lengths.ToArray()); Assert.Equal(0, t1[0, 0, 0, 0, 0]); @@ -38,7 +164,7 @@ public static void TensorBroadcastToTests() t0 = Tensor.FillRange(Enumerable.Range(0, 3)).Reshape(1, 3); t1 = Tensor.FillRange(Enumerable.Range(0, 3)).Reshape(3, 1); - var t2 = Tensor.BroadcastTo(t0, [3, 3]); + var t2 = Tensor.Broadcast(t0, [3, 3]); Assert.Equal([3, 3], t2.Lengths.ToArray()); Assert.Equal(0, t2[0, 0]); @@ -52,7 +178,7 @@ public static void TensorBroadcastToTests() Assert.Equal(2, t2[2, 2]); t1 = Tensor.FillRange(Enumerable.Range(0, 3)).Reshape(3, 1); - t2 = Tensor.BroadcastTo(t1, [3, 3]); + t2 = Tensor.Broadcast(t1, [3, 3]); Assert.Equal([3, 3], t2.Lengths.ToArray()); Assert.Equal(0, t2[0, 0]); @@ -78,25 +204,40 @@ public static void TensorBroadcastToTests() var t3 = t2.Slice(0..1, ..); Assert.Equal([1, 3], t3.Lengths.ToArray()); + + t1 = Tensor.FillRange(Enumerable.Range(0, 3)); + t2 = Tensor.Broadcast(t1, [3, 3]); + Assert.Equal([3, 3], t2.Lengths.ToArray()); + + Assert.Equal(0, t2[0, 0]); + Assert.Equal(1, t2[0, 1]); + Assert.Equal(2, t2[0, 2]); + Assert.Equal(0, t2[1, 0]); + Assert.Equal(1, t2[1, 1]); + Assert.Equal(2, t2[1, 2]); + Assert.Equal(0, t2[2, 0]); + Assert.Equal(1, t2[2, 1]); + Assert.Equal(2, t2[2, 2]); } - [Fact] - public static void TensorBroadcastToShapeCompatibleTests() - { - Tensor t0 = Tensor.Reshape(Tensor.FillRange(Enumerable.Range(0, 8)), 8); - Tensor t1 = Tensor.Reshape(Tensor.FillRange(Enumerable.Range(0, 8)), 1, 8); + //// Needs internals visible + //[Fact] + //public static void TensorBroadcastToShapeCompatibleTests() + //{ + // Tensor t0 = Tensor.Reshape(Tensor.FillRange(Enumerable.Range(0, 8)), 8); + // Tensor t1 = Tensor.Reshape(Tensor.FillRange(Enumerable.Range(0, 8)), 1, 8); - Assert.True(Tensor.AreShapesBroadcastToCompatible(t0.Lengths, t1.Lengths)); - Assert.True(Tensor.AreShapesBroadcastToCompatible(t0, t1)); + // Assert.True(Tensor.AreShapesBroadcastToCompatible(t0.Lengths, t1.Lengths)); + // Assert.True(Tensor.AreShapesBroadcastToCompatible(t0, t1)); - t1 = Tensor.Reshape(Tensor.FillRange(Enumerable.Range(0, 8)), 2, 4); - Assert.False(Tensor.AreShapesBroadcastToCompatible(t0, t1)); + // t1 = Tensor.Reshape(Tensor.FillRange(Enumerable.Range(0, 8)), 2, 4); + // Assert.False(Tensor.AreShapesBroadcastToCompatible(t0, t1)); - t0 = Tensor.FillRange(Enumerable.Range(0, 3)); + // t0 = Tensor.FillRange(Enumerable.Range(0, 3)); - Assert.False(Tensor.AreShapesBroadcastToCompatible(t0.Lengths, [1,3,1,1,1])); + // Assert.False(Tensor.AreShapesBroadcastToCompatible(t0.Lengths, [1,3,1,1,1])); - } + //} [Fact] public static void TensorResizeTests() @@ -108,7 +249,7 @@ public static void TensorResizeTests() t1 = Tensor.Resize(t0, [1, 1]); Assert.Equal([1, 1], t1.Lengths.ToArray()); - Assert.Equal(0, t1[0]); + Assert.Equal(0, t1[0, 0]); t1 = Tensor.Resize(t0, [6]); Assert.Equal([6], t1.Lengths.ToArray()); @@ -294,14 +435,78 @@ public static void TensorStackTests() Assert.Equal(0, resultTensor[0, 0, 0]); Assert.Equal(1, resultTensor[0, 0, 1]); - Assert.Equal(2, resultTensor[0, 0, 0]); - Assert.Equal(3, resultTensor[0, 1, 1]); - Assert.Equal(4, resultTensor[0, 2, 0]); - Assert.Equal(5, resultTensor[0, 2, 1]); - Assert.Equal(6, resultTensor[0, 3, 0]); - Assert.Equal(7, resultTensor[0, 3, 1]); - Assert.Equal(8, resultTensor[0, 3, 1]); - Assert.Equal(9, resultTensor[0, 3, 1]); + Assert.Equal(2, resultTensor[0, 0, 2]); + Assert.Equal(3, resultTensor[0, 0, 3]); + Assert.Equal(4, resultTensor[0, 0, 4]); + Assert.Equal(5, resultTensor[0, 1, 0]); + Assert.Equal(6, resultTensor[0, 1, 1]); + Assert.Equal(7, resultTensor[0, 1, 2]); + Assert.Equal(8, resultTensor[0, 1, 3]); + Assert.Equal(9, resultTensor[0, 1, 4]); + Assert.Equal(0, resultTensor[1, 0, 0]); + Assert.Equal(1, resultTensor[1, 0, 1]); + Assert.Equal(2, resultTensor[1, 0, 2]); + Assert.Equal(3, resultTensor[1, 0, 3]); + Assert.Equal(4, resultTensor[1, 0, 4]); + Assert.Equal(5, resultTensor[1, 1, 0]); + Assert.Equal(6, resultTensor[1, 1, 1]); + Assert.Equal(7, resultTensor[1, 1, 2]); + Assert.Equal(8, resultTensor[1, 1, 3]); + Assert.Equal(9, resultTensor[1, 1, 4]); + + resultTensor = Tensor.Stack([t0, t1], axis:1); + Assert.Equal(3, resultTensor.Rank); + Assert.Equal(2, resultTensor.Lengths[0]); + Assert.Equal(2, resultTensor.Lengths[1]); + Assert.Equal(5, resultTensor.Lengths[2]); + + Assert.Equal(0, resultTensor[0, 0, 0]); + Assert.Equal(1, resultTensor[0, 0, 1]); + Assert.Equal(2, resultTensor[0, 0, 2]); + Assert.Equal(3, resultTensor[0, 0, 3]); + Assert.Equal(4, resultTensor[0, 0, 4]); + Assert.Equal(0, resultTensor[0, 1, 0]); + Assert.Equal(1, resultTensor[0, 1, 1]); + Assert.Equal(2, resultTensor[0, 1, 2]); + Assert.Equal(3, resultTensor[0, 1, 3]); + Assert.Equal(4, resultTensor[0, 1, 4]); + Assert.Equal(5, resultTensor[1, 0, 0]); + Assert.Equal(6, resultTensor[1, 0, 1]); + Assert.Equal(7, resultTensor[1, 0, 2]); + Assert.Equal(8, resultTensor[1, 0, 3]); + Assert.Equal(9, resultTensor[1, 0, 4]); + Assert.Equal(5, resultTensor[1, 1, 0]); + Assert.Equal(6, resultTensor[1, 1, 1]); + Assert.Equal(7, resultTensor[1, 1, 2]); + Assert.Equal(8, resultTensor[1, 1, 3]); + Assert.Equal(9, resultTensor[1, 1, 4]); + + resultTensor = Tensor.Stack([t0, t1], axis: 2); + Assert.Equal(3, resultTensor.Rank); + Assert.Equal(2, resultTensor.Lengths[0]); + Assert.Equal(5, resultTensor.Lengths[1]); + Assert.Equal(2, resultTensor.Lengths[2]); + + Assert.Equal(0, resultTensor[0, 0, 0]); + Assert.Equal(0, resultTensor[0, 0, 1]); + Assert.Equal(1, resultTensor[0, 1, 0]); + Assert.Equal(1, resultTensor[0, 1, 1]); + Assert.Equal(2, resultTensor[0, 2, 0]); + Assert.Equal(2, resultTensor[0, 2, 1]); + Assert.Equal(3, resultTensor[0, 3, 0]); + Assert.Equal(3, resultTensor[0, 3, 1]); + Assert.Equal(4, resultTensor[0, 4, 0]); + Assert.Equal(4, resultTensor[0, 4, 1]); + Assert.Equal(5, resultTensor[1, 0, 0]); + Assert.Equal(5, resultTensor[1, 0, 1]); + Assert.Equal(6, resultTensor[1, 1, 0]); + Assert.Equal(6, resultTensor[1, 1, 1]); + Assert.Equal(7, resultTensor[1, 2, 0]); + Assert.Equal(7, resultTensor[1, 2, 1]); + Assert.Equal(8, resultTensor[1, 3, 0]); + Assert.Equal(8, resultTensor[1, 3, 1]); + Assert.Equal(9, resultTensor[1, 4, 0]); + Assert.Equal(9, resultTensor[1, 4, 1]); } [Fact]