From 021de58f7431a671f64e3bd6d2253c7a16fd9a6f Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 11 Jun 2024 15:51:50 -0400 Subject: [PATCH 1/2] Add TensorPrimitives.HammingDistance and friends --- .../ref/System.Numerics.Tensors.netcore.cs | 3 + .../src/System.Numerics.Tensors.csproj | 1 + .../TensorPrimitives.HammingDistance.cs | 191 ++++++++++++++++++ .../netcore/TensorPrimitives.PopCount.cs | 14 ++ .../tests/TensorPrimitives.Generic.cs | 106 ++++++++++ 5 files changed, 315 insertions(+) create mode 100644 src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.HammingDistance.cs diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs index 7897a41ee6ce2..6572cbc5d34bf 100644 --- a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs @@ -407,6 +407,8 @@ public static void Floor(System.ReadOnlySpan x, System.Span destination public static void FusedMultiplyAdd(System.ReadOnlySpan x, System.ReadOnlySpan y, System.ReadOnlySpan addend, System.Span destination) where T : System.Numerics.IFloatingPointIeee754 { } public static void FusedMultiplyAdd(System.ReadOnlySpan x, System.ReadOnlySpan y, T addend, System.Span destination) where T : System.Numerics.IFloatingPointIeee754 { } public static void FusedMultiplyAdd(System.ReadOnlySpan x, T y, System.ReadOnlySpan addend, System.Span destination) where T : System.Numerics.IFloatingPointIeee754 { } + public static int HammingDistance(System.ReadOnlySpan x, System.ReadOnlySpan y) { throw null; } + public static long HammingBitDistance(System.ReadOnlySpan x, System.ReadOnlySpan y) where T : IBinaryInteger { throw null; } public static void Hypot(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.IRootFunctions { } public static void Ieee754Remainder(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.IFloatingPointIeee754 { } public static void Ieee754Remainder(System.ReadOnlySpan x, T y, System.Span destination) where T : System.Numerics.IFloatingPointIeee754 { } @@ -457,6 +459,7 @@ public static void Multiply(System.ReadOnlySpan x, T y, System.Span des public static void Negate(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IUnaryNegationOperators { } public static T Norm(System.ReadOnlySpan x) where T : System.Numerics.IRootFunctions { throw null; } public static void OnesComplement(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IBitwiseOperators { } + public static long PopCount(System.ReadOnlySpan x) where T : System.Numerics.IBinaryInteger { throw null; } public static void PopCount(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IBinaryInteger { } public static void Pow(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.IPowerFunctions { } public static void Pow(System.ReadOnlySpan x, T y, System.Span destination) where T : System.Numerics.IPowerFunctions { } diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj index 0044dbdd1a4cf..fd078d40a06e8 100644 --- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj @@ -80,6 +80,7 @@ + diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.HammingDistance.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.HammingDistance.cs new file mode 100644 index 0000000000000..50e3b13a3bfea --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.HammingDistance.cs @@ -0,0 +1,191 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; + +namespace System.Numerics.Tensors +{ + public static partial class TensorPrimitives + { + /// Computes the bitwise Hamming distance between two equal-length tensors of values. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The number of bits that differ between the two spans. + /// Length of must be same as length of . + /// and must not be empty. + public static long HammingBitDistance(ReadOnlySpan x, ReadOnlySpan y) where T : IBinaryInteger + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + long count = 0; + for (int i = 0; i < x.Length; i++) + { + count += long.CreateTruncating(T.PopCount(x[i] ^ y[i])); + } + + return count; + } + + /// Computes the Hamming distance between two equal-length tensors of values. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The number of elements that differ between the two spans. + /// Length of must be same as length of . + /// and must not be empty. + /// + /// + /// This method computes the number of locations i where !EqualityComparer>T<.Default.Equal(x[i], y[i]). + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int HammingDistance(ReadOnlySpan x, ReadOnlySpan y) + { + if (typeof(T) == typeof(char)) + { + // Special-case char, as it's reasonable for someone to want to use HammingDistance on strings, + // and we want that accelerated. This can be removed if/when VectorXx supports char. + return CountUnequalElements( + MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref MemoryMarshal.GetReference(x)), x.Length), + MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref MemoryMarshal.GetReference(y)), y.Length)); + } + + return CountUnequalElements(x, y); + } + + /// Counts the number of elements that are pair-wise different between the two spans. + private static int CountUnequalElements(ReadOnlySpan x, ReadOnlySpan y) + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + // TODO: This has a very similar structure to CosineSimilarity, which is also open-coded rather than + // using a shared routine plus operator, as we don't have one implemented that exactly fits. We should + // look at refactoring these to share the core logic. + + int count = 0; + if (Vector128.IsHardwareAccelerated && Vector128.IsSupported && x.Length >= Vector128.Count) + { + if (Vector256.IsHardwareAccelerated && Vector256.IsSupported && x.Length >= Vector256.Count) + { + if (Vector512.IsHardwareAccelerated && Vector512.IsSupported && x.Length >= Vector512.Count) + { + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T yRef = ref MemoryMarshal.GetReference(y); + + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = 0; + do + { + Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)i); + Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)i); + + count += BitOperations.PopCount((~Vector512.Equals(xVec, yVec)).ExtractMostSignificantBits()); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512.Count)); + + Vector512 remainderMask = CreateRemainderMaskVector512(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + count += BitOperations.PopCount((~Vector512.Equals(xVec, yVec)).ExtractMostSignificantBits()); + } + } + else + { + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T yRef = ref MemoryMarshal.GetReference(y); + + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = 0; + do + { + Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)i); + Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)i); + + count += BitOperations.PopCount((~Vector256.Equals(xVec, yVec)).ExtractMostSignificantBits()); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256.Count)); + + Vector256 remainderMask = CreateRemainderMaskVector256(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + count += BitOperations.PopCount((~Vector256.Equals(xVec, yVec)).ExtractMostSignificantBits()); + } + } + } + else + { + ref T xRef = ref MemoryMarshal.GetReference(x); + ref T yRef = ref MemoryMarshal.GetReference(y); + + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = 0; + do + { + Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)i); + Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)i); + + count += BitOperations.PopCount((~Vector128.Equals(xVec, yVec)).ExtractMostSignificantBits()); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128.Count)); + + Vector128 remainderMask = CreateRemainderMaskVector128(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + count += BitOperations.PopCount((~Vector128.Equals(xVec, yVec)).ExtractMostSignificantBits()); + } + } + } + else + { + for (int i = 0; i < x.Length; i++) + { + if (!EqualityComparer.Default.Equals(x[i], y[i])) + { + count++; + } + } + } + + Debug.Assert(count >= 0 && count <= x.Length, $"Expected count to be in the range [0, {x.Length}], got {count}."); + return count; + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.PopCount.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.PopCount.cs index 8bc90f3c6968c..da45bd720b129 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.PopCount.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.PopCount.cs @@ -13,6 +13,20 @@ namespace System.Numerics.Tensors { public static partial class TensorPrimitives { + /// Computes the population count of all elements in the specified tensor. + /// The tensor, represented as a span. + /// The sum of the number of bits set in each element in . + public static long PopCount(ReadOnlySpan x) where T : IBinaryInteger + { + long count = 0; + for (int i = 0; i < x.Length; i++) + { + count += long.CreateTruncating(T.PopCount(x[i])); + } + + return count; + } + /// Computes the element-wise population count of numbers in the specified tensor. /// The tensor, represented as a span. /// The destination tensor, represented as a span. diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs index 2768a03a07047..cfe4fd9896977 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs @@ -2087,6 +2087,81 @@ public void CopySign_ThrowsForOverlapppingInputsWithOutputs() AssertExtensions.Throws("destination", () => TensorPrimitives.CopySign(array.AsSpan(1, 2), default(T), array.AsSpan(2, 2))); } #endregion + + #region HammingBitDistance + [Fact] + public void HammingBitDistance_ThrowsForMismatchedLengths() + { + Assert.Throws(() => TensorPrimitives.HammingBitDistance(new int[1], new int[2])); + Assert.Throws(() => TensorPrimitives.HammingBitDistance(new int[2], new int[1])); + } + + [Fact] + public void HammingBitDistance_AllLengths() + { + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + long expected = 0; + for (int i = 0; i < tensorLength; i++) + { + expected += long.CreateTruncating(T.PopCount(x[i] ^ y[i])); + } + + Assert.Equal(expected, TensorPrimitives.HammingBitDistance(x.Span, y.Span)); + }); + } + + [Fact] + public void HammingBitDistance_KnownValues() + { + T value42 = T.CreateTruncating(42); + T value84 = T.CreateTruncating(84); + + T[] values1 = new T[100]; + T[] values2 = new T[100]; + + Array.Fill(values1, value42); + Array.Fill(values2, value84); + + Assert.Equal(0, TensorPrimitives.HammingBitDistance(values1, values1)); + Assert.Equal(600, TensorPrimitives.HammingBitDistance(values1, values2)); + Assert.Equal(0, TensorPrimitives.HammingBitDistance(values2, values2)); + } + #endregion + + #region PopCount + [Fact] + public void PopCount_AllLengths() + { + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + long expected = 0; + for (int i = 0; i < tensorLength; i++) + { + expected += long.CreateTruncating(T.PopCount(x[i])); + } + + Assert.Equal(expected, TensorPrimitives.PopCount(x.Span)); + }); + } + + [Fact] + public void PopCount_KnownValues() + { + T[] values = new T[255]; + for (int i = 0; i < values.Length; i++) + { + values[i] = T.CreateTruncating(i); + } + + Assert.Equal(1016, TensorPrimitives.PopCount(values)); + } + #endregion } public unsafe abstract class GenericNumberTensorPrimitivesTests : TensorPrimitivesTests @@ -2269,5 +2344,36 @@ public void ScalarSpanDestination_ThrowsForOverlapppingInputsWithOutputs(ScalarS AssertExtensions.Throws("destination", () => tensorPrimitivesMethod(default, array.AsSpan(4, 2), array.AsSpan(5, 2))); } #endregion + + #region HammingDistance + [Fact] + public void HammingDistance_ThrowsForMismatchedLengths() + { + Assert.Throws(() => TensorPrimitives.HammingDistance(new int[1], new int[2])); + Assert.Throws(() => TensorPrimitives.HammingDistance(new int[2], new int[1])); + } + + [Fact] + public void HammingDistance_AllLengths() + { + Assert.All(Helpers.TensorLengthsIncluding0, tensorLength => + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + int expected = 0; + ReadOnlySpan xSpan = x, ySpan = y; + for (int i = 0; i < xSpan.Length; i++) + { + if (xSpan[i] != ySpan[i]) + { + expected++; + } + } + + Assert.Equal(expected, TensorPrimitives.HammingDistance(x, y)); + }); + } + #endregion } } From 358d38d581c1778d42ba380265d2f0b34a21b979 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 12 Jun 2024 08:26:39 -0400 Subject: [PATCH 2/2] Address PR feedback --- .../TensorPrimitives.HammingDistance.cs | 13 +++++++++- .../System.Numerics.Tensors.Net8.Tests.csproj | 1 + .../System.Numerics.Tensors.Tests.csproj | 1 + .../tests/TensorPrimitivesTests.Reference.cs | 25 +++++++++++++++++++ 4 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.Reference.cs diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.HammingDistance.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.HammingDistance.cs index 50e3b13a3bfea..c38b60e5f0ac3 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.HammingDistance.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.HammingDistance.cs @@ -173,7 +173,7 @@ private static int CountUnequalElements(ReadOnlySpan x, ReadOnlySpan y) } } } - else + else if (typeof(T).IsValueType) { for (int i = 0; i < x.Length; i++) { @@ -183,6 +183,17 @@ private static int CountUnequalElements(ReadOnlySpan x, ReadOnlySpan y) } } } + else + { + EqualityComparer comparer = EqualityComparer.Default; + for (int i = 0; i < x.Length; i++) + { + if (!comparer.Equals(x[i], y[i])) + { + count++; + } + } + } Debug.Assert(count >= 0 && count <= x.Length, $"Expected count to be in the range [0, {x.Length}], got {count}."); return count; diff --git a/src/libraries/System.Numerics.Tensors/tests/Net8Tests/System.Numerics.Tensors.Net8.Tests.csproj b/src/libraries/System.Numerics.Tensors/tests/Net8Tests/System.Numerics.Tensors.Net8.Tests.csproj index 3b8f867b355c0..1f8f82a11dc84 100644 --- a/src/libraries/System.Numerics.Tensors/tests/Net8Tests/System.Numerics.Tensors.Net8.Tests.csproj +++ b/src/libraries/System.Numerics.Tensors/tests/Net8Tests/System.Numerics.Tensors.Net8.Tests.csproj @@ -17,6 +17,7 @@ + diff --git a/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj b/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj index 4b31efa2aaef1..136e487d70d0a 100644 --- a/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj +++ b/src/libraries/System.Numerics.Tensors/tests/System.Numerics.Tensors.Tests.csproj @@ -24,6 +24,7 @@ + diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.Reference.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.Reference.cs new file mode 100644 index 0000000000000..6dad877b88418 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.Reference.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace System.Numerics.Tensors.Tests +{ + [ActiveIssue("https://github.com/dotnet/runtime/issues/97295", typeof(PlatformDetection), nameof(PlatformDetection.IsMonoRuntime), nameof(PlatformDetection.IsNotMonoInterpreter))] + public class ReferenceTensorPrimitivesTests + { + // The 99% case for TensorPrimitives is working with value type Ts, and the rest of the tests are optimized for that. + // These tests provide additional coverage for when T is a reference type. + + [Fact] + public void HammingDistance_ValidateReferenceType() + { + Assert.Equal(0, TensorPrimitives.HammingDistance(Array.Empty(), Array.Empty())); + Assert.Equal(1, TensorPrimitives.HammingDistance(["a"], ["b"])); + Assert.Equal(2, TensorPrimitives.HammingDistance(["a", "b", "c"], ["a", "c", "b"])); + Assert.Equal(2, TensorPrimitives.HammingDistance(["a", "b", "c"], ["a", "c", "b"])); + Assert.Equal(0, TensorPrimitives.HammingDistance(["a", "b", "c"], ["a", "b", "c"])); + Assert.Throws(() => TensorPrimitives.HammingDistance(["a", "b"], ["a", "b", "c"])); + } + } +}