[MLAS] Fix Flaky LuT GEMM Tests by Replacing Gather with Shuffle#27174
[MLAS] Fix Flaky LuT GEMM Tests by Replacing Gather with Shuffle#27174
Conversation
18a34cc to
6f92041
Compare
909562d to
1fdd9bb
Compare
e90d1a3 to
dbd1de1
Compare
|
I tested changes on 2 bit llama model, and didn't see any noticeable throughput difference |
There was a problem hiding this comment.
Pull request overview
Addresses flakiness in MatMulNBitsLutGemm by replacing AVX2 gather-based activation loading with deterministic contiguous loads plus shuffles.
Changes:
- Replaced
_mm256_i32gather_pswith_mm256_loadu_ps+ unpack/permute shuffles in the AVX2 LUT GEMM kernel. - Updated LUT scale computation path to use the same deterministic deinterleave approach.
- Re-enabled the previously disabled asymmetric 256x256 LUT GEMM test.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
onnxruntime/test/contrib_ops/matmul_2bits_test.cc |
Re-enables the asymmetric 256x256 LUT GEMM test case. |
onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp |
Replaces gather with deterministic load+shuffle in LUT construction and max-scaling logic. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
) ## Problem Description The `MatMulNBitsLutGemm` test suite, specifically `Float32_2Bits_Symmetric_256x256_BlkLen64`, was observing intermittent failures (flakiness). The failure manifested as numerical mismatches exceeding the tolerance, suggesting non-deterministic behavior in the kernel execution. ## Root Cause Analysis The issue was traced to the usage of `_mm256_i32gather_ps` in sqnbitgemm_lut_kernel_avx2.cpp While the gather indices were technically calculating addresses within the bounds of the allocated buffer, gather instructions on certain AVX2 hardware implementations can exhibit non-deterministic behavior or subtle performance/prefetching artifacts when operating on specific stride patterns (in this case, gathering with a stride of 4 floats). ## Solution This PR replaces the `_mm256_i32gather_ps` instruction with a sequence of **contiguous loads (`_mm256_loadu_ps`) followed by deterministic shuffles**. ### How it works: 1. **Contiguous Load**: We load 4 contiguous vectors of 8 floats elements using `_mm256_loadu_ps`. This is always memory-safe and deterministic. 2. **Deterministic Shuffle**: We apply a verified sequence of `unpack` and `permutevar8x32` instructions to rearrange these 32 linearly loaded elements into the exact same stride-4 layout that the gather instruction produced. ### Benefits: * **Stability**: Eliminates the hardware-dependent non-determinism of gather. * **Safety**: Usage of `loadu` guarantees we only touch memory within the explicit range of the 32 elements we intend to load. * **Correctness**: The shuffle logic was verified against the reference gather behavior using a C++ reproduction script to ensure bit-exact layout equivalence. ### Performance Micro-benchmark on MatMulNBitsLutGemm (256x256, BlkLen=64). Original (Gather): ~55.55 us Fixed (Load+Shuffle): ~57.79 us Delta: +2.24 us (~4% slower) The slight performance regression is expected because replacing a single hardware gather instruction with a sequence of loadu, unpack, and permute instructions adds instruction count overhead. However, this is a necessary tradeoff to ensure deterministic behavior and memory safety across all AVX2 implementations. ## Verification * **Tests**: All 9 tests in `MatMulNBitsLutGemm` passed successfully (including the previously flaky `BlkLen64` case).
| Commit | Commit Title | Author | | :--- | :--- | :--- | | `6861526` | [MLAS] Fix Data Race in MlasLutGemm by Serializing LUT Generation (#27179) | tianleiwu | | `592bcb4` | remove coloredlogs (#27135) | tianleiwu | | `0f153de` | Add API GetTensorElementTypeAndShapeDataReference (#27175) | adrianlizarraga | | `1caa3e6` | [MLAS] Fix Flaky LuT GEMM Tests by Replacing Gather with Shuffle (#27174) | tianleiwu | --------- Co-authored-by: Adrian Lizarraga <adlizarraga@microsoft.com>
…rosoft#27174) ## Problem Description The `MatMulNBitsLutGemm` test suite, specifically `Float32_2Bits_Symmetric_256x256_BlkLen64`, was observing intermittent failures (flakiness). The failure manifested as numerical mismatches exceeding the tolerance, suggesting non-deterministic behavior in the kernel execution. ## Root Cause Analysis The issue was traced to the usage of `_mm256_i32gather_ps` in sqnbitgemm_lut_kernel_avx2.cpp While the gather indices were technically calculating addresses within the bounds of the allocated buffer, gather instructions on certain AVX2 hardware implementations can exhibit non-deterministic behavior or subtle performance/prefetching artifacts when operating on specific stride patterns (in this case, gathering with a stride of 4 floats). ## Solution This PR replaces the `_mm256_i32gather_ps` instruction with a sequence of **contiguous loads (`_mm256_loadu_ps`) followed by deterministic shuffles**. ### How it works: 1. **Contiguous Load**: We load 4 contiguous vectors of 8 floats elements using `_mm256_loadu_ps`. This is always memory-safe and deterministic. 2. **Deterministic Shuffle**: We apply a verified sequence of `unpack` and `permutevar8x32` instructions to rearrange these 32 linearly loaded elements into the exact same stride-4 layout that the gather instruction produced. ### Benefits: * **Stability**: Eliminates the hardware-dependent non-determinism of gather. * **Safety**: Usage of `loadu` guarantees we only touch memory within the explicit range of the 32 elements we intend to load. * **Correctness**: The shuffle logic was verified against the reference gather behavior using a C++ reproduction script to ensure bit-exact layout equivalence. ### Performance Micro-benchmark on MatMulNBitsLutGemm (256x256, BlkLen=64). Original (Gather): ~55.55 us Fixed (Load+Shuffle): ~57.79 us Delta: +2.24 us (~4% slower) The slight performance regression is expected because replacing a single hardware gather instruction with a sequence of loadu, unpack, and permute instructions adds instruction count overhead. However, this is a necessary tradeoff to ensure deterministic behavior and memory safety across all AVX2 implementations. ## Verification * **Tests**: All 9 tests in `MatMulNBitsLutGemm` passed successfully (including the previously flaky `BlkLen64` case).
Problem Description
The
MatMulNBitsLutGemmtest suite, specificallyFloat32_2Bits_Symmetric_256x256_BlkLen64, was observing intermittent failures (flakiness).The failure manifested as numerical mismatches exceeding the tolerance, suggesting non-deterministic behavior in the kernel execution.
Root Cause Analysis
The issue was traced to the usage of
_mm256_i32gather_psin sqnbitgemm_lut_kernel_avx2.cppWhile the gather indices were technically calculating addresses within the bounds of the allocated buffer, gather instructions on certain AVX2 hardware implementations can exhibit non-deterministic behavior or subtle performance/prefetching artifacts when operating on specific stride patterns (in this case, gathering with a stride of 4 floats).
Solution
This PR replaces the
_mm256_i32gather_psinstruction with a sequence of contiguous loads (_mm256_loadu_ps) followed by deterministic shuffles.How it works:
_mm256_loadu_ps. This is always memory-safe and deterministic.unpackandpermutevar8x32instructions to rearrange these 32 linearly loaded elements into the exact same stride-4 layout that the gather instruction produced.Benefits:
loaduguarantees we only touch memory within the explicit range of the 32 elements we intend to load.Performance
Micro-benchmark on MatMulNBitsLutGemm (256x256, BlkLen=64).
Original (Gather): ~55.55 us
Fixed (Load+Shuffle): ~57.79 us
Delta: +2.24 us (~4% slower)
The slight performance regression is expected because replacing a single hardware gather instruction with a sequence of loadu, unpack, and permute instructions adds instruction count overhead. However, this is a necessary tradeoff to ensure deterministic behavior and memory safety across all AVX2 implementations.
Verification
MatMulNBitsLutGemmpassed successfully (including the previously flakyBlkLen64case).