diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs index c6e4e116744fbd..ff8e2373bc1c63 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs @@ -4810,10 +4810,9 @@ public static T StdDev(in ReadOnlyTensorSpan x) { T mean = Average(x); T result = T.AdditiveIdentity; - TensorOperation.Invoke, T, T>(x, mean, ref result); + TensorOperation.Invoke, T, T>(x, mean, ref result); T variance = result / T.CreateChecked(x.FlattenedLength); return T.Sqrt(variance); - } #endregion diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs index 227a8d95db58ed..c96103ee913c92 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs @@ -2103,6 +2103,27 @@ public static void Invoke(ReadOnlySpan x, ReadOnlySpan y, Span destinat } } + public readonly struct SumOfSquaredAbsoluteDifferences + : IBinaryOperation_Tensor_Scalar + where T : IAdditionOperators, IAdditiveIdentity, IMultiplyOperators, ISubtractionOperators, INumberBase + { + public static void Invoke(ref readonly T x, T y, ref T destination) + { + // Absolute value is needed before squaring to support complex numbers + T diff = T.Abs(x - y); + destination += diff * diff; + } + public static void Invoke(ReadOnlySpan x, T y, Span destination) + { + for (int i = 0; i < x.Length; i++) + { + // Absolute value is needed before squaring to support complex numbers + T diff = T.Abs(x[i] - y); + destination[i] = diff * diff; + } + } + } + public readonly struct SumOfSquares : IUnaryReduction_Tensor where T : IAdditionOperators, IAdditiveIdentity, IMultiplyOperators diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs index 22c86fbbeb695e..1a999ff33080e2 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs @@ -3,6 +3,8 @@ using System.Buffers; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.Linq; using System.Runtime.InteropServices; using Xunit; @@ -1152,14 +1154,89 @@ public static void TensorStackTests() Assert.Equal(40, resultTensor2[1, 1, 1]); } - [Fact] - public static void TensorStdDevTests() + public static IEnumerable StdDevFloatTestData() { - Tensor t0 = Tensor.Create(Enumerable.Sequence(0, 4, 1).ToArray(), lengths: [2, 2]); + // Test case where Abs doesn't change the result (real numbers, positive values) + yield return new object[] + { + new float[] { 0f, 1f, 2f, 3f }, + StdDev([0f, 1f, 2f, 3f]) + }; + + // Test case with all same values (should be 0) + yield return new object[] + { + new float[] { 1f, 1f, 1f, 1f }, + 0f + }; + + // Test case with negative values where Abs could matter (but doesn't for standard deviation) + yield return new object[] + { + new float[] { -2f, -1f, 1f, 2f }, + StdDev([-2f, -1f, 1f, 2f]) + }; + } + + public static IEnumerable StdDevComplexTestData() + { + // Test case where Abs is critical (complex numbers) + yield return new object[] + { + new TestComplex[] { new(new(1, 2)), new(new(3, 4)) } + }; + + // Test case with purely imaginary numbers + yield return new object[] + { + new TestComplex[] { new(new(0, 1)), new(new(0, 2)), new(new(0, 3)) } + }; + + // Test case with purely real numbers (should behave like floats) + yield return new object[] + { + new TestComplex[] { new(new(1, 0)), new(new(2, 0)), new(new(3, 0)) } + }; + } + + [Theory, MemberData(nameof(StdDevFloatTestData))] + public static void TensorStdDevFloatTests(float[] data, float expectedStdDev) + { + var tensor = Tensor.Create(data); + + var tensorPrimitivesResult = TensorPrimitives.StdDev(data); + var tensorResult = Tensor.StdDev(tensor.AsReadOnlyTensorSpan()); + + // Both should produce the same result + Assert.Equal(tensorPrimitivesResult, tensorResult, precision: 5); + Assert.Equal(expectedStdDev, tensorResult, precision: 5); + + // Test that non-contiguous calculations work with reshaped tensor + if (data.Length >= 4) + { + var reshapedTensor = Tensor.Create(data, lengths: [2, 2]); + var reshapedResult = Tensor.StdDev(reshapedTensor.AsReadOnlyTensorSpan()); + Assert.Equal(expectedStdDev, reshapedResult, precision: 5); + } + } - Assert.Equal(StdDev([0, 1, 2, 3]), Tensor.StdDev(t0), .1); + [Theory, MemberData(nameof(StdDevComplexTestData))] + public static void TensorStdDevComplexTests(TestComplex[] data) + { + var tensor = Tensor.Create(data); + + var tensorPrimitivesResult = TensorPrimitives.StdDev(data); + var tensorResult = Tensor.StdDev(tensor.AsReadOnlyTensorSpan()); + + // Both should produce the same result - this is the key test for the fix + Assert.Equal(tensorPrimitivesResult.Real, tensorResult.Real, precision: 10); + Assert.Equal(tensorPrimitivesResult.Imaginary, tensorResult.Imaginary, precision: 10); + } - // Test that non-contiguous calculations work + [Fact] + public static void TensorStdDevNonContiguousTests() + { + // Test that non-contiguous calculations work for float tensors Tensor fourByFour = Tensor.CreateFromShape([4, 4]); fourByFour[[0, 0]] = 1f; fourByFour[[0, 1]] = 1f; @@ -3190,4 +3267,91 @@ public static void ToStringZeroDataTest() Assert.Equal(expected, tensor.ToString([2, 0, 2])); } } + + /// + /// Test complex number type that implements IRootFunctions for testing consistency between + /// TensorPrimitives.StdDev and Tensor.StdDev + /// + public readonly struct TestComplex(Complex value) : IRootFunctions, IEquatable + { + private readonly Complex _value = value; + + public double Real => _value.Real; + public double Imaginary => _value.Imaginary; + + public static TestComplex One => new(Complex.One); + public static int Radix => 2; + public static TestComplex Zero => new(Complex.Zero); + public static TestComplex E => new(new(Math.E, 0)); + public static TestComplex Pi => new(new(Math.PI, 0)); + public static TestComplex Tau => new(new(Math.Tau, 0)); + + public static TestComplex operator +(TestComplex left, TestComplex right) => new(left._value + right._value); + public static TestComplex operator *(TestComplex left, TestComplex right) => new(left._value * right._value); + public static TestComplex operator /(TestComplex left, TestComplex right) => new(left._value / right._value); + public static TestComplex operator -(TestComplex left, TestComplex right) => new(left._value - right._value); + public static TestComplex Sqrt(TestComplex x) => new(Complex.Sqrt(x._value)); + public static TestComplex Abs(TestComplex value) => new(new(Complex.Abs(value._value), 0)); + public static TestComplex AdditiveIdentity => new(Complex.Zero); + public static TestComplex CreateChecked(TOther value) where TOther : INumberBase => new(Complex.CreateChecked(value)); + + // Override Object methods + public override bool Equals(object? obj) => obj is TestComplex other && Equals(other); + public override int GetHashCode() => _value.GetHashCode(); + public override string ToString() => _value.ToString(); + + // IEquatable + public bool Equals(TestComplex other) => _value.Equals(other._value); + + // Operators + public static bool operator ==(TestComplex left, TestComplex right) => left.Equals(right); + public static bool operator !=(TestComplex left, TestComplex right) => !left.Equals(right); + + // Required interface implementations not needed for tests - throw NotImplementedException + public string ToString(string? format, IFormatProvider? formatProvider) => throw new NotImplementedException(); + public bool TryFormat(Span destination, out int charsWritten, ReadOnlySpan format, IFormatProvider? provider) => throw new NotImplementedException(); + public static TestComplex Parse(string s, IFormatProvider? provider) => throw new NotImplementedException(); + public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException(); + public static TestComplex Parse(ReadOnlySpan s, IFormatProvider? provider) => throw new NotImplementedException(); + public static bool TryParse(ReadOnlySpan s, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException(); + public static TestComplex operator --(TestComplex value) => throw new NotImplementedException(); + public static TestComplex operator ++(TestComplex value) => throw new NotImplementedException(); + public static TestComplex MultiplicativeIdentity => new(Complex.One); + public static TestComplex operator -(TestComplex value) => new(-value._value); + public static TestComplex operator +(TestComplex value) => new(value._value); + public static bool IsCanonical(TestComplex value) => throw new NotImplementedException(); + public static bool IsComplexNumber(TestComplex value) => throw new NotImplementedException(); + public static bool IsEvenInteger(TestComplex value) => throw new NotImplementedException(); + public static bool IsFinite(TestComplex value) => throw new NotImplementedException(); + public static bool IsImaginaryNumber(TestComplex value) => throw new NotImplementedException(); + public static bool IsInfinity(TestComplex value) => throw new NotImplementedException(); + public static bool IsInteger(TestComplex value) => throw new NotImplementedException(); + public static bool IsNaN(TestComplex value) => throw new NotImplementedException(); + public static bool IsNegative(TestComplex value) => throw new NotImplementedException(); + public static bool IsNegativeInfinity(TestComplex value) => throw new NotImplementedException(); + public static bool IsNormal(TestComplex value) => throw new NotImplementedException(); + public static bool IsOddInteger(TestComplex value) => throw new NotImplementedException(); + public static bool IsPositive(TestComplex value) => throw new NotImplementedException(); + public static bool IsPositiveInfinity(TestComplex value) => throw new NotImplementedException(); + public static bool IsRealNumber(TestComplex value) => throw new NotImplementedException(); + public static bool IsSubnormal(TestComplex value) => throw new NotImplementedException(); + public static bool IsZero(TestComplex value) => throw new NotImplementedException(); + public static TestComplex MaxMagnitude(TestComplex x, TestComplex y) => throw new NotImplementedException(); + public static TestComplex MaxMagnitudeNumber(TestComplex x, TestComplex y) => throw new NotImplementedException(); + public static TestComplex MinMagnitude(TestComplex x, TestComplex y) => throw new NotImplementedException(); + public static TestComplex MinMagnitudeNumber(TestComplex x, TestComplex y) => throw new NotImplementedException(); + public static TestComplex Parse(ReadOnlySpan s, NumberStyles style, IFormatProvider? provider) => throw new NotImplementedException(); + public static TestComplex Parse(string s, NumberStyles style, IFormatProvider? provider) => throw new NotImplementedException(); + public static bool TryConvertFromSaturating(TOther value, out TestComplex result) where TOther : INumberBase => throw new NotImplementedException(); + public static bool TryConvertFromTruncating(TOther value, out TestComplex result) where TOther : INumberBase => throw new NotImplementedException(); + public static bool TryConvertFromChecked(TOther value, out TestComplex result) where TOther : INumberBase => throw new NotImplementedException(); + public static bool TryConvertToChecked(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase => throw new NotImplementedException(); + public static bool TryConvertToSaturating(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase => throw new NotImplementedException(); + public static bool TryConvertToTruncating(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase => throw new NotImplementedException(); + public static bool TryParse(ReadOnlySpan s, NumberStyles style, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException(); + public static bool TryParse([NotNullWhen(true)] string? s, NumberStyles style, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException(); + public static TestComplex Cbrt(TestComplex x) => throw new NotImplementedException(); + public static TestComplex Hypot(TestComplex x, TestComplex y) => throw new NotImplementedException(); + public static TestComplex RootN(TestComplex x, int n) => throw new NotImplementedException(); + } }