Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4810,10 +4810,9 @@ public static T StdDev<T>(in ReadOnlyTensorSpan<T> x)
{
T mean = Average(x);
T result = T.AdditiveIdentity;
TensorOperation.Invoke<TensorOperation.SumOfSquaredDifferences<T>, T, T>(x, mean, ref result);
TensorOperation.Invoke<TensorOperation.SumOfSquaredMagnitudeDifferences<T>, T, T>(x, mean, ref result);
T variance = result / T.CreateChecked(x.FlattenedLength);
return T.Sqrt(variance);

}
#endregion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2103,6 +2103,25 @@ public static void Invoke(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destinat
}
}

public readonly struct SumOfSquaredMagnitudeDifferences<T>
: IBinaryOperation_Tensor_Scalar<T, T>
where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>, ISubtractionOperators<T, T, T>, INumberBase<T>
{
public static void Invoke(ref readonly T x, T y, ref T destination)
{
T diff = T.Abs(x - y);
destination += diff * diff;
}
public static void Invoke(ReadOnlySpan<T> x, T y, Span<T> destination)
{
for (int i = 0; i < x.Length; i++)
{
T diff = T.Abs(x[i] - y);
destination[i] = diff * diff;
}
}
}

public readonly struct SumOfSquares<T>
: IUnaryReduction_Tensor<T, T>
where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>
Expand Down
103 changes: 103 additions & 0 deletions src/libraries/System.Numerics.Tensors/tests/TensorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Runtime.InteropServices;
using Xunit;
Expand Down Expand Up @@ -1169,6 +1171,21 @@ public static void TensorStdDevTests()
Assert.Equal(0f, Tensor.StdDev(upperLeft));
}

[Fact]
public static void TensorStdDevComplexTests()
{
// Test case from issue: Tensor.StdDev vs TensorPrimitives.StdDev differences for Complex input
var arr = new TestComplex[] { new(new(1, 2)), new(new(3, 4)) };
var tensor = Tensor.Create(arr);

var tensorPrimitivesResult = TensorPrimitives.StdDev<TestComplex>(arr);
var tensorResult = Tensor.StdDev(tensor.AsReadOnlyTensorSpan());

// Both should produce the same result
Assert.Equal(tensorPrimitivesResult.Real, tensorResult.Real, precision: 10);
Assert.Equal(tensorPrimitivesResult.Imaginary, tensorResult.Imaginary, precision: 10);
}

[Fact]
public static void TensorSumTests()
{
Expand Down Expand Up @@ -3190,4 +3207,90 @@ public static void ToStringZeroDataTest()
Assert.Equal(expected, tensor.ToString([2, 0, 2]));
}
}

/// <summary>
/// Test complex number type that implements IRootFunctions for testing StdDev consistency between
/// TensorPrimitives.StdDev and Tensor.StdDev
/// </summary>
public readonly struct TestComplex(Complex value) : IRootFunctions<TestComplex>, IEquatable<TestComplex>
{
private readonly Complex _value = value;

public double Real => _value.Real;
public double Imaginary => _value.Imaginary;

public static TestComplex operator +(TestComplex left, TestComplex right) => new(left._value + right._value);
public static TestComplex operator *(TestComplex left, TestComplex right) => new(left._value * right._value);
public static TestComplex operator /(TestComplex left, TestComplex right) => new(left._value / right._value);
public static TestComplex operator -(TestComplex left, TestComplex right) => new(left._value - right._value);
public static TestComplex Sqrt(TestComplex x) => new(Complex.Sqrt(x._value));
public static TestComplex Abs(TestComplex value) => new(new(Complex.Abs(value._value), 0));
public static TestComplex AdditiveIdentity => new(Complex.Zero);
public static TestComplex CreateChecked<TOther>(TOther value) where TOther : INumberBase<TOther> => new(Complex.CreateChecked(value));

// Override Object methods
public override bool Equals(object? obj) => obj is TestComplex other && Equals(other);
public override int GetHashCode() => _value.GetHashCode();
public override string ToString() => _value.ToString();

// IEquatable<TestComplex>
public bool Equals(TestComplex other) => _value.Equals(other._value);

// Operators
public static bool operator ==(TestComplex left, TestComplex right) => left.Equals(right);
public static bool operator !=(TestComplex left, TestComplex right) => !left.Equals(right);

// Required interface implementations - minimal stubs
public string ToString(string? format, IFormatProvider? formatProvider) => _value.ToString(format, formatProvider);
public bool TryFormat(Span<char> destination, out int charsWritten, ReadOnlySpan<char> format, IFormatProvider? provider) => _value.TryFormat(destination, out charsWritten, format, provider);
public static TestComplex Parse(string s, IFormatProvider? provider) => new(Complex.Parse(s, provider));
public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, out TestComplex result) { var success = Complex.TryParse(s, provider, out var c); result = new(c); return success; }
public static TestComplex Parse(ReadOnlySpan<char> s, IFormatProvider? provider) => new(Complex.Parse(s, provider));
public static bool TryParse(ReadOnlySpan<char> s, IFormatProvider? provider, out TestComplex result) { var success = Complex.TryParse(s, provider, out var c); result = new(c); return success; }
public static TestComplex operator --(TestComplex value) => new(value._value - Complex.One);
public static TestComplex operator ++(TestComplex value) => new(value._value + Complex.One);
public static TestComplex MultiplicativeIdentity => new(Complex.One);
public static TestComplex operator -(TestComplex value) => new(-value._value);
public static TestComplex operator +(TestComplex value) => new(value._value);
public static bool IsCanonical(TestComplex value) => true;
public static bool IsComplexNumber(TestComplex value) => true;
public static bool IsEvenInteger(TestComplex value) => false;
public static bool IsFinite(TestComplex value) => Complex.IsFinite(value._value);
public static bool IsImaginaryNumber(TestComplex value) => value._value.Real == 0 && value._value.Imaginary != 0;
public static bool IsInfinity(TestComplex value) => Complex.IsInfinity(value._value);
public static bool IsInteger(TestComplex value) => false;
public static bool IsNaN(TestComplex value) => Complex.IsNaN(value._value);
public static bool IsNegative(TestComplex value) => false;
public static bool IsNegativeInfinity(TestComplex value) => false;
public static bool IsNormal(TestComplex value) => true;
public static bool IsOddInteger(TestComplex value) => false;
public static bool IsPositive(TestComplex value) => false;
public static bool IsPositiveInfinity(TestComplex value) => false;
public static bool IsRealNumber(TestComplex value) => value._value.Imaginary == 0;
public static bool IsSubnormal(TestComplex value) => false;
public static bool IsZero(TestComplex value) => value._value == Complex.Zero;
public static TestComplex MaxMagnitude(TestComplex x, TestComplex y) => Complex.Abs(x._value) >= Complex.Abs(y._value) ? x : y;
public static TestComplex MaxMagnitudeNumber(TestComplex x, TestComplex y) => MaxMagnitude(x, y);
public static TestComplex MinMagnitude(TestComplex x, TestComplex y) => Complex.Abs(x._value) <= Complex.Abs(y._value) ? x : y;
public static TestComplex MinMagnitudeNumber(TestComplex x, TestComplex y) => MinMagnitude(x, y);
public static TestComplex Parse(ReadOnlySpan<char> s, NumberStyles style, IFormatProvider? provider) => new(Complex.Parse(s, style, provider));
public static TestComplex Parse(string s, NumberStyles style, IFormatProvider? provider) => new(Complex.Parse(s, style, provider));
public static bool TryConvertFromSaturating<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> { result = default; return false; }
public static bool TryConvertFromTruncating<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> { result = default; return false; }
public static bool TryConvertFromChecked<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> { result = default; return false; }
public static bool TryConvertToChecked<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> { result = default; return false; }
public static bool TryConvertToSaturating<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> { result = default; return false; }
public static bool TryConvertToTruncating<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> { result = default; return false; }
public static bool TryParse(ReadOnlySpan<char> s, NumberStyles style, IFormatProvider? provider, out TestComplex result) { var success = Complex.TryParse(s, style, provider, out var c); result = new(c); return success; }
public static bool TryParse([NotNullWhen(true)] string? s, NumberStyles style, IFormatProvider? provider, out TestComplex result) { var success = Complex.TryParse(s, style, provider, out var c); result = new(c); return success; }
public static TestComplex One => new(Complex.One);
public static int Radix => 2;
public static TestComplex Zero => new(Complex.Zero);
public static TestComplex E => new(new(Math.E, 0));
public static TestComplex Pi => new(new(Math.PI, 0));
public static TestComplex Tau => new(new(Math.Tau, 0));
public static TestComplex Cbrt(TestComplex x) => new(Complex.Pow(x._value, 1.0/3.0));
public static TestComplex Hypot(TestComplex x, TestComplex y) => new(new(Complex.Abs(Complex.Sqrt(x._value * x._value + y._value * y._value)), 0));
public static TestComplex RootN(TestComplex x, int n) => new(Complex.Pow(x._value, 1.0/n));
}
}
Loading