Skip to content

[MLAS] Fix Flaky LuT GEMM Tests by Replacing Gather with Shuffle#27174

Merged
tianleiwu merged 4 commits intomainfrom
tlwu/fix_sqnbitgemm_lut_kernel_avx2
Jan 29, 2026
Merged

[MLAS] Fix Flaky LuT GEMM Tests by Replacing Gather with Shuffle#27174
tianleiwu merged 4 commits intomainfrom
tlwu/fix_sqnbitgemm_lut_kernel_avx2

Conversation

@tianleiwu
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu commented Jan 27, 2026

Problem Description

The MatMulNBitsLutGemm test suite, specifically Float32_2Bits_Symmetric_256x256_BlkLen64, was observing intermittent failures (flakiness).
The failure manifested as numerical mismatches exceeding the tolerance, suggesting non-deterministic behavior in the kernel execution.

Root Cause Analysis

The issue was traced to the usage of _mm256_i32gather_ps in sqnbitgemm_lut_kernel_avx2.cpp
While the gather indices were technically calculating addresses within the bounds of the allocated buffer, gather instructions on certain AVX2 hardware implementations can exhibit non-deterministic behavior or subtle performance/prefetching artifacts when operating on specific stride patterns (in this case, gathering with a stride of 4 floats).

Solution

This PR replaces the _mm256_i32gather_ps instruction with a sequence of contiguous loads (_mm256_loadu_ps) followed by deterministic shuffles.

How it works:

  1. Contiguous Load: We load 4 contiguous vectors of 8 floats elements using _mm256_loadu_ps. This is always memory-safe and deterministic.
  2. Deterministic Shuffle: We apply a verified sequence of unpack and permutevar8x32 instructions to rearrange these 32 linearly loaded elements into the exact same stride-4 layout that the gather instruction produced.

Benefits:

  • Stability: Eliminates the hardware-dependent non-determinism of gather.
  • Safety: Usage of loadu guarantees we only touch memory within the explicit range of the 32 elements we intend to load.
  • Correctness: The shuffle logic was verified against the reference gather behavior using a C++ reproduction script to ensure bit-exact layout equivalence.

Performance

Micro-benchmark on MatMulNBitsLutGemm (256x256, BlkLen=64).
Original (Gather): ~55.55 us
Fixed (Load+Shuffle): ~57.79 us
Delta: +2.24 us (~4% slower)

The slight performance regression is expected because replacing a single hardware gather instruction with a sequence of loadu, unpack, and permute instructions adds instruction count overhead. However, this is a necessary tradeoff to ensure deterministic behavior and memory safety across all AVX2 implementations.

Verification

  • Tests: All 10 tests in MatMulNBitsLutGemm passed successfully (including the previously flaky BlkLen64 case).

@tianleiwu tianleiwu force-pushed the tlwu/fix_sqnbitgemm_lut_kernel_avx2 branch from 18a34cc to 6f92041 Compare January 27, 2026 17:02
@tianleiwu tianleiwu marked this pull request as draft January 27, 2026 17:02
@tianleiwu tianleiwu closed this Jan 27, 2026
@tianleiwu tianleiwu reopened this Jan 27, 2026
@tianleiwu tianleiwu force-pushed the tlwu/fix_sqnbitgemm_lut_kernel_avx2 branch from 909562d to 1fdd9bb Compare January 27, 2026 22:08
@tianleiwu tianleiwu force-pushed the tlwu/fix_sqnbitgemm_lut_kernel_avx2 branch from e90d1a3 to dbd1de1 Compare January 28, 2026 00:48
@tianleiwu tianleiwu marked this pull request as ready for review January 28, 2026 00:57
@tianleiwu tianleiwu changed the title [MLAS] Fix precision issue in MatMulNBits LutGemm AVX2 kernel [MLAS] Fix Flaky LuT GEMM Tests by Replacing Gather with Shuffle Jan 28, 2026
@vraspar
Copy link
Copy Markdown
Contributor

vraspar commented Jan 28, 2026

I tested changes on 2 bit llama model, and didn't see any noticeable throughput difference

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Addresses flakiness in MatMulNBitsLutGemm by replacing AVX2 gather-based activation loading with deterministic contiguous loads plus shuffles.

Changes:

  • Replaced _mm256_i32gather_ps with _mm256_loadu_ps + unpack/permute shuffles in the AVX2 LUT GEMM kernel.
  • Updated LUT scale computation path to use the same deterministic deinterleave approach.
  • Re-enabled the previously disabled asymmetric 256x256 LUT GEMM test.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
onnxruntime/test/contrib_ops/matmul_2bits_test.cc Re-enables the asymmetric 256x256 LUT GEMM test case.
onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp Replaces gather with deterministic load+shuffle in LUT construction and max-scaling logic.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@tianleiwu tianleiwu merged commit 1caa3e6 into main Jan 29, 2026
95 of 96 checks passed
@tianleiwu tianleiwu deleted the tlwu/fix_sqnbitgemm_lut_kernel_avx2 branch January 29, 2026 07:27
tianleiwu added a commit that referenced this pull request Jan 29, 2026
)

## Problem Description
The `MatMulNBitsLutGemm` test suite, specifically
`Float32_2Bits_Symmetric_256x256_BlkLen64`, was observing intermittent
failures (flakiness).
The failure manifested as numerical mismatches exceeding the tolerance,
suggesting non-deterministic behavior in the kernel execution.

## Root Cause Analysis
The issue was traced to the usage of `_mm256_i32gather_ps` in
sqnbitgemm_lut_kernel_avx2.cpp
While the gather indices were technically calculating addresses within
the bounds of the allocated buffer, gather instructions on certain AVX2
hardware implementations can exhibit non-deterministic behavior or
subtle performance/prefetching artifacts when operating on specific
stride patterns (in this case, gathering with a stride of 4 floats).

## Solution
This PR replaces the `_mm256_i32gather_ps` instruction with a sequence
of **contiguous loads (`_mm256_loadu_ps`) followed by deterministic
shuffles**.

### How it works:
1. **Contiguous Load**: We load 4 contiguous vectors of 8 floats
elements using `_mm256_loadu_ps`. This is always memory-safe and
deterministic.
2. **Deterministic Shuffle**: We apply a verified sequence of `unpack`
and `permutevar8x32` instructions to rearrange these 32 linearly loaded
elements into the exact same stride-4 layout that the gather instruction
produced.

### Benefits:
* **Stability**: Eliminates the hardware-dependent non-determinism of
gather.
* **Safety**: Usage of `loadu` guarantees we only touch memory within
the explicit range of the 32 elements we intend to load.
* **Correctness**: The shuffle logic was verified against the reference
gather behavior using a C++ reproduction script to ensure bit-exact
layout equivalence.

### Performance

Micro-benchmark on MatMulNBitsLutGemm (256x256, BlkLen=64).
Original (Gather): ~55.55 us
Fixed (Load+Shuffle): ~57.79 us
Delta: +2.24 us (~4% slower)

The slight performance regression is expected because replacing a single
hardware gather instruction with a sequence of loadu, unpack, and
permute instructions adds instruction count overhead. However, this is a
necessary tradeoff to ensure deterministic behavior and memory safety
across all AVX2 implementations.

## Verification
* **Tests**: All 9 tests in `MatMulNBitsLutGemm` passed successfully
(including the previously flaky `BlkLen64` case).
tianleiwu added a commit that referenced this pull request Jan 29, 2026
| Commit | Commit Title | Author |
| :--- | :--- | :--- |
| `6861526` | [MLAS] Fix Data Race in MlasLutGemm by Serializing LUT
Generation (#27179) | tianleiwu |
| `592bcb4` | remove coloredlogs (#27135) | tianleiwu |
| `0f153de` | Add API GetTensorElementTypeAndShapeDataReference (#27175)
| adrianlizarraga |
| `1caa3e6` | [MLAS] Fix Flaky LuT GEMM Tests by Replacing Gather with
Shuffle (#27174) | tianleiwu |

---------

Co-authored-by: Adrian Lizarraga <adlizarraga@microsoft.com>
milpuz01 pushed a commit to milpuz01/onnxruntime that referenced this pull request Feb 4, 2026
…rosoft#27174)

## Problem Description
The `MatMulNBitsLutGemm` test suite, specifically
`Float32_2Bits_Symmetric_256x256_BlkLen64`, was observing intermittent
failures (flakiness).
The failure manifested as numerical mismatches exceeding the tolerance,
suggesting non-deterministic behavior in the kernel execution.

## Root Cause Analysis
The issue was traced to the usage of `_mm256_i32gather_ps` in
sqnbitgemm_lut_kernel_avx2.cpp
While the gather indices were technically calculating addresses within
the bounds of the allocated buffer, gather instructions on certain AVX2
hardware implementations can exhibit non-deterministic behavior or
subtle performance/prefetching artifacts when operating on specific
stride patterns (in this case, gathering with a stride of 4 floats).

## Solution
This PR replaces the `_mm256_i32gather_ps` instruction with a sequence
of **contiguous loads (`_mm256_loadu_ps`) followed by deterministic
shuffles**.

### How it works:
1. **Contiguous Load**: We load 4 contiguous vectors of 8 floats
elements using `_mm256_loadu_ps`. This is always memory-safe and
deterministic.
2. **Deterministic Shuffle**: We apply a verified sequence of `unpack`
and `permutevar8x32` instructions to rearrange these 32 linearly loaded
elements into the exact same stride-4 layout that the gather instruction
produced.

### Benefits:
* **Stability**: Eliminates the hardware-dependent non-determinism of
gather.
* **Safety**: Usage of `loadu` guarantees we only touch memory within
the explicit range of the 32 elements we intend to load.
* **Correctness**: The shuffle logic was verified against the reference
gather behavior using a C++ reproduction script to ensure bit-exact
layout equivalence.

### Performance

Micro-benchmark on MatMulNBitsLutGemm (256x256, BlkLen=64).
Original (Gather): ~55.55 us
Fixed (Load+Shuffle): ~57.79 us
Delta: +2.24 us (~4% slower)

The slight performance regression is expected because replacing a single
hardware gather instruction with a sequence of loadu, unpack, and
permute instructions adds instruction count overhead. However, this is a
necessary tradeoff to ensure deterministic behavior and memory safety
across all AVX2 implementations.

## Verification
* **Tests**: All 9 tests in `MatMulNBitsLutGemm` passed successfully
(including the previously flaky `BlkLen64` case).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants