|
3 | 3 |
|
4 | 4 | using System.Buffers; |
5 | 5 | using System.Collections.Generic; |
| 6 | +using System.Diagnostics.CodeAnalysis; |
| 7 | +using System.Globalization; |
6 | 8 | using System.Linq; |
7 | 9 | using System.Runtime.InteropServices; |
8 | 10 | using Xunit; |
@@ -1152,14 +1154,89 @@ public static void TensorStackTests() |
1152 | 1154 | Assert.Equal(40, resultTensor2[1, 1, 1]); |
1153 | 1155 | } |
1154 | 1156 |
|
1155 | | - [Fact] |
1156 | | - public static void TensorStdDevTests() |
| 1157 | + public static IEnumerable<object[]> StdDevFloatTestData() |
1157 | 1158 | { |
1158 | | - Tensor<float> t0 = Tensor.Create(Enumerable.Sequence<float>(0, 4, 1).ToArray(), lengths: [2, 2]); |
| 1159 | + // Test case where Abs doesn't change the result (real numbers, positive values) |
| 1160 | + yield return new object[] |
| 1161 | + { |
| 1162 | + new float[] { 0f, 1f, 2f, 3f }, |
| 1163 | + StdDev([0f, 1f, 2f, 3f]) |
| 1164 | + }; |
| 1165 | + |
| 1166 | + // Test case with all same values (should be 0) |
| 1167 | + yield return new object[] |
| 1168 | + { |
| 1169 | + new float[] { 1f, 1f, 1f, 1f }, |
| 1170 | + 0f |
| 1171 | + }; |
| 1172 | + |
| 1173 | + // Test case with negative values where Abs could matter (but doesn't for standard deviation) |
| 1174 | + yield return new object[] |
| 1175 | + { |
| 1176 | + new float[] { -2f, -1f, 1f, 2f }, |
| 1177 | + StdDev([-2f, -1f, 1f, 2f]) |
| 1178 | + }; |
| 1179 | + } |
| 1180 | + |
| 1181 | + public static IEnumerable<object[]> StdDevComplexTestData() |
| 1182 | + { |
| 1183 | + // Test case where Abs is critical (complex numbers) |
| 1184 | + yield return new object[] |
| 1185 | + { |
| 1186 | + new TestComplex[] { new(new(1, 2)), new(new(3, 4)) } |
| 1187 | + }; |
| 1188 | + |
| 1189 | + // Test case with purely imaginary numbers |
| 1190 | + yield return new object[] |
| 1191 | + { |
| 1192 | + new TestComplex[] { new(new(0, 1)), new(new(0, 2)), new(new(0, 3)) } |
| 1193 | + }; |
| 1194 | + |
| 1195 | + // Test case with purely real numbers (should behave like floats) |
| 1196 | + yield return new object[] |
| 1197 | + { |
| 1198 | + new TestComplex[] { new(new(1, 0)), new(new(2, 0)), new(new(3, 0)) } |
| 1199 | + }; |
| 1200 | + } |
| 1201 | + |
| 1202 | + [Theory, MemberData(nameof(StdDevFloatTestData))] |
| 1203 | + public static void TensorStdDevFloatTests(float[] data, float expectedStdDev) |
| 1204 | + { |
| 1205 | + var tensor = Tensor.Create(data); |
| 1206 | + |
| 1207 | + var tensorPrimitivesResult = TensorPrimitives.StdDev<float>(data); |
| 1208 | + var tensorResult = Tensor.StdDev(tensor.AsReadOnlyTensorSpan()); |
| 1209 | + |
| 1210 | + // Both should produce the same result |
| 1211 | + Assert.Equal(tensorPrimitivesResult, tensorResult, precision: 5); |
| 1212 | + Assert.Equal(expectedStdDev, tensorResult, precision: 5); |
| 1213 | + |
| 1214 | + // Test that non-contiguous calculations work with reshaped tensor |
| 1215 | + if (data.Length >= 4) |
| 1216 | + { |
| 1217 | + var reshapedTensor = Tensor.Create(data, lengths: [2, 2]); |
| 1218 | + var reshapedResult = Tensor.StdDev(reshapedTensor.AsReadOnlyTensorSpan()); |
| 1219 | + Assert.Equal(expectedStdDev, reshapedResult, precision: 5); |
| 1220 | + } |
| 1221 | + } |
1159 | 1222 |
|
1160 | | - Assert.Equal(StdDev([0, 1, 2, 3]), Tensor.StdDev<float>(t0), .1); |
| 1223 | + [Theory, MemberData(nameof(StdDevComplexTestData))] |
| 1224 | + public static void TensorStdDevComplexTests(TestComplex[] data) |
| 1225 | + { |
| 1226 | + var tensor = Tensor.Create(data); |
| 1227 | + |
| 1228 | + var tensorPrimitivesResult = TensorPrimitives.StdDev<TestComplex>(data); |
| 1229 | + var tensorResult = Tensor.StdDev(tensor.AsReadOnlyTensorSpan()); |
| 1230 | + |
| 1231 | + // Both should produce the same result - this is the key test for the fix |
| 1232 | + Assert.Equal(tensorPrimitivesResult.Real, tensorResult.Real, precision: 10); |
| 1233 | + Assert.Equal(tensorPrimitivesResult.Imaginary, tensorResult.Imaginary, precision: 10); |
| 1234 | + } |
1161 | 1235 |
|
1162 | | - // Test that non-contiguous calculations work |
| 1236 | + [Fact] |
| 1237 | + public static void TensorStdDevNonContiguousTests() |
| 1238 | + { |
| 1239 | + // Test that non-contiguous calculations work for float tensors |
1163 | 1240 | Tensor<float> fourByFour = Tensor.CreateFromShape<float>([4, 4]); |
1164 | 1241 | fourByFour[[0, 0]] = 1f; |
1165 | 1242 | fourByFour[[0, 1]] = 1f; |
@@ -3190,4 +3267,91 @@ public static void ToStringZeroDataTest() |
3190 | 3267 | Assert.Equal(expected, tensor.ToString([2, 0, 2])); |
3191 | 3268 | } |
3192 | 3269 | } |
| 3270 | + |
| 3271 | + /// <summary> |
| 3272 | + /// Test complex number type that implements IRootFunctions for testing consistency between |
| 3273 | + /// TensorPrimitives.StdDev and Tensor.StdDev |
| 3274 | + /// </summary> |
| 3275 | + public readonly struct TestComplex(Complex value) : IRootFunctions<TestComplex>, IEquatable<TestComplex> |
| 3276 | + { |
| 3277 | + private readonly Complex _value = value; |
| 3278 | + |
| 3279 | + public double Real => _value.Real; |
| 3280 | + public double Imaginary => _value.Imaginary; |
| 3281 | + |
| 3282 | + public static TestComplex One => new(Complex.One); |
| 3283 | + public static int Radix => 2; |
| 3284 | + public static TestComplex Zero => new(Complex.Zero); |
| 3285 | + public static TestComplex E => new(new(Math.E, 0)); |
| 3286 | + public static TestComplex Pi => new(new(Math.PI, 0)); |
| 3287 | + public static TestComplex Tau => new(new(Math.Tau, 0)); |
| 3288 | + |
| 3289 | + public static TestComplex operator +(TestComplex left, TestComplex right) => new(left._value + right._value); |
| 3290 | + public static TestComplex operator *(TestComplex left, TestComplex right) => new(left._value * right._value); |
| 3291 | + public static TestComplex operator /(TestComplex left, TestComplex right) => new(left._value / right._value); |
| 3292 | + public static TestComplex operator -(TestComplex left, TestComplex right) => new(left._value - right._value); |
| 3293 | + public static TestComplex Sqrt(TestComplex x) => new(Complex.Sqrt(x._value)); |
| 3294 | + public static TestComplex Abs(TestComplex value) => new(new(Complex.Abs(value._value), 0)); |
| 3295 | + public static TestComplex AdditiveIdentity => new(Complex.Zero); |
| 3296 | + public static TestComplex CreateChecked<TOther>(TOther value) where TOther : INumberBase<TOther> => new(Complex.CreateChecked(value)); |
| 3297 | + |
| 3298 | + // Override Object methods |
| 3299 | + public override bool Equals(object? obj) => obj is TestComplex other && Equals(other); |
| 3300 | + public override int GetHashCode() => _value.GetHashCode(); |
| 3301 | + public override string ToString() => _value.ToString(); |
| 3302 | + |
| 3303 | + // IEquatable<TestComplex> |
| 3304 | + public bool Equals(TestComplex other) => _value.Equals(other._value); |
| 3305 | + |
| 3306 | + // Operators |
| 3307 | + public static bool operator ==(TestComplex left, TestComplex right) => left.Equals(right); |
| 3308 | + public static bool operator !=(TestComplex left, TestComplex right) => !left.Equals(right); |
| 3309 | + |
| 3310 | + // Required interface implementations not needed for tests - throw NotImplementedException |
| 3311 | + public string ToString(string? format, IFormatProvider? formatProvider) => throw new NotImplementedException(); |
| 3312 | + public bool TryFormat(Span<char> destination, out int charsWritten, ReadOnlySpan<char> format, IFormatProvider? provider) => throw new NotImplementedException(); |
| 3313 | + public static TestComplex Parse(string s, IFormatProvider? provider) => throw new NotImplementedException(); |
| 3314 | + public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException(); |
| 3315 | + public static TestComplex Parse(ReadOnlySpan<char> s, IFormatProvider? provider) => throw new NotImplementedException(); |
| 3316 | + public static bool TryParse(ReadOnlySpan<char> s, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException(); |
| 3317 | + public static TestComplex operator --(TestComplex value) => throw new NotImplementedException(); |
| 3318 | + public static TestComplex operator ++(TestComplex value) => throw new NotImplementedException(); |
| 3319 | + public static TestComplex MultiplicativeIdentity => new(Complex.One); |
| 3320 | + public static TestComplex operator -(TestComplex value) => new(-value._value); |
| 3321 | + public static TestComplex operator +(TestComplex value) => new(value._value); |
| 3322 | + public static bool IsCanonical(TestComplex value) => throw new NotImplementedException(); |
| 3323 | + public static bool IsComplexNumber(TestComplex value) => throw new NotImplementedException(); |
| 3324 | + public static bool IsEvenInteger(TestComplex value) => throw new NotImplementedException(); |
| 3325 | + public static bool IsFinite(TestComplex value) => throw new NotImplementedException(); |
| 3326 | + public static bool IsImaginaryNumber(TestComplex value) => throw new NotImplementedException(); |
| 3327 | + public static bool IsInfinity(TestComplex value) => throw new NotImplementedException(); |
| 3328 | + public static bool IsInteger(TestComplex value) => throw new NotImplementedException(); |
| 3329 | + public static bool IsNaN(TestComplex value) => throw new NotImplementedException(); |
| 3330 | + public static bool IsNegative(TestComplex value) => throw new NotImplementedException(); |
| 3331 | + public static bool IsNegativeInfinity(TestComplex value) => throw new NotImplementedException(); |
| 3332 | + public static bool IsNormal(TestComplex value) => throw new NotImplementedException(); |
| 3333 | + public static bool IsOddInteger(TestComplex value) => throw new NotImplementedException(); |
| 3334 | + public static bool IsPositive(TestComplex value) => throw new NotImplementedException(); |
| 3335 | + public static bool IsPositiveInfinity(TestComplex value) => throw new NotImplementedException(); |
| 3336 | + public static bool IsRealNumber(TestComplex value) => throw new NotImplementedException(); |
| 3337 | + public static bool IsSubnormal(TestComplex value) => throw new NotImplementedException(); |
| 3338 | + public static bool IsZero(TestComplex value) => throw new NotImplementedException(); |
| 3339 | + public static TestComplex MaxMagnitude(TestComplex x, TestComplex y) => throw new NotImplementedException(); |
| 3340 | + public static TestComplex MaxMagnitudeNumber(TestComplex x, TestComplex y) => throw new NotImplementedException(); |
| 3341 | + public static TestComplex MinMagnitude(TestComplex x, TestComplex y) => throw new NotImplementedException(); |
| 3342 | + public static TestComplex MinMagnitudeNumber(TestComplex x, TestComplex y) => throw new NotImplementedException(); |
| 3343 | + public static TestComplex Parse(ReadOnlySpan<char> s, NumberStyles style, IFormatProvider? provider) => throw new NotImplementedException(); |
| 3344 | + public static TestComplex Parse(string s, NumberStyles style, IFormatProvider? provider) => throw new NotImplementedException(); |
| 3345 | + public static bool TryConvertFromSaturating<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> => throw new NotImplementedException(); |
| 3346 | + public static bool TryConvertFromTruncating<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> => throw new NotImplementedException(); |
| 3347 | + public static bool TryConvertFromChecked<TOther>(TOther value, out TestComplex result) where TOther : INumberBase<TOther> => throw new NotImplementedException(); |
| 3348 | + public static bool TryConvertToChecked<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> => throw new NotImplementedException(); |
| 3349 | + public static bool TryConvertToSaturating<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> => throw new NotImplementedException(); |
| 3350 | + public static bool TryConvertToTruncating<TOther>(TestComplex value, [MaybeNullWhen(false)] out TOther result) where TOther : INumberBase<TOther> => throw new NotImplementedException(); |
| 3351 | + public static bool TryParse(ReadOnlySpan<char> s, NumberStyles style, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException(); |
| 3352 | + public static bool TryParse([NotNullWhen(true)] string? s, NumberStyles style, IFormatProvider? provider, out TestComplex result) => throw new NotImplementedException(); |
| 3353 | + public static TestComplex Cbrt(TestComplex x) => throw new NotImplementedException(); |
| 3354 | + public static TestComplex Hypot(TestComplex x, TestComplex y) => throw new NotImplementedException(); |
| 3355 | + public static TestComplex RootN(TestComplex x, int n) => throw new NotImplementedException(); |
| 3356 | + } |
3193 | 3357 | } |
0 commit comments