diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp index b54f051ca1504..a89993d4515b8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -187,21 +187,53 @@ get_bias_scale() return 3; } +static inline void +MlasAvx2LoaduDeinterleave32Ps(const float* src, __m256& v0, __m256& v1, __m256& v2, __m256& v3) +{ + // Process 32 activations contiguously using loadu + shuffle. + // This allows us to mix neighbors (src[4i], src[4i+1], src[4i+2], src[4i+3]) across lanes, + // which matches the T-MAC weight packing. + // We use loadu + shuffle instead of gather to avoid potential issues with gather + // on some hardware and ensure deterministic behavior. + __m256 vec_b0 = _mm256_loadu_ps(src + 0); + __m256 vec_b1 = _mm256_loadu_ps(src + 8); + __m256 vec_b2 = _mm256_loadu_ps(src + 16); + __m256 vec_b3 = _mm256_loadu_ps(src + 24); + + __m256 t0 = _mm256_unpacklo_ps(vec_b0, vec_b1); + __m256 t1 = _mm256_unpackhi_ps(vec_b0, vec_b1); + __m256 t2 = _mm256_unpacklo_ps(vec_b2, vec_b3); + __m256 t3 = _mm256_unpackhi_ps(vec_b2, vec_b3); + + __m256 u0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2))); + __m256 u1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2))); + __m256 u2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3))); + __m256 u3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3))); + + const __m256i perm_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + v0 = _mm256_permutevar8x32_ps(u0, perm_idx); + v1 = _mm256_permutevar8x32_ps(u1, perm_idx); + v2 = _mm256_permutevar8x32_ps(u2, perm_idx); + v3 = _mm256_permutevar8x32_ps(u3, perm_idx); +} + void partial_max_g4_int8_k8(float* lut_scales, const float* b) { - // TODO(vraspar): add support for arm neon - const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); - __m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1); - __m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1); - __m256 vec_b2 = _mm256_i32gather_ps(b + 2, vec_bi, 1); - __m256 vec_b3 = _mm256_i32gather_ps(b + 3, vec_bi, 1); + __m256 vec_b0, vec_b1, vec_b2, vec_b3; + MlasAvx2LoaduDeinterleave32Ps(b, vec_b0, vec_b1, vec_b2, vec_b3); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); __m256 vec_babs0 = _mm256_andnot_ps(vec_sign, vec_b0); __m256 vec_babs1 = _mm256_andnot_ps(vec_sign, vec_b1); __m256 vec_babs2 = _mm256_andnot_ps(vec_sign, vec_b2); __m256 vec_babs3 = _mm256_andnot_ps(vec_sign, vec_b3); + + // The upper bound for the LUT values (mixtures of 4 activations) is the sum + // of their absolute values. __m256 abssum = _mm256_add_ps(_mm256_add_ps(vec_babs0, vec_babs1), _mm256_add_ps(vec_babs2, vec_babs3)); + + // Reduce max across lanes to find the global maximum sum in this chunk. __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(abssum, 1), _mm256_castps256_ps128(abssum)); max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); @@ -222,16 +254,14 @@ lut_ctor_g4_int8_impl( ) { __m256 vec_lut[16]; - float biases = 0.0; - const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); + float biases = 0.0f; float scales = *lut_scales; float t_scales = scales ? 1.0f / scales : 0.0f; for (int k = 0; k < act_k / 32; ++k) { - __m256 vec_b0 = _mm256_i32gather_ps(b + k * 32 + 0, vec_bi, 1); - __m256 vec_b1 = _mm256_i32gather_ps(b + k * 32 + 1, vec_bi, 1); - __m256 vec_b2 = _mm256_i32gather_ps(b + k * 32 + 2, vec_bi, 1); - __m256 vec_b3 = _mm256_i32gather_ps(b + k * 32 + 3, vec_bi, 1); + const float* b_chunk = b + k * 32; + __m256 vec_b0, vec_b1, vec_b2, vec_b3; + MlasAvx2LoaduDeinterleave32Ps(b_chunk, vec_b0, vec_b1, vec_b2, vec_b3); PRAGMA_UNROLL for (int g = 1; g < 16; g += 2) { diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 853458312cd1f..3d5e3e5f360b4 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -371,8 +371,10 @@ TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_256x256) { TestMatMul2BitsLutGemm(1, 256, 256, 32, false); } -// TODO: Re-enable once LUT GEMM asymmetric quantization accuracy issue is resolved -TEST(MatMulNBitsLutGemm, DISABLED_Float32_2Bits_Asymmetric_256x256) { +// This test was previously disabled due to accuracy issues related to non-deterministic +// gather operations. It is now re-enabled after replacing gather with deterministic +// load+shuffle operations to improve determinism and stability. +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_256x256) { TestMatMul2BitsLutGemm(1, 256, 256, 32, true); }