Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorized MemoryExtensions.CommonPrefixLength #68210

Merged
merged 4 commits into from
Apr 27, 2022

Conversation

gfoidl
Copy link
Member

@gfoidl gfoidl commented Apr 19, 2022

Functional implementation of the new API got added in #67929
This PR vectorizes the implementation.

|     Method | LengthAndIndexOfMatch |      Mean |     Error |    StdDev | Ratio | RatioSD |
|----------- |---------------------- |----------:|----------:|----------:|------:|--------:|
|    Default |               (10, 5) |  5.432 ns | 0.1167 ns | 0.0975 ns |  1.00 |    0.00 |
| PR_Unroll4 |               (10, 5) |  4.601 ns | 0.0991 ns | 0.0927 ns |  0.85 |    0.02 |
|            |                       |           |           |           |       |         |
|    Default |               (10, 9) |  7.185 ns | 0.1153 ns | 0.1078 ns |  1.00 |    0.00 |
| PR_Unroll4 |               (10, 9) |  5.813 ns | 0.1311 ns | 0.1095 ns |  0.81 |    0.02 |
|            |                       |           |           |           |       |         |
|    Default |             (100, 16) | 11.043 ns | 0.2307 ns | 0.2158 ns |  1.00 |    0.00 |
| PR_Unroll4 |             (100, 16) |  4.112 ns | 0.0898 ns | 0.0840 ns |  0.37 |    0.01 |
|            |                       |           |           |           |       |         |
|    Default |             (100, 99) | 66.556 ns | 0.8823 ns | 0.8253 ns |  1.00 |    0.00 |
| PR_Unroll4 |             (100, 99) |  8.013 ns | 0.1776 ns | 0.1974 ns |  0.12 |    0.00 |
|            |                       |           |           |           |       |         |
|    Default |              (20, 13) |  9.158 ns | 0.1724 ns | 0.1613 ns |  1.00 |    0.00 |
| PR_Unroll4 |              (20, 13) |  3.134 ns | 0.0843 ns | 0.0789 ns |  0.34 |    0.01 |
|            |                       |           |           |           |       |         |
|    Default |                (3, 2) |  4.505 ns | 0.0481 ns | 0.0427 ns |  1.00 |    0.00 |
| PR_Unroll4 |                (3, 2) |  3.658 ns | 0.0733 ns | 0.0686 ns |  0.81 |    0.02 |

Note: I tried different unroll-levels, 4 gave best results.

Benchmarks aren't in https://github.com/dotnet/performance
Should these be added? Or is it covered via usage in other parts, or too niche to have it covered?

Benchmark code
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using BenchmarkDotNet.Attributes;

namespace ConsoleApp3;

//[ShortRunJob]
//[DisassemblyDiagnoser]
public class CommonPrefixLengthBenchmark
{
    public static void Run()
    {
        CommonPrefixLengthBenchmark bench = new();
        bench.Setup();
        Console.WriteLine(bench.Default());
        Console.WriteLine(bench.PR_Unroll4());
    }

    public record Args(int Length, int IndexOfMatch)
    {
        public override string ToString() => $"({this.Length}, {this.IndexOfMatch})";
    }

    public static IEnumerable<Args> Arguments()
    {
        yield return new Args(3, 2);
        yield return new Args(10, 5);
        yield return new Args(10, 9);
        yield return new Args(20, 13);
        yield return new Args(100, 16);
        yield return new Args(100, 99);
    }

    [ParamsSource(nameof(Arguments))]
    public Args LengthAndIndexOfMatch { get; set; } = new Args(3, 2);

    private byte[]? _arr0;
    private byte[]? _arr1;

    [GlobalSetup]
    public void Setup()
    {
        _arr0 = new byte[this.LengthAndIndexOfMatch.Length];
        Random.Shared.NextBytes(_arr0);

        _arr1 = _arr0.AsSpan().ToArray();
        _arr1[this.LengthAndIndexOfMatch.IndexOfMatch] = 0xff;
    }

    [Benchmark(Baseline = true)]
    public int Default() => SpanHelpers.CommonPrefixLength<byte>(_arr0, _arr1);

    [Benchmark]
    public int PR_Unroll4() => SpanHelpers.CommonPrefixLength_PRUnroll4<byte>(_arr0, _arr1);
}

public static class SpanHelpers
{
    public static int CommonPrefixLength<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T> other)
    {
        // Shrink one of the spans if necessary to ensure they're both the same length. We can then iterate until
        // the Length of one of them and at least have bounds checks removed from that one.
        if (other.Length > span.Length)
        {
            other = other.Slice(0, span.Length);
        }
        else if (span.Length > other.Length)
        {
            span = span.Slice(0, other.Length);
        }
        Debug.Assert(span.Length == other.Length);

        // Find the first element pairwise that is not equal, and return its index as the length
        // of the sequence before it that matches.
        for (int i = 0; i < span.Length; i++)
        {
            if (!EqualityComparer<T>.Default.Equals(span[i], other[i]))
            {
                return i;
            }
        }

        return span.Length;
    }

    public static int CommonPrefixLength_PRUnroll4<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T> other)
    {
        if (typeof(T) == typeof(byte))
        {
            nuint length = Math.Min((nuint)(uint)span.Length, (nuint)(uint)other.Length);
            nuint size = (nuint)Unsafe.SizeOf<T>();

            return (int)CommonPrefixLengthUnroll4(
                ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
                ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
                length * size);
        }

        // Shrink one of the spans if necessary to ensure they're both the same length. We can then iterate until
        // the Length of one of them and at least have bounds checks removed from that one.
        if (other.Length > span.Length)
        {
            other = other.Slice(0, span.Length);
        }
        else if (span.Length > other.Length)
        {
            span = span.Slice(0, other.Length);
        }
        Debug.Assert(span.Length == other.Length);

        // Find the first element pairwise that is not equal, and return its index as the length
        // of the sequence before it that matches.
        for (int i = 0; i < span.Length; i++)
        {
            if (!EqualityComparer<T>.Default.Equals(span[i], other[i]))
            {
                return i;
            }
        }

        return span.Length;
    }

    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static nuint CommonPrefixLengthUnroll4(ref byte first, ref byte second, nuint length)
    {
        nuint i;

        if (!Vector128.IsHardwareAccelerated || length < (nuint)Vector128<byte>.Count)
        {
            // To have kind of fast path for small inputs, we handle as much elements needed
            // so that either we are done or can use the unrolled loop below.
            i = length % 4;

            if (i > 0)
            {
                if (first != second)
                {
                    return 0;
                }

                if (i > 1)
                {
                    if (Unsafe.Add(ref first, 1) != Unsafe.Add(ref second, 1))
                    {
                        return 1;
                    }

                    if (i > 2 && Unsafe.Add(ref first, 2) != Unsafe.Add(ref second, 2))
                    {
                        return 2;
                    }
                }
            }

            for (; (nint)i <= (nint)length - 4; i += 4)
            {
                if (Unsafe.Add(ref first, i + 0) != Unsafe.Add(ref second, i + 0)) return i + 0;
                if (Unsafe.Add(ref first, i + 1) != Unsafe.Add(ref second, i + 1)) return i + 1;
                if (Unsafe.Add(ref first, i + 2) != Unsafe.Add(ref second, i + 2)) return i + 2;
                if (Unsafe.Add(ref first, i + 3) != Unsafe.Add(ref second, i + 3)) return i + 3;
            }

            return length;
        }

        Debug.Assert(length >= (uint)Vector128<byte>.Count);

        int mask;
        nuint lengthToExamine = length - (nuint)Vector128<byte>.Count;

        Vector128<byte> firstVec;
        Vector128<byte> secondVec;
        Vector128<byte> maskVec;
        i = 0;

        if (lengthToExamine != 0)
        {
            do
            {
                firstVec = Vector128.LoadUnsafe(ref first, i);
                secondVec = Vector128.LoadUnsafe(ref second, i);
                maskVec = Vector128.Equals(firstVec, secondVec);
                mask = (int)maskVec.ExtractMostSignificantBits();

                if (mask != 0xFFFF)
                {
                    goto Found;
                }

                i += (nuint)Vector128<byte>.Count;
            } while (i < lengthToExamine);
        }

        // Do final compare as Vector128<byte>.Count from end rather than start
        i = lengthToExamine;
        firstVec = Vector128.LoadUnsafe(ref first, i);
        secondVec = Vector128.LoadUnsafe(ref second, i);
        maskVec = Vector128.Equals(firstVec, secondVec);
        mask = (int)maskVec.ExtractMostSignificantBits();

        if (mask != 0xFFFF)
        {
            goto Found;
        }

        return length;

    Found:
        mask = ~mask;
        return i + (uint)BitOperations.TrailingZeroCount(mask);
    }
}

I tried forwarding SequenceEqual to CommonPrefixLength, but this regressed as in SequenceEqual some nice tricks are employed. These can't be used for CommonPrefixLength, as here we need the actual index of the difference, and not only true/false.


@stephentoub: in regards to #67929 (comment) I know you favored #67942, but the PR for the CommonPrefixLength is merged, and I had already similar code lying around.
And for my usecase (a kind of parser) I'd prefer to use the (optimized) implementation in .NET instead of having to use my own (which in this case is the same code anyway 😉).

@ghost ghost added the community-contribution Indicates that the PR has been added by a community member label Apr 19, 2022
@ghost
Copy link

ghost commented Apr 19, 2022

Tagging subscribers to this area: @dotnet/area-system-memory
See info in area-owners.md if you want to be subscribed.

Issue Details

Functional implementation of the new API got added in #67929
This PR vectorizes the implementation.

|     Method | LengthAndIndexOfMatch |      Mean |     Error |    StdDev | Ratio | RatioSD |
|----------- |---------------------- |----------:|----------:|----------:|------:|--------:|
|    Default |               (10, 5) |  5.432 ns | 0.1167 ns | 0.0975 ns |  1.00 |    0.00 |
| PR_Unroll4 |               (10, 5) |  4.601 ns | 0.0991 ns | 0.0927 ns |  0.85 |    0.02 |
|            |                       |           |           |           |       |         |
|    Default |               (10, 9) |  7.185 ns | 0.1153 ns | 0.1078 ns |  1.00 |    0.00 |
| PR_Unroll4 |               (10, 9) |  5.813 ns | 0.1311 ns | 0.1095 ns |  0.81 |    0.02 |
|            |                       |           |           |           |       |         |
|    Default |             (100, 16) | 11.043 ns | 0.2307 ns | 0.2158 ns |  1.00 |    0.00 |
| PR_Unroll4 |             (100, 16) |  4.112 ns | 0.0898 ns | 0.0840 ns |  0.37 |    0.01 |
|            |                       |           |           |           |       |         |
|    Default |             (100, 99) | 66.556 ns | 0.8823 ns | 0.8253 ns |  1.00 |    0.00 |
| PR_Unroll4 |             (100, 99) |  8.013 ns | 0.1776 ns | 0.1974 ns |  0.12 |    0.00 |
|            |                       |           |           |           |       |         |
|    Default |              (20, 13) |  9.158 ns | 0.1724 ns | 0.1613 ns |  1.00 |    0.00 |
| PR_Unroll4 |              (20, 13) |  3.134 ns | 0.0843 ns | 0.0789 ns |  0.34 |    0.01 |
|            |                       |           |           |           |       |         |
|    Default |                (3, 2) |  4.505 ns | 0.0481 ns | 0.0427 ns |  1.00 |    0.00 |
| PR_Unroll4 |                (3, 2) |  3.658 ns | 0.0733 ns | 0.0686 ns |  0.81 |    0.02 |

Note: I tried different unroll-levels, 4 gave best results.

Benchmarks aren't in https://github.com/dotnet/performance
Should these be added? Or is it covered via usage in other parts, or too niche to have it covered?

Benchmark code
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using BenchmarkDotNet.Attributes;

namespace ConsoleApp3;

//[ShortRunJob]
//[DisassemblyDiagnoser]
public class CommonPrefixLengthBenchmark
{
    public static void Run()
    {
        CommonPrefixLengthBenchmark bench = new();
        bench.Setup();
        Console.WriteLine(bench.Default());
        Console.WriteLine(bench.PR_Unroll4());
    }

    public record Args(int Length, int IndexOfMatch)
    {
        public override string ToString() => $"({this.Length}, {this.IndexOfMatch})";
    }

    public static IEnumerable<Args> Arguments()
    {
        yield return new Args(3, 2);
        yield return new Args(10, 5);
        yield return new Args(10, 9);
        yield return new Args(20, 13);
        yield return new Args(100, 16);
        yield return new Args(100, 99);
    }

    [ParamsSource(nameof(Arguments))]
    public Args LengthAndIndexOfMatch { get; set; } = new Args(3, 2);

    private byte[]? _arr0;
    private byte[]? _arr1;

    [GlobalSetup]
    public void Setup()
    {
        _arr0 = new byte[this.LengthAndIndexOfMatch.Length];
        Random.Shared.NextBytes(_arr0);

        _arr1 = _arr0.AsSpan().ToArray();
        _arr1[this.LengthAndIndexOfMatch.IndexOfMatch] = 0xff;
    }

    [Benchmark(Baseline = true)]
    public int Default() => SpanHelpers.CommonPrefixLength<byte>(_arr0, _arr1);

    [Benchmark]
    public int PR_Unroll4() => SpanHelpers.CommonPrefixLength_PRUnroll4<byte>(_arr0, _arr1);
}

public static class SpanHelpers
{
    public static int CommonPrefixLength<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T> other)
    {
        // Shrink one of the spans if necessary to ensure they're both the same length. We can then iterate until
        // the Length of one of them and at least have bounds checks removed from that one.
        if (other.Length > span.Length)
        {
            other = other.Slice(0, span.Length);
        }
        else if (span.Length > other.Length)
        {
            span = span.Slice(0, other.Length);
        }
        Debug.Assert(span.Length == other.Length);

        // Find the first element pairwise that is not equal, and return its index as the length
        // of the sequence before it that matches.
        for (int i = 0; i < span.Length; i++)
        {
            if (!EqualityComparer<T>.Default.Equals(span[i], other[i]))
            {
                return i;
            }
        }

        return span.Length;
    }

    public static int CommonPrefixLength_PRUnroll4<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T> other)
    {
        if (typeof(T) == typeof(byte))
        {
            nuint length = Math.Min((nuint)(uint)span.Length, (nuint)(uint)other.Length);
            nuint size = (nuint)Unsafe.SizeOf<T>();

            return (int)CommonPrefixLengthUnroll4(
                ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
                ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
                length * size);
        }

        // Shrink one of the spans if necessary to ensure they're both the same length. We can then iterate until
        // the Length of one of them and at least have bounds checks removed from that one.
        if (other.Length > span.Length)
        {
            other = other.Slice(0, span.Length);
        }
        else if (span.Length > other.Length)
        {
            span = span.Slice(0, other.Length);
        }
        Debug.Assert(span.Length == other.Length);

        // Find the first element pairwise that is not equal, and return its index as the length
        // of the sequence before it that matches.
        for (int i = 0; i < span.Length; i++)
        {
            if (!EqualityComparer<T>.Default.Equals(span[i], other[i]))
            {
                return i;
            }
        }

        return span.Length;
    }

    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static nuint CommonPrefixLengthUnroll4(ref byte first, ref byte second, nuint length)
    {
        nuint i;

        if (!Vector128.IsHardwareAccelerated || length < (nuint)Vector128<byte>.Count)
        {
            // To have kind of fast path for small inputs, we handle as much elements needed
            // so that either we are done or can use the unrolled loop below.
            i = length % 4;

            if (i > 0)
            {
                if (first != second)
                {
                    return 0;
                }

                if (i > 1)
                {
                    if (Unsafe.Add(ref first, 1) != Unsafe.Add(ref second, 1))
                    {
                        return 1;
                    }

                    if (i > 2 && Unsafe.Add(ref first, 2) != Unsafe.Add(ref second, 2))
                    {
                        return 2;
                    }
                }
            }

            for (; (nint)i <= (nint)length - 4; i += 4)
            {
                if (Unsafe.Add(ref first, i + 0) != Unsafe.Add(ref second, i + 0)) return i + 0;
                if (Unsafe.Add(ref first, i + 1) != Unsafe.Add(ref second, i + 1)) return i + 1;
                if (Unsafe.Add(ref first, i + 2) != Unsafe.Add(ref second, i + 2)) return i + 2;
                if (Unsafe.Add(ref first, i + 3) != Unsafe.Add(ref second, i + 3)) return i + 3;
            }

            return length;
        }

        Debug.Assert(length >= (uint)Vector128<byte>.Count);

        int mask;
        nuint lengthToExamine = length - (nuint)Vector128<byte>.Count;

        Vector128<byte> firstVec;
        Vector128<byte> secondVec;
        Vector128<byte> maskVec;
        i = 0;

        if (lengthToExamine != 0)
        {
            do
            {
                firstVec = Vector128.LoadUnsafe(ref first, i);
                secondVec = Vector128.LoadUnsafe(ref second, i);
                maskVec = Vector128.Equals(firstVec, secondVec);
                mask = (int)maskVec.ExtractMostSignificantBits();

                if (mask != 0xFFFF)
                {
                    goto Found;
                }

                i += (nuint)Vector128<byte>.Count;
            } while (i < lengthToExamine);
        }

        // Do final compare as Vector128<byte>.Count from end rather than start
        i = lengthToExamine;
        firstVec = Vector128.LoadUnsafe(ref first, i);
        secondVec = Vector128.LoadUnsafe(ref second, i);
        maskVec = Vector128.Equals(firstVec, secondVec);
        mask = (int)maskVec.ExtractMostSignificantBits();

        if (mask != 0xFFFF)
        {
            goto Found;
        }

        return length;

    Found:
        mask = ~mask;
        return i + (uint)BitOperations.TrailingZeroCount(mask);
    }
}

I tried forwarding SequenceEqual to CommonPrefixLength, but this regressed as in SequenceEqual some nice tricks are employed. These can't be used for CommonPrefixLength, as here we need the actual index of the difference, and not only true/false.


@stephentoub: in regards to #67929 (comment) I know you favored #67942, but the PR for the CommonPrefixLength is merged, and I had already similar code lying around.
And for my usecase (a kind of parser) I'd prefer to use the (optimized) implementation in .NET instead of having to use my own (which in this case is the same code anyway 😉).

Author: gfoidl
Assignees: -
Labels:

area-System.Memory, community-contribution

Milestone: -

ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
length * size);

return (int)(index / size);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of byte and then dividing to get rid of partials is clever... but deserves a big comment :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too verbose now?

@stephentoub
Copy link
Member

Thanks. This LGTM, but @tannergooding should review as well.

Benchmarks aren't in https://github.com/dotnet/performance. Should these be added? Or is it covered via usage in other parts, or too niche to have it covered?

It's too new; last I checked, the performance repo hadn't ingested a runtime with these APIs. If it now has them, it'd be worth adding a few.

@@ -2033,6 +2033,18 @@ public static void Sort<T>(this Span<T> span, Comparison<T> comparison)
/// <returns>The length of the common prefix shared by the two spans. If there's no shared prefix, 0 is returned.</returns>
public static int CommonPrefixLength<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T> other)
{
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint length = Math.Min((nuint)(uint)span.Length, (nuint)(uint)other.Length);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to suggest nuint.Min but then I remembered that IntPtr and UIntPtr still have them explicitly implemented until the work around dotnet/csharplang#6031 goes in.

}

i += (nuint)Vector128<byte>.Count;
} while (i < lengthToExamine);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: using while (i < lengthToExamine) { } would cover the if check and reduce nesting

nuint lengthToExamine = length - (nuint)Vector128<byte>.Count;

Vector128<byte> firstVec;
Vector128<byte> secondVec;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Remove this and do Vector128.Equals(firstVector, Vector128.LoadUnsafe(ref second, i)) as the JIT tends to optimize the load better in this scenario

-- It does still fold the load most of the time, but there are cases where storing to a local can make it "worse"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as the JIT tends to optimize the load better in this scenario

Good to know, thanks!

Checked codegen, no difference.

firstVec = Vector128.LoadUnsafe(ref first, i);
secondVec = Vector128.LoadUnsafe(ref second, i);
maskVec = Vector128.Equals(firstVec, secondVec);
mask = (int)maskVec.ExtractMostSignificantBits();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why convert to int here and below? TrailingZeroCount has an overload that takes uint and i is uint (and its also available as uint.TrailingZeroCount now)

So keeping it as uint just simplifies things and reduces code complexity.

Copy link
Member Author

@gfoidl gfoidl Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻

uint.TrailingZeroCount just delegates to BitOperations.TrailingZeroCount:

public static uint TrailingZeroCount(uint value) => (uint)BitOperations.TrailingZeroCount(value);

So it's a bit more work for the JIT. What is peferable now?
I like the generic math approach a bit more, but no strong reason for it, only feeling.
(Besides that uint.TrailingZeroCount returns uint and not int).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is peferable now?

We don't have any official guidance right yet. My personal recommendation is that if you need a uint back out to just use the uint.TrailingZeroCount version.

Copy link
Member

@tannergooding tannergooding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. A couple styling/consistency/simplification nits, but nothing that should block this

@danmoseley
Copy link
Member

@tannergooding seems this can be merged?

@adamsitnik adamsitnik added the tenet-performance Performance related issue label Apr 27, 2022
@tannergooding tannergooding merged commit aa96f04 into dotnet:main Apr 27, 2022
@gfoidl gfoidl deleted the commonprefixlength_vectorization branch April 30, 2022 09:19
@ghost ghost locked as resolved and limited conversation to collaborators May 30, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
area-System.Memory community-contribution Indicates that the PR has been added by a community member tenet-performance Performance related issue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants