Skip to content

Commit 0c5f13b

Browse files
authored
Fix Tensor.StdDev vs TensorPrimitives.StdDev differences for Complex input (#119229)
1 parent 116db00 commit 0c5f13b

File tree

3 files changed

+191
-7
lines changed

3 files changed

+191
-7
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4810,10 +4810,9 @@ public static T StdDev<T>(in ReadOnlyTensorSpan<T> x)
48104810
{
48114811
T mean = Average(x);
48124812
T result = T.AdditiveIdentity;
4813-
TensorOperation.Invoke<TensorOperation.SumOfSquaredDifferences<T>, T, T>(x, mean, ref result);
4813+
TensorOperation.Invoke<TensorOperation.SumOfSquaredAbsoluteDifferences<T>, T, T>(x, mean, ref result);
48144814
T variance = result / T.CreateChecked(x.FlattenedLength);
48154815
return T.Sqrt(variance);
4816-
48174816
}
48184817
#endregion
48194818

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorOperation.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2103,6 +2103,27 @@ public static void Invoke(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destinat
21032103
}
21042104
}
21052105

2106+
public readonly struct SumOfSquaredAbsoluteDifferences<T>
2107+
: IBinaryOperation_Tensor_Scalar<T, T>
2108+
where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>, ISubtractionOperators<T, T, T>, INumberBase<T>
2109+
{
2110+
public static void Invoke(ref readonly T x, T y, ref T destination)
2111+
{
2112+
// Absolute value is needed before squaring to support complex numbers
2113+
T diff = T.Abs(x - y);
2114+
destination += diff * diff;
2115+
}
2116+
public static void Invoke(ReadOnlySpan<T> x, T y, Span<T> destination)
2117+
{
2118+
for (int i = 0; i < x.Length; i++)
2119+
{
2120+
// Absolute value is needed before squaring to support complex numbers
2121+
T diff = T.Abs(x[i] - y);
2122+
destination[i] = diff * diff;
2123+
}
2124+
}
2125+
}
2126+
21062127
public readonly struct SumOfSquares<T>
21072128
: IUnaryReduction_Tensor<T, T>
21082129
where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>

src/libraries/System.Numerics.Tensors/tests/TensorTests.cs

Lines changed: 169 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
using System.Buffers;
55
using System.Collections.Generic;
6+
using System.Diagnostics.CodeAnalysis;
7+
using System.Globalization;
68
using System.Linq;
79
using System.Runtime.InteropServices;
810
using Xunit;
@@ -1152,14 +1154,89 @@ public static void TensorStackTests()
11521154
Assert.Equal(40, resultTensor2[1, 1, 1]);
11531155
}
11541156

1155-
[Fact]
1156-
public static void TensorStdDevTests()
1157+
public static IEnumerable<object[]> StdDevFloatTestData()
11571158
{
1158-
Tensor<float> t0 = Tensor.Create(Enumerable.Sequence<float>(0, 4, 1).ToArray(), lengths: [2, 2]);
1159+
// Test case where Abs doesn't change the result (real numbers, positive values)
1160+
yield return new object[]
1161+
{
1162+
new float[] { 0f, 1f, 2f, 3f },
1163+
StdDev([0f, 1f, 2f, 3f])
1164+
};
1165+
1166+
// Test case with all same values (should be 0)
1167+
yield return new object[]
1168+
{
1169+
new float[] { 1f, 1f, 1f, 1f },
1170+
0f
1171+
};
1172+
1173+
// Test case with negative values where Abs could matter (but doesn't for standard deviation)
1174+
yield return new object[]
1175+
{
1176+
new float[] { -2f, -1f, 1f, 2f },
1177+
StdDev([-2f, -1f, 1f, 2f])
1178+
};
1179+
}
1180+
1181+
public static IEnumerable<object[]> StdDevComplexTestData()
1182+
{
1183+
// Test case where Abs is critical (complex numbers)
1184+
yield return new object[]
1185+
{
1186+
new TestComplex[] { new(new(1, 2)), new(new(3, 4)) }
1187+
};
1188+
1189+
// Test case with purely imaginary numbers
1190+
yield return new object[]
1191+
{
1192+
new TestComplex[] { new(new(0, 1)), new(new(0, 2)), new(new(0, 3)) }
1193+
};
1194+
1195+
// Test case with purely real numbers (should behave like floats)
1196+
yield return new object[]
1197+
{
1198+
new TestComplex[] { new(new(1, 0)), new(new(2, 0)), new(new(3, 0)) }
1199+
};
1200+
}
1201+
1202+
[Theory, MemberData(nameof(StdDevFloatTestData))]
1203+
public static void TensorStdDevFloatTests(float[] data, float expectedStdDev)
1204+
{
1205+
var tensor = Tensor.Create(data);
1206+
1207+
var tensorPrimitivesResult = TensorPrimitives.StdDev<float>(data);
1208+
var tensorResult = Tensor.StdDev(tensor.AsReadOnlyTensorSpan());
1209+
1210+
// Both should produce the same result
1211+
Assert.Equal(tensorPrimitivesResult, tensorResult, precision: 5);
1212+
Assert.Equal(expectedStdDev, tensorResult, precision: 5);
1213+
1214+
// Test that non-contiguous calculations work with reshaped tensor
1215+
if (data.Length >= 4)
1216+
{
1217+
var reshapedTensor = Tensor.Create(data, lengths: [2, 2]);
1218+
var reshapedResult = Tensor.StdDev(reshapedTensor.AsReadOnlyTensorSpan());
1219+
Assert.Equal(expectedStdDev, reshapedResult, precision: 5);
1220+
}
1221+
}
11591222

1160-
Assert.Equal(StdDev([0, 1, 2, 3]), Tensor.StdDev<float>(t0), .1);
1223+
[Theory, MemberData(nameof(StdDevComplexTestData))]
1224+
public static void TensorStdDevComplexTests(TestComplex[] data)
1225+
{
1226+
var tensor = Tensor.Create(data);
1227+
1228+
var tensorPrimitivesResult = TensorPrimitives.StdDev<TestComplex>(data);
1229+
var tensorResult = Tensor.StdDev(tensor.AsReadOnlyTensorSpan());
1230+
1231+
// Both should produce the same result - this is the key test for the fix
1232+
Assert.Equal(tensorPrimitivesResult.Real, tensorResult.Real, precision: 10);
1233+
Assert.Equal(tensorPrimitivesResult.Imaginary, tensorResult.Imaginary, precision: 10);
1234+
}
11611235

1162-
// Test that non-contiguous calculations work
1236+
[Fact]
1237+
public static void TensorStdDevNonContiguousTests()
1238+
{
1239+
// Test that non-contiguous calculations work for float tensors
11631240
Tensor<float> fourByFour = Tensor.CreateFromShape<float>([4, 4]);
11641241
fourByFour[[0, 0]] = 1f;
11651242
fourByFour[[0, 1]] = 1f;
@@ -3190,4 +3267,91 @@ public static void ToStringZeroDataTest()
31903267
Assert.Equal(expected, tensor.ToString([2, 0, 2]));
31913268
}
31923269
}
3270+
3271+
/// <summary>
3272+
/// Test complex number type that implements IRootFunctions for testing consistency between
3273+
/// TensorPrimitives.StdDev and Tensor.StdDev
3274+
/// </summary>
3275+
public readonly struct TestComplex(Complex value) : IRootFunctions<TestComplex>, IEquatable<TestComplex>
3276+
{
3277+
private readonly Complex _value = value;
3278+
3279+
public double Real => _value.Real;
3280+
public double Imaginary => _value.Imaginary;
3281+
3282+
public static TestComplex One => new(Complex.One);
3283+
public static int Radix => 2;
3284+
public static TestComplex Zero => new(Complex.Zero);
3285+
public static TestComplex E => new(new(Math.E, 0));
3286+
public static TestComplex Pi => new(new(Math.PI, 0));
3287+
public static TestComplex Tau => new(new(Math.Tau, 0));
3288+
3289+
public static TestComplex operator +(TestComplex left, TestComplex right) => new(left._value + right._value);
3290+
public static TestComplex operator *(TestComplex left, TestComplex right) => new(left._value * right._value);
3291+
public static TestComplex operator /(TestComplex left, TestComplex right) => new(left._value / right._value);
3292+
public static TestComplex operator -(TestComplex left, TestComplex right) => new(left._value - right._value);
3293+
public static TestComplex Sqrt(TestComplex x) => new(Complex.Sqrt(x._value));
3294+
public static TestComplex Abs(TestComplex value) => new(new(Complex.Abs(value._value), 0));
3295+
public static TestComplex AdditiveIdentity => new(Complex.Zero);
3296+
public static TestComplex CreateChecked<TOther>(TOther value) where TOther : INumberBase<TOther> => new(Complex.CreateChecked(value));
3297+
3298+
// Override Object methods
3299+
public override bool Equals(object? obj) => obj is TestComplex other && Equals(other);
3300+
public override int GetHashCode() => _value.GetHashCode();
3301+
public override string ToString() => _value.ToString();
3302+
3303+
// IEquatable<TestComplex>
3304+
public bool Equals(TestComplex other) => _value.Equals(other._value);
3305+
3306+
// Operators
3307+
public static bool operator ==(TestComplex left, TestComplex right) => left.Equals(right);
3308+
public static bool operator !=(TestComplex left, TestComplex right) => !left.Equals(right);
3309+
3310+
// Required interface implementations not needed for tests - throw NotImplementedException
3311+
public string ToString(string? format, IFormatProvider? formatProvider) => throw new NotImplementedException();
3312+
public bool TryFormat(Span<char> destination, out int charsWritten, ReadOnlySpan<char> format, IFormatProvider? provider) => throw new NotImplementedException();
3313+
public static TestComplex Parse(string s, IFormatProvider? provider) => throw new NotImplementedException();
3314+
public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException();
3315+
public static TestComplex Parse(ReadOnlySpan<char> s, IFormatProvider? provider) => throw new NotImplementedException();
3316+
public static bool TryParse(ReadOnlySpan<char> s, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException();
3317+
public static TestComplex operator --(TestComplex value) => throw new NotImplementedException();
3318+
public static TestComplex operator ++(TestComplex value) => throw new NotImplementedException();
3319+
public static TestComplex MultiplicativeIdentity => new(Complex.One);
3320+
public static TestComplex operator -(TestComplex value) => new(-value._value);
3321+
public static TestComplex operator +(TestComplex value) => new(value._value);
3322+
public static bool IsCanonical(TestComplex value) => throw new NotImplementedException();
3323+
public static bool IsComplexNumber(TestComplex value) => throw new NotImplementedException();
3324+
public static bool IsEvenInteger(TestComplex value) => throw new NotImplementedException();
3325+
public static bool IsFinite(TestComplex value) => throw new NotImplementedException();
3326+
public static bool IsImaginaryNumber(TestComplex value) => throw new NotImplementedException();
3327+
public static bool IsInfinity(TestComplex value) => throw new NotImplementedException();
3328+
public static bool IsInteger(TestComplex value) => throw new NotImplementedException();
3329+
public static bool IsNaN(TestComplex value) => throw new NotImplementedException();
3330+
public static bool IsNegative(TestComplex value) => throw new NotImplementedException();
3331+
public static bool IsNegativeInfinity(TestComplex value) => throw new NotImplementedException();
3332+
public static bool IsNormal(TestComplex value) => throw new NotImplementedException();
3333+
public static bool IsOddInteger(TestComplex value) => throw new NotImplementedException();
3334+
public static bool IsPositive(TestComplex value) => throw new NotImplementedException();
3335+
public static bool IsPositiveInfinity(TestComplex value) => throw new NotImplementedException();
3336+
public static bool IsRealNumber(TestComplex value) => throw new NotImplementedException();
3337+
public static bool IsSubnormal(TestComplex value) => throw new NotImplementedException();
3338+
public static bool IsZero(TestComplex value) => throw new NotImplementedException();
3339+
public static TestComplex MaxMagnitude(TestComplex x, TestComplex y) => throw new NotImplementedException();
3340+
public static TestComplex MaxMagnitudeNumber(TestComplex x, TestComplex y) => throw new NotImplementedException();
3341+
public static TestComplex MinMagnitude(TestComplex x, TestComplex y) => throw new NotImplementedException();
3342+
public static TestComplex MinMagnitudeNumber(TestComplex x, TestComplex y) => throw new NotImplementedException();
3343+
public static TestComplex Parse(ReadOnlySpan<char> s, NumberStyles style, IFormatProvider? provider) => throw new NotImplementedException();
3344+
public static TestComplex Parse(string s, NumberStyles style, IFormatProvider? provider) => throw new NotImplementedException();
3345+
public static bool TryConvertFromSaturating<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
3346+
public static bool TryConvertFromTruncating<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
3347+
public static bool TryConvertFromChecked<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
3348+
public static bool TryConvertToChecked<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
3349+
public static bool TryConvertToSaturating<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
3350+
public static bool TryConvertToTruncating<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
3351+
public static bool TryParse(ReadOnlySpan<char> s, NumberStyles style, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException();
3352+
public static bool TryParse([NotNullWhen(true)] string? s, NumberStyles style, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException();
3353+
public static TestComplex Cbrt(TestComplex x) => throw new NotImplementedException();
3354+
public static TestComplex Hypot(TestComplex x, TestComplex y) => throw new NotImplementedException();
3355+
public static TestComplex RootN(TestComplex x, int n) => throw new NotImplementedException();
3356+
}
31933357
}

0 commit comments

Comments
 (0)