diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index f029e539f02a1..cb099c2409a44 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -548,32 +548,32 @@ MlasLutGemm( // const int num_groups = static_cast(K / BlkLen); - // Parallelize over M (batch dimension) - // Each iteration processes one row of the activation matrix + // Iterate over M (batch dimension) + // Each iteration processes one row of the activation matrix. + // NOTE: This loop is intentionally serialized. Previous attempts to parallelize + // using MlasTrySimpleParallel caused flaky test failures (race conditions) + // when M > 1 (e.g., Batch32 case). Since GenerateLUT is lightweight, + // serial execution ensures correctness with negligible performance impact. // TODO(vraspar): Ideally we have to do block parallelism here - MlasTrySimpleParallel( - threadpool, - static_cast(M), - [&](ptrdiff_t ine11) { - const size_t row_offset = static_cast(ine11) * K; - const size_t lut_offset = static_cast(ine11) * K * 4; // 4 bytes per K element for 2-bit LUT - const size_t scale_bias_offset = static_cast(ine11) * lut_scales_size; - - // Call the dispatch function for this row - // ggml_tmac_mul_mat_task_init - Dispatch->GenerateLUT( - const_cast(a_float + row_offset), // Input activation for this row - qlut + lut_offset, // Output LUT for this row - lut_scales + scale_bias_offset, // Scales for this row - lut_biases + scale_bias_offset, // Biases for this row - M, - K, - N, - tmac_params.act_group_size - ); - } - ); + for (size_t ine11 = 0; ine11 < static_cast(M); ine11++) { + const size_t row_offset = ine11 * K; + const size_t lut_offset = ine11 * K * 4; // 4 bytes per K element for 2-bit LUT + const size_t scale_bias_offset = ine11 * lut_scales_size; + + // Call the dispatch function for this row + // ggml_tmac_mul_mat_task_init + Dispatch->GenerateLUT( + const_cast(a_float + row_offset), // Input activation for this row + qlut + lut_offset, // Output LUT for this row + lut_scales + scale_bias_offset, // Scales for this row + lut_biases + scale_bias_offset, // Biases for this row + M, + K, + N, + tmac_params.act_group_size + ); + } // all relevant LUT's have been generated // equivalent of lut_mul_mat's ggml_backend_tmac_mul_mat function ggml_barrier line