Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,21 +187,53 @@ get_bias_scale()
return 3;
}

static inline void
MlasAvx2LoaduDeinterleave32Ps(const float* src, __m256& v0, __m256& v1, __m256& v2, __m256& v3)
{
// Process 32 activations contiguously using loadu + shuffle.
// This allows us to mix neighbors (src[4i], src[4i+1], src[4i+2], src[4i+3]) across lanes,
// which matches the T-MAC weight packing.
// We use loadu + shuffle instead of gather to avoid potential issues with gather
// on some hardware and ensure deterministic behavior.
__m256 vec_b0 = _mm256_loadu_ps(src + 0);
__m256 vec_b1 = _mm256_loadu_ps(src + 8);
__m256 vec_b2 = _mm256_loadu_ps(src + 16);
__m256 vec_b3 = _mm256_loadu_ps(src + 24);

__m256 t0 = _mm256_unpacklo_ps(vec_b0, vec_b1);
__m256 t1 = _mm256_unpackhi_ps(vec_b0, vec_b1);
__m256 t2 = _mm256_unpacklo_ps(vec_b2, vec_b3);
__m256 t3 = _mm256_unpackhi_ps(vec_b2, vec_b3);

__m256 u0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2)));
__m256 u1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2)));
__m256 u2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3)));
__m256 u3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3)));

const __m256i perm_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
v0 = _mm256_permutevar8x32_ps(u0, perm_idx);
v1 = _mm256_permutevar8x32_ps(u1, perm_idx);
v2 = _mm256_permutevar8x32_ps(u2, perm_idx);
v3 = _mm256_permutevar8x32_ps(u3, perm_idx);
}

void
partial_max_g4_int8_k8(float* lut_scales, const float* b)
{
// TODO(vraspar): add support for arm neon
const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0);
__m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1);
__m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1);
__m256 vec_b2 = _mm256_i32gather_ps(b + 2, vec_bi, 1);
__m256 vec_b3 = _mm256_i32gather_ps(b + 3, vec_bi, 1);
__m256 vec_b0, vec_b1, vec_b2, vec_b3;
MlasAvx2LoaduDeinterleave32Ps(b, vec_b0, vec_b1, vec_b2, vec_b3);

const __m256 vec_sign = _mm256_set1_ps(-0.0f);
__m256 vec_babs0 = _mm256_andnot_ps(vec_sign, vec_b0);
__m256 vec_babs1 = _mm256_andnot_ps(vec_sign, vec_b1);
__m256 vec_babs2 = _mm256_andnot_ps(vec_sign, vec_b2);
__m256 vec_babs3 = _mm256_andnot_ps(vec_sign, vec_b3);

// The upper bound for the LUT values (mixtures of 4 activations) is the sum
// of their absolute values.
__m256 abssum = _mm256_add_ps(_mm256_add_ps(vec_babs0, vec_babs1), _mm256_add_ps(vec_babs2, vec_babs3));

// Reduce max across lanes to find the global maximum sum in this chunk.
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(abssum, 1), _mm256_castps256_ps128(abssum));
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
Expand All @@ -222,16 +254,14 @@ lut_ctor_g4_int8_impl(
)
{
__m256 vec_lut[16];
float biases = 0.0;
const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0);
float biases = 0.0f;
float scales = *lut_scales;
float t_scales = scales ? 1.0f / scales : 0.0f;

for (int k = 0; k < act_k / 32; ++k) {
__m256 vec_b0 = _mm256_i32gather_ps(b + k * 32 + 0, vec_bi, 1);
__m256 vec_b1 = _mm256_i32gather_ps(b + k * 32 + 1, vec_bi, 1);
__m256 vec_b2 = _mm256_i32gather_ps(b + k * 32 + 2, vec_bi, 1);
__m256 vec_b3 = _mm256_i32gather_ps(b + k * 32 + 3, vec_bi, 1);
const float* b_chunk = b + k * 32;
__m256 vec_b0, vec_b1, vec_b2, vec_b3;
MlasAvx2LoaduDeinterleave32Ps(b_chunk, vec_b0, vec_b1, vec_b2, vec_b3);

PRAGMA_UNROLL
for (int g = 1; g < 16; g += 2) {
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/test/contrib_ops/matmul_2bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,10 @@ TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_256x256) {
TestMatMul2BitsLutGemm<float>(1, 256, 256, 32, false);
}

// TODO: Re-enable once LUT GEMM asymmetric quantization accuracy issue is resolved
TEST(MatMulNBitsLutGemm, DISABLED_Float32_2Bits_Asymmetric_256x256) {
// This test was previously disabled due to accuracy issues related to non-deterministic
// gather operations. It is now re-enabled after replacing gather with deterministic
// load+shuffle operations to improve determinism and stability.
TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_256x256) {
TestMatMul2BitsLutGemm<float>(1, 256, 256, 32, true);
}

Expand Down
Loading