Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4810,10 +4810,9 @@ public static T StdDev<T>(in ReadOnlyTensorSpan<T> x)
{
T mean = Average(x);
T result = T.AdditiveIdentity;
TensorOperation.Invoke<TensorOperation.SumOfSquaredDifferences<T>, T, T>(x, mean, ref result);
TensorOperation.Invoke<TensorOperation.SumOfSquaredAbsoluteDifferences<T>, T, T>(x, mean, ref result);
T variance = result / T.CreateChecked(x.FlattenedLength);
return T.Sqrt(variance);

}
#endregion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2103,6 +2103,27 @@ public static void Invoke(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destinat
}
}

public readonly struct SumOfSquaredAbsoluteDifferences<T>
: IBinaryOperation_Tensor_Scalar<T, T>
where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>, ISubtractionOperators<T, T, T>, INumberBase<T>
{
public static void Invoke(ref readonly T x, T y, ref T destination)
{
// Absolute value is needed before squaring to support complex numbers
T diff = T.Abs(x - y);
destination += diff * diff;
}
public static void Invoke(ReadOnlySpan<T> x, T y, Span<T> destination)
{
for (int i = 0; i < x.Length; i++)
{
// Absolute value is needed before squaring to support complex numbers
T diff = T.Abs(x[i] - y);
destination[i] = diff * diff;
}
}
}

public readonly struct SumOfSquares<T>
: IUnaryReduction_Tensor<T, T>
where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>
Expand Down
174 changes: 169 additions & 5 deletions src/libraries/System.Numerics.Tensors/tests/TensorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Runtime.InteropServices;
using Xunit;
Expand Down Expand Up @@ -1152,14 +1154,89 @@ public static void TensorStackTests()
Assert.Equal(40, resultTensor2[1, 1, 1]);
}

[Fact]
public static void TensorStdDevTests()
public static IEnumerable<object[]> StdDevFloatTestData()
{
Tensor<float> t0 = Tensor.Create(Enumerable.Sequence<float>(0, 4, 1).ToArray(), lengths: [2, 2]);
// Test case where Abs doesn't change the result (real numbers, positive values)
yield return new object[]
{
new float[] { 0f, 1f, 2f, 3f },
StdDev([0f, 1f, 2f, 3f])
};

// Test case with all same values (should be 0)
yield return new object[]
{
new float[] { 1f, 1f, 1f, 1f },
0f
};

// Test case with negative values where Abs could matter (but doesn't for standard deviation)
yield return new object[]
{
new float[] { -2f, -1f, 1f, 2f },
StdDev([-2f, -1f, 1f, 2f])
};
}

public static IEnumerable<object[]> StdDevComplexTestData()
{
// Test case where Abs is critical (complex numbers)
yield return new object[]
{
new TestComplex[] { new(new(1, 2)), new(new(3, 4)) }
};

// Test case with purely imaginary numbers
yield return new object[]
{
new TestComplex[] { new(new(0, 1)), new(new(0, 2)), new(new(0, 3)) }
};

// Test case with purely real numbers (should behave like floats)
yield return new object[]
{
new TestComplex[] { new(new(1, 0)), new(new(2, 0)), new(new(3, 0)) }
};
}

[Theory, MemberData(nameof(StdDevFloatTestData))]
public static void TensorStdDevFloatTests(float[] data, float expectedStdDev)
{
var tensor = Tensor.Create(data);

var tensorPrimitivesResult = TensorPrimitives.StdDev<float>(data);
var tensorResult = Tensor.StdDev(tensor.AsReadOnlyTensorSpan());

// Both should produce the same result
Assert.Equal(tensorPrimitivesResult, tensorResult, precision: 5);
Assert.Equal(expectedStdDev, tensorResult, precision: 5);

// Test that non-contiguous calculations work with reshaped tensor
if (data.Length >= 4)
{
var reshapedTensor = Tensor.Create(data, lengths: [2, 2]);
var reshapedResult = Tensor.StdDev(reshapedTensor.AsReadOnlyTensorSpan());
Assert.Equal(expectedStdDev, reshapedResult, precision: 5);
}
}

Assert.Equal(StdDev([0, 1, 2, 3]), Tensor.StdDev<float>(t0), .1);
[Theory, MemberData(nameof(StdDevComplexTestData))]
public static void TensorStdDevComplexTests(TestComplex[] data)
{
var tensor = Tensor.Create(data);

var tensorPrimitivesResult = TensorPrimitives.StdDev<TestComplex>(data);
var tensorResult = Tensor.StdDev(tensor.AsReadOnlyTensorSpan());

// Both should produce the same result - this is the key test for the fix
Assert.Equal(tensorPrimitivesResult.Real, tensorResult.Real, precision: 10);
Assert.Equal(tensorPrimitivesResult.Imaginary, tensorResult.Imaginary, precision: 10);
}

// Test that non-contiguous calculations work
[Fact]
public static void TensorStdDevNonContiguousTests()
{
// Test that non-contiguous calculations work for float tensors
Tensor<float> fourByFour = Tensor.CreateFromShape<float>([4, 4]);
fourByFour[[0, 0]] = 1f;
fourByFour[[0, 1]] = 1f;
Expand Down Expand Up @@ -3190,4 +3267,91 @@ public static void ToStringZeroDataTest()
Assert.Equal(expected, tensor.ToString([2, 0, 2]));
}
}

/// <summary>
/// Test complex number type that implements IRootFunctions for testing consistency between
/// TensorPrimitives.StdDev and Tensor.StdDev
/// </summary>
public readonly struct TestComplex(Complex value) : IRootFunctions<TestComplex>, IEquatable<TestComplex>
{
private readonly Complex _value = value;

public double Real => _value.Real;
public double Imaginary => _value.Imaginary;

public static TestComplex One => new(Complex.One);
public static int Radix => 2;
public static TestComplex Zero => new(Complex.Zero);
public static TestComplex E => new(new(Math.E, 0));
public static TestComplex Pi => new(new(Math.PI, 0));
public static TestComplex Tau => new(new(Math.Tau, 0));

public static TestComplex operator +(TestComplex left, TestComplex right) => new(left._value + right._value);
public static TestComplex operator *(TestComplex left, TestComplex right) => new(left._value * right._value);
public static TestComplex operator /(TestComplex left, TestComplex right) => new(left._value / right._value);
public static TestComplex operator -(TestComplex left, TestComplex right) => new(left._value - right._value);
public static TestComplex Sqrt(TestComplex x) => new(Complex.Sqrt(x._value));
public static TestComplex Abs(TestComplex value) => new(new(Complex.Abs(value._value), 0));
public static TestComplex AdditiveIdentity => new(Complex.Zero);
public static TestComplex CreateChecked<TOther>(TOther value) where TOther : INumberBase<TOther> => new(Complex.CreateChecked(value));

// Override Object methods
public override bool Equals(object? obj) => obj is TestComplex other && Equals(other);
public override int GetHashCode() => _value.GetHashCode();
public override string ToString() => _value.ToString();

// IEquatable<TestComplex>
public bool Equals(TestComplex other) => _value.Equals(other._value);

// Operators
public static bool operator ==(TestComplex left, TestComplex right) => left.Equals(right);
public static bool operator !=(TestComplex left, TestComplex right) => !left.Equals(right);

// Required interface implementations not needed for tests - throw NotImplementedException
public string ToString(string? format, IFormatProvider? formatProvider) => throw new NotImplementedException();
public bool TryFormat(Span<char> destination, out int charsWritten, ReadOnlySpan<char> format, IFormatProvider? provider) => throw new NotImplementedException();
public static TestComplex Parse(string s, IFormatProvider? provider) => throw new NotImplementedException();
public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException();
public static TestComplex Parse(ReadOnlySpan<char> s, IFormatProvider? provider) => throw new NotImplementedException();
public static bool TryParse(ReadOnlySpan<char> s, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException();
public static TestComplex operator --(TestComplex value) => throw new NotImplementedException();
public static TestComplex operator ++(TestComplex value) => throw new NotImplementedException();
public static TestComplex MultiplicativeIdentity => new(Complex.One);
public static TestComplex operator -(TestComplex value) => new(-value._value);
public static TestComplex operator +(TestComplex value) => new(value._value);
public static bool IsCanonical(TestComplex value) => throw new NotImplementedException();
public static bool IsComplexNumber(TestComplex value) => throw new NotImplementedException();
public static bool IsEvenInteger(TestComplex value) => throw new NotImplementedException();
public static bool IsFinite(TestComplex value) => throw new NotImplementedException();
public static bool IsImaginaryNumber(TestComplex value) => throw new NotImplementedException();
public static bool IsInfinity(TestComplex value) => throw new NotImplementedException();
public static bool IsInteger(TestComplex value) => throw new NotImplementedException();
public static bool IsNaN(TestComplex value) => throw new NotImplementedException();
public static bool IsNegative(TestComplex value) => throw new NotImplementedException();
public static bool IsNegativeInfinity(TestComplex value) => throw new NotImplementedException();
public static bool IsNormal(TestComplex value) => throw new NotImplementedException();
public static bool IsOddInteger(TestComplex value) => throw new NotImplementedException();
public static bool IsPositive(TestComplex value) => throw new NotImplementedException();
public static bool IsPositiveInfinity(TestComplex value) => throw new NotImplementedException();
public static bool IsRealNumber(TestComplex value) => throw new NotImplementedException();
public static bool IsSubnormal(TestComplex value) => throw new NotImplementedException();
public static bool IsZero(TestComplex value) => throw new NotImplementedException();
public static TestComplex MaxMagnitude(TestComplex x, TestComplex y) => throw new NotImplementedException();
public static TestComplex MaxMagnitudeNumber(TestComplex x, TestComplex y) => throw new NotImplementedException();
public static TestComplex MinMagnitude(TestComplex x, TestComplex y) => throw new NotImplementedException();
public static TestComplex MinMagnitudeNumber(TestComplex x, TestComplex y) => throw new NotImplementedException();
public static TestComplex Parse(ReadOnlySpan<char> s, NumberStyles style, IFormatProvider? provider) => throw new NotImplementedException();
public static TestComplex Parse(string s, NumberStyles style, IFormatProvider? provider) => throw new NotImplementedException();
public static bool TryConvertFromSaturating<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
public static bool TryConvertFromTruncating<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
public static bool TryConvertFromChecked<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
public static bool TryConvertToChecked<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
public static bool TryConvertToSaturating<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
public static bool TryConvertToTruncating<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> => throw new NotImplementedException();
public static bool TryParse(ReadOnlySpan<char> s, NumberStyles style, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException();
public static bool TryParse([NotNullWhen(true)] string? s, NumberStyles style, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException();
public static TestComplex Cbrt(TestComplex x) => throw new NotImplementedException();
public static TestComplex Hypot(TestComplex x, TestComplex y) => throw new NotImplementedException();
public static TestComplex RootN(TestComplex x, int n) => throw new NotImplementedException();
}
}
Loading