Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions onnxruntime/core/mlas/lib/qlutgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,32 +548,32 @@ MlasLutGemm(

// const int num_groups = static_cast<int>(K / BlkLen);

// Parallelize over M (batch dimension)
// Each iteration processes one row of the activation matrix
// Iterate over M (batch dimension)
// Each iteration processes one row of the activation matrix.
// NOTE: This loop is intentionally serialized. Previous attempts to parallelize
// using MlasTrySimpleParallel caused flaky test failures (race conditions)
// when M > 1 (e.g., Batch32 case). Since GenerateLUT is lightweight,
// serial execution ensures correctness with negligible performance impact.
// TODO(vraspar): Ideally we have to do block parallelism here

MlasTrySimpleParallel(
threadpool,
static_cast<size_t>(M),
[&](ptrdiff_t ine11) {
const size_t row_offset = static_cast<size_t>(ine11) * K;
const size_t lut_offset = static_cast<size_t>(ine11) * K * 4; // 4 bytes per K element for 2-bit LUT
const size_t scale_bias_offset = static_cast<size_t>(ine11) * lut_scales_size;

// Call the dispatch function for this row
// ggml_tmac_mul_mat_task_init
Dispatch->GenerateLUT(
const_cast<float*>(a_float + row_offset), // Input activation for this row
qlut + lut_offset, // Output LUT for this row
lut_scales + scale_bias_offset, // Scales for this row
lut_biases + scale_bias_offset, // Biases for this row
M,
K,
N,
tmac_params.act_group_size
);
}
);
for (size_t ine11 = 0; ine11 < static_cast<size_t>(M); ine11++) {
const size_t row_offset = ine11 * K;
const size_t lut_offset = ine11 * K * 4; // 4 bytes per K element for 2-bit LUT
const size_t scale_bias_offset = ine11 * lut_scales_size;

// Call the dispatch function for this row
// ggml_tmac_mul_mat_task_init
Dispatch->GenerateLUT(
const_cast<float*>(a_float + row_offset), // Input activation for this row
qlut + lut_offset, // Output LUT for this row
lut_scales + scale_bias_offset, // Scales for this row
lut_biases + scale_bias_offset, // Biases for this row
M,
K,
N,
tmac_params.act_group_size
);
}

// all relevant LUT's have been generated
// equivalent of lut_mul_mat's ggml_backend_tmac_mul_mat function ggml_barrier line
Expand Down
Loading