Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve SpanHelpers.ClearWithReferences for arm64 #93346

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 72 additions & 9 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.X86;

#pragma warning disable 8500 // sizeof of managed types
Expand Down Expand Up @@ -341,16 +343,28 @@ public static unsafe void ClearWithReferences(ref IntPtr ip, nuint pointerSizeLe
// Writing backward allows us to get away with only simple modifications to the
// mov instruction's base and index registers between loop iterations.

for (; pointerSizeLength >= 8; pointerSizeLength -= 8)
if (pointerSizeLength >= 8)
{
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -1) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -2) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -3) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -4) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -5) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -6) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -7) = default;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -8) = default;
// Handle large inputs separately, currently, only on ARM64 since it gives us
// needed atomicity guarantees even with just 8-byte alignment.
// TODO: consider pinning the input, align to 64 bytes and use AVX/AVX512 on x64.
// since VMOVAPQ guarantees 16-byte alignment if aligned to 16 bytes.
if (AdvSimd.Arm64.IsSupported && pointerSizeLength >= 32)
goto LargeInput;

do
{
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -1) = 0;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -2) = 0;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -3) = 0;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -4) = 0;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -5) = 0;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -6) = 0;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -7) = 0;
Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -8) = 0;
pointerSizeLength -= 8;
}
while (pointerSizeLength >= 8);
}

Debug.Assert(pointerSizeLength <= 7);
Expand Down Expand Up @@ -406,6 +420,55 @@ public static unsafe void ClearWithReferences(ref IntPtr ip, nuint pointerSizeLe

// Write only element.
ip = default;
return;

LargeInput:
// Branch-less alignment to 16 bytes: unconditional zero the first pointer
// and then adjust ip pointer to +1 if it was misaligned.
ip = 0;
nuint misalignedPtrs = (nuint)(((nuint)Unsafe.AsPointer(ref ip) & 8) == 0 ? 0 : 1);
ref nint ipEnd = ref Unsafe.Add(ref ip, pointerSizeLength);
ip = ref Unsafe.Add(ref ip, misalignedPtrs);
pointerSizeLength -= misalignedPtrs;

// Work with 64b blocks and use blocks--'s flag as a loop condition.
nuint blocks = pointerSizeLength >> 3;
do
{
// On ARM64, this is supposed to be optimized into:
//
// stp xzr, xzr, [x0]
// stp xzr, xzr, [x0, #0x10]
// stp xzr, xzr, [x0, #0x20]
// stp xzr, xzr, [x0, #0x30]
//
// Although, JIT is free to optimize it into
//
// stp q0, q0, [x0]
// stp q0, q0, [x0, #20]
//
Unsafe.Add(ref ip, 0) = 0;
Unsafe.Add(ref ip, 1) = 0;
Unsafe.Add(ref ip, 2) = 0;
Unsafe.Add(ref ip, 3) = 0;
Unsafe.Add(ref ip, 4) = 0;
Unsafe.Add(ref ip, 5) = 0;
Unsafe.Add(ref ip, 6) = 0;
Unsafe.Add(ref ip, 7) = 0;
ip = ref Unsafe.Add(ref ip, 8);
blocks--;
} while (blocks != 0);

// Unconditional zero last 64 bytes to handle the remainder.
// Ideally, two instructions again.
Unsafe.Add(ref ipEnd, -1) = 0;
Unsafe.Add(ref ipEnd, -2) = 0;
Unsafe.Add(ref ipEnd, -3) = 0;
Unsafe.Add(ref ipEnd, -4) = 0;
Unsafe.Add(ref ipEnd, -5) = 0;
Unsafe.Add(ref ipEnd, -6) = 0;
Unsafe.Add(ref ipEnd, -7) = 0;
Unsafe.Add(ref ipEnd, -8) = 0;
}

public static void Reverse(ref int buf, nuint length)
Expand Down
Loading