diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.cs index a7e5f48d63180..d07ca01e814a4 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.cs @@ -3,7 +3,9 @@ using System.Diagnostics; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; using System.Runtime.Intrinsics.X86; #pragma warning disable 8500 // sizeof of managed types @@ -341,16 +343,28 @@ public static unsafe void ClearWithReferences(ref IntPtr ip, nuint pointerSizeLe // Writing backward allows us to get away with only simple modifications to the // mov instruction's base and index registers between loop iterations. - for (; pointerSizeLength >= 8; pointerSizeLength -= 8) + if (pointerSizeLength >= 8) { - Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -1) = default; - Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -2) = default; - Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -3) = default; - Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -4) = default; - Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -5) = default; - Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -6) = default; - Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -7) = default; - Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -8) = default; + // Handle large inputs separately, currently, only on ARM64 since it gives us + // needed atomicity guarantees even with just 8-byte alignment. + // TODO: consider pinning the input, align to 64 bytes and use AVX/AVX512 on x64. + // since VMOVAPQ guarantees 16-byte alignment if aligned to 16 bytes. + if (AdvSimd.Arm64.IsSupported && pointerSizeLength >= 32) + goto LargeInput; + + do + { + Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -1) = 0; + Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -2) = 0; + Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -3) = 0; + Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -4) = 0; + Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -5) = 0; + Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -6) = 0; + Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -7) = 0; + Unsafe.Add(ref Unsafe.Add(ref ip, (nint)pointerSizeLength), -8) = 0; + pointerSizeLength -= 8; + } + while (pointerSizeLength >= 8); } Debug.Assert(pointerSizeLength <= 7); @@ -406,6 +420,55 @@ public static unsafe void ClearWithReferences(ref IntPtr ip, nuint pointerSizeLe // Write only element. ip = default; + return; + + LargeInput: + // Branch-less alignment to 16 bytes: unconditional zero the first pointer + // and then adjust ip pointer to +1 if it was misaligned. + ip = 0; + nuint misalignedPtrs = (nuint)(((nuint)Unsafe.AsPointer(ref ip) & 8) == 0 ? 0 : 1); + ref nint ipEnd = ref Unsafe.Add(ref ip, pointerSizeLength); + ip = ref Unsafe.Add(ref ip, misalignedPtrs); + pointerSizeLength -= misalignedPtrs; + + // Work with 64b blocks and use blocks--'s flag as a loop condition. + nuint blocks = pointerSizeLength >> 3; + do + { + // On ARM64, this is supposed to be optimized into: + // + // stp xzr, xzr, [x0] + // stp xzr, xzr, [x0, #0x10] + // stp xzr, xzr, [x0, #0x20] + // stp xzr, xzr, [x0, #0x30] + // + // Although, JIT is free to optimize it into + // + // stp q0, q0, [x0] + // stp q0, q0, [x0, #20] + // + Unsafe.Add(ref ip, 0) = 0; + Unsafe.Add(ref ip, 1) = 0; + Unsafe.Add(ref ip, 2) = 0; + Unsafe.Add(ref ip, 3) = 0; + Unsafe.Add(ref ip, 4) = 0; + Unsafe.Add(ref ip, 5) = 0; + Unsafe.Add(ref ip, 6) = 0; + Unsafe.Add(ref ip, 7) = 0; + ip = ref Unsafe.Add(ref ip, 8); + blocks--; + } while (blocks != 0); + + // Unconditional zero last 64 bytes to handle the remainder. + // Ideally, two instructions again. + Unsafe.Add(ref ipEnd, -1) = 0; + Unsafe.Add(ref ipEnd, -2) = 0; + Unsafe.Add(ref ipEnd, -3) = 0; + Unsafe.Add(ref ipEnd, -4) = 0; + Unsafe.Add(ref ipEnd, -5) = 0; + Unsafe.Add(ref ipEnd, -6) = 0; + Unsafe.Add(ref ipEnd, -7) = 0; + Unsafe.Add(ref ipEnd, -8) = 0; } public static void Reverse(ref int buf, nuint length)