From 83ffc5d9f6d5129d786d8a1970a903dbd96a71b5 Mon Sep 17 00:00:00 2001 From: apsonawane Date: Thu, 25 Sep 2025 21:51:48 +0000 Subject: [PATCH 1/4] Fix merge conflicts --- onnxruntime/contrib_ops/cpu/moe/moe_helper.h | 14 ++++ .../cpu/moe/moe_quantization_cpu.cc | 68 +++++++++++-------- 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index 39249f842e632..b3801fda0598e 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -33,8 +33,22 @@ struct MoEParameters { MoEParallelType parallel_type{MoEParallelType::None}; int64_t tensor_shards{1}; }; + namespace moe_helper { +// Validate block-wise quantization requirements according to op spec +inline Status ValidateBlockwiseQuantization(int64_t block_size, int64_t hidden_size, int64_t inter_size) { + if (block_size > 0) { + ORT_ENFORCE(hidden_size % block_size == 0, + "For block-wise quantization, hidden_size (", hidden_size, + ") must be divisible by block_size (", block_size, ")."); + ORT_ENFORCE(inter_size % block_size == 0, + "For block-wise quantization, inter_size (", inter_size, + ") must be divisible by block_size (", block_size, ")."); + } + return Status::OK(); +} + template Status CheckInputs(MoEParameters& parameters, const Tensor* input, // required diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 8a3c3f6d9f37a..0287dc78a642a 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -61,22 +61,8 @@ bool CanUseMlasQ4Dequant(int64_t num_bits) { bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, int64_t rows, int64_t cols, MLAS_BLK_QUANT_TYPE& out_qtype) { - if (expert_weight_bits != 4) { - return false; - } - - if (block_size == 64) { - out_qtype = BlkQ4Sym64; - } else if (block_size == 128) { - out_qtype = BlkQ4Sym128; - } else if (block_size == 0) { - out_qtype = BlkQ4Sym; - } else { - return false; - } - - size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(cols), static_cast(rows)); - return expected_size > 0; + // TEMPORARY: Disable direct Q4 GEMM + return false; } } // namespace @@ -392,6 +378,9 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t num_experts = moe_params.num_experts; const int64_t fc1_out_features = inter_size * (swiglu_fusion_ > 0 ? 2 : 1); + // Validate block-wise quantization requirements + ORT_RETURN_IF_ERROR(moe_helper::ValidateBlockwiseQuantization(block_size_, hidden_size, inter_size)); + auto* output = context->Output(0, input_shape); auto* tp = context->GetOperatorThreadPool(); @@ -503,7 +492,11 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int max_expert_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; const int64_t total_expert_work = std::accumulate(expert_token_map.begin(), expert_token_map.end(), 0LL, [](int64_t sum, const std::vector& tokens) { return sum + static_cast(tokens.size()); }); + // Expert threading uses larger divisor (8x vs 4x for routing) because expert processing involves + // heavier computation (GEMM operations) that benefits more from parallelization const int64_t expert_thread_divisor = std::max(1, max_expert_threads * 8); + // Minimum work per expert thread is smaller (16 vs 32 for routing) because expert work + // involves matrix operations that are more compute-intensive per token const int64_t min_expert_work_per_thread = std::max(int64_t{16}, total_expert_work / expert_thread_divisor); int num_expert_threads = (tp == nullptr || total_expert_work < min_expert_work_per_thread) ? 1 : std::min(narrow(total_expert_work / std::max(int64_t{1}, min_expert_work_per_thread)), std::min(narrow(num_experts), max_expert_threads)); @@ -568,12 +561,26 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } } - if (total_work < 48) { + // Adaptive thresholds optimized for both decoding (1 token) and batch inference + // For single-token decoding: allow more parallelization across experts since each expert + // represents substantial work (especially for large hidden_size) + // For batch: higher thresholds ensure efficient thread utilization + if (total_work < 4) { + // Very small workload - single token with few experts, use 1 thread num_expert_threads = 1; - } else if (total_work < 192) { - num_expert_threads = std::min(num_expert_threads, 2); - } else if (total_work < 512) { + } else if (total_work < 12) { + // Small workload - single token or small batch, use up to 4 threads + // This covers typical decoding case with top-k=2-4 where each expert is substantial work num_expert_threads = std::min(num_expert_threads, 4); + } else if (total_work < 32) { + // Medium workload - use up to 6 threads + num_expert_threads = std::min(num_expert_threads, 6); + } else if (total_work < 128) { + // Large workload - use up to 8 threads + num_expert_threads = std::min(num_expert_threads, 8); + } else if (total_work < 384) { + // Very large workload - use more threads + num_expert_threads = std::min(num_expert_threads, 12); } std::sort(expert_workload.begin(), expert_workload.end(), @@ -652,7 +659,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const T* fc1_scales_ptr; if (is_fc1_block_wise) { - const int64_t fc1_blocks_per_row = fc1_scales_dims[2]; + const int64_t fc1_blocks_per_row = (hidden_size + block_size_ - 1) / block_size_; fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features * fc1_blocks_per_row; } else { fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features; @@ -714,15 +721,17 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } // Traditional approach: dequantize + regular GEMM + // Use parallel dequantization when we have multiple blocks and sufficient work per thread + // Threshold of 32 features ensures each thread has meaningful work to justify threading overhead if (num_dequant_blocks > 1 && fc1_out_features >= 32) { concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_dequant_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_row = block_idx * dequant_block_size; const int64_t end_row = std::min(start_row + dequant_block_size, fc1_out_features); const auto offset = expert_idx * fc1_out_features * fc1_packed_cols + start_row * fc1_packed_cols; DequantizeBlock(fc1_weights_data + offset, - fc1_scales_ptr + (is_fc1_block_wise ? start_row * fc1_scales_dims[2] : start_row), + fc1_scales_ptr + (is_fc1_block_wise ? start_row * fc1_blocks_per_row : start_row), is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, - end_row - start_row, hidden_size, B1_dequant + start_row * hidden_size, tp); + end_row - start_row, hidden_size, B1_dequant + start_row * hidden_size, nullptr); }); } else { DequantizeBlock(fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, @@ -778,7 +787,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc1_gemm_done: - const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); + // Activation threshold scales inversely with inter_size to balance work per thread + // For large inter_size, fewer tokens justify parallel activation; for small inter_size, more tokens needed + // Base value of 256 ensures reasonable work distribution across different model configurations + const int64_t activation_threshold = std::max(int64_t{4}, 256 / inter_size); if (num_expert_tokens >= activation_threshold && tp != nullptr) { const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; @@ -812,7 +824,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const T* fc2_scales_ptr; if (is_fc2_block_wise) { - const int64_t fc2_blocks_per_row = fc2_scales_dims[2]; + const int64_t fc2_blocks_per_row = (inter_size + block_size_ - 1) / block_size_; fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size * fc2_blocks_per_row; } else { fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size; @@ -874,15 +886,17 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } // Traditional approach: dequantize + regular GEMM + // Use parallel dequantization when we have multiple blocks and sufficient work per thread + // Threshold of 32 for hidden_size ensures each thread has meaningful work to justify threading overhead if (num_fc2_dequant_blocks > 1 && hidden_size >= 32) { concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_fc2_dequant_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_row = block_idx * fc2_dequant_block_size; const int64_t end_row = std::min(start_row + fc2_dequant_block_size, hidden_size); const auto offset = expert_idx * hidden_size * fc2_packed_cols + start_row * fc2_packed_cols; DequantizeBlock(fc2_weights_data + offset, - fc2_scales_ptr + (is_fc2_block_wise ? start_row * fc2_scales_dims[2] : start_row), + fc2_scales_ptr + (is_fc2_block_wise ? start_row * fc2_blocks_per_row : start_row), is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, - end_row - start_row, inter_size, B2_dequant + start_row * inter_size, tp); + end_row - start_row, inter_size, B2_dequant + start_row * inter_size, nullptr); }); } else { DequantizeBlock(fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, From d399c66d97d61dcb8e8fac9bf7bfd442ee14f0ad Mon Sep 17 00:00:00 2001 From: apsonawane Date: Thu, 25 Sep 2025 21:50:02 +0000 Subject: [PATCH 2/4] Re-enable quantized Mlas --- .../cpu/moe/moe_quantization_cpu.cc | 62 +++++++++++++++++-- 1 file changed, 58 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 0287dc78a642a..82141890d0ee5 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -61,8 +61,56 @@ bool CanUseMlasQ4Dequant(int64_t num_bits) { bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, int64_t rows, int64_t cols, MLAS_BLK_QUANT_TYPE& out_qtype) { - // TEMPORARY: Disable direct Q4 GEMM - return false; + if (expert_weight_bits != 4) { + return false; + } + + if (rows <= 0 || cols <= 0) { + return false; + } + + MLAS_BLK_QUANT_TYPE qtype; + switch (block_size) { + case 0: + case 32: + qtype = BlkQ4Sym; + break; + case 64: + qtype = BlkQ4Sym64; + break; + case 128: + qtype = BlkQ4Sym128; + break; + default: + return false; + } + + const size_t pack_size = + MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); + + if (pack_size == 0) { + return false; + } + + out_qtype = qtype; + return true; +} + +void TransposeFp32RowMajorToColumnMajor(const float* src, + float* dst, + int64_t rows, + int64_t cols) { + if (rows <= 0 || cols <= 0) { + return; + } + + for (int64_t r = 0; r < rows; ++r) { + const size_t row_offset = static_cast(r) * static_cast(cols); + for (int64_t c = 0; c < cols; ++c) { + dst[static_cast(c) * static_cast(rows) + static_cast(r)] = + src[row_offset + static_cast(c)]; + } + } } } // namespace @@ -99,13 +147,19 @@ Status ConvertToMlasQ4Format(const uint8_t* quantized_data, DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, temp_float, nullptr); - size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(cols), static_cast(rows)); + IAllocatorUniquePtr temp_float_col_major_buffer = + IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); + float* temp_float_col_major = temp_float_col_major_buffer.get(); + + TransposeFp32RowMajorToColumnMajor(temp_float, temp_float_col_major, rows, cols); + + size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); if (packed_size == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); } mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); - MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast(cols), static_cast(rows), static_cast(cols)); + MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float_col_major, static_cast(rows), static_cast(cols), static_cast(rows)); return Status::OK(); } From 8289fcb1e0a0503b07c6649770de2f74a0aee468 Mon Sep 17 00:00:00 2001 From: apsonawane Date: Thu, 25 Sep 2025 22:18:01 +0000 Subject: [PATCH 3/4] Add overflow safety changes --- .../cpu/moe/moe_quantization_cpu.cc | 196 +++++++++++------- 1 file changed, 121 insertions(+), 75 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 82141890d0ee5..5a96e25bf6089 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -51,6 +51,14 @@ inline int64_t GetDequantBlockSize(int64_t features, int64_t total_work) { return std::min(target_block_size, work_based_size); } +inline size_t ToSize(int64_t value) { + return onnxruntime::narrow(value); +} + +inline size_t SafeProduct(size_t lhs, size_t rhs) { + return SafeInt(lhs) * rhs; +} + bool CanUseMlasQ4Dequant(int64_t num_bits) { if (num_bits != 4) { return false; @@ -69,6 +77,9 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, return false; } + const size_t rows_size = onnxruntime::narrow(rows); + const size_t cols_size = onnxruntime::narrow(cols); + MLAS_BLK_QUANT_TYPE qtype; switch (block_size) { case 0: @@ -85,8 +96,7 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, return false; } - const size_t pack_size = - MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); + const size_t pack_size = MlasQ4GemmPackBSize(qtype, rows_size, cols_size); if (pack_size == 0) { return false; @@ -104,11 +114,15 @@ void TransposeFp32RowMajorToColumnMajor(const float* src, return; } + const size_t rows_size = onnxruntime::narrow(rows); + const size_t cols_size = onnxruntime::narrow(cols); + for (int64_t r = 0; r < rows; ++r) { - const size_t row_offset = static_cast(r) * static_cast(cols); + const size_t r_index = onnxruntime::narrow(r); + const size_t row_offset = SafeInt(r_index) * cols_size; for (int64_t c = 0; c < cols; ++c) { - dst[static_cast(c) * static_cast(rows) + static_cast(r)] = - src[row_offset + static_cast(c)]; + const size_t c_index = onnxruntime::narrow(c); + dst[c_index * rows_size + r_index] = src[row_offset + c_index]; } } } @@ -142,24 +156,28 @@ Status ConvertToMlasQ4Format(const uint8_t* quantized_data, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only 4-bit quantization supported for MLAS Q4 format conversion"); } - auto temp_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); + const size_t rows_size = onnxruntime::narrow(rows); + const size_t cols_size = onnxruntime::narrow(cols); + const size_t total_elements = SafeInt(rows_size) * cols_size; + + auto temp_float_buffer = IAllocator::MakeUniquePtr(allocator, total_elements); float* temp_float = temp_float_buffer.get(); DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, temp_float, nullptr); IAllocatorUniquePtr temp_float_col_major_buffer = - IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); + IAllocator::MakeUniquePtr(allocator, total_elements); float* temp_float_col_major = temp_float_col_major_buffer.get(); TransposeFp32RowMajorToColumnMajor(temp_float, temp_float_col_major, rows, cols); - size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); + size_t packed_size = MlasQ4GemmPackBSize(qtype, rows_size, cols_size); if (packed_size == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); } mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); - MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float_col_major, static_cast(rows), static_cast(cols), static_cast(rows)); + MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float_col_major, rows_size, cols_size, rows_size); return Status::OK(); } @@ -175,14 +193,20 @@ Status DirectQ4Gemm(const float* A, MLAS_THREADPOOL* thread_pool) { MLAS_Q4_GEMM_DATA_PARAMS params; params.A = A; - params.lda = static_cast(K); + params.lda = onnxruntime::narrow(K); params.B = mlas_packed_B; params.Bias = bias; params.C = C; - params.ldc = static_cast(N); + params.ldc = onnxruntime::narrow(N); params.OutputProcessor = nullptr; - MlasQ4GemmBatch(qtype, static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, thread_pool); + MlasQ4GemmBatch(qtype, + onnxruntime::narrow(M), + onnxruntime::narrow(N), + onnxruntime::narrow(K), + 1, + ¶ms, + thread_pool); return Status::OK(); } @@ -283,7 +307,7 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data, MlasDequantizeLinear( quantized_data + r * cols, dequantized_data + r * cols, - static_cast(cols), + onnxruntime::narrow(cols), scale, zero_pt); } @@ -432,6 +456,13 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t num_experts = moe_params.num_experts; const int64_t fc1_out_features = inter_size * (swiglu_fusion_ > 0 ? 2 : 1); + const size_t num_tokens_size = ToSize(num_tokens); + const size_t hidden_size_size = ToSize(hidden_size); + const size_t inter_size_size = ToSize(inter_size); + const size_t num_experts_size = ToSize(num_experts); + const size_t fc1_out_features_size = ToSize(fc1_out_features); + const size_t k_size = onnxruntime::narrow(k_); + // Validate block-wise quantization requirements ORT_RETURN_IF_ERROR(moe_helper::ValidateBlockwiseQuantization(block_size_, hidden_size, inter_size)); @@ -441,25 +472,26 @@ Status QMoECPU::Compute(OpKernelContext* context) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - const size_t output_buffer_size = static_cast(output->Shape().Size()); + const size_t output_buffer_size = onnxruntime::narrow(output->Shape().Size()); const T* input_data = input->Data(); IAllocatorUniquePtr router_logits_float_buffer; const float* router_logits_float; if constexpr (std::is_same_v) { - router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); + const size_t logits_count = SafeProduct(num_tokens_size, num_experts_size); + router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, logits_count); router_logits_float = router_logits_float_buffer.get(); MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs->Data()), const_cast(router_logits_float), - static_cast(num_tokens * num_experts)); + logits_count); } else { router_logits_float = reinterpret_cast(router_probs->Data()); } - auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, SafeProduct(num_tokens_size, k_size)); int* route_expert = route_expert_ptr.get(); - auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, SafeProduct(num_tokens_size, k_size)); float* route_scale = route_scale_ptr.get(); const int max_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; @@ -470,7 +502,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { std::vector>> thread_local_expert_token_maps(num_routing_threads); for (auto& map : thread_local_expert_token_maps) { - map.resize(static_cast(num_experts)); + map.resize(num_experts_size); for (auto& expert_tokens : map) { expert_tokens.reserve(32); } @@ -480,52 +512,52 @@ Status QMoECPU::Compute(OpKernelContext* context) const { auto work = concurrency::ThreadPool::PartitionWork(narrow(thread_id), num_routing_threads, static_cast(num_tokens)); auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; - std::vector> sorted_logits(static_cast(num_experts)); - std::vector top_k_exp(static_cast(k_)); + std::vector> sorted_logits(num_experts_size); + std::vector top_k_exp(k_size); for (int64_t i = work.start; i < work.end; ++i) { const float* logits = router_logits_float + i * num_experts; - for (size_t j = 0; j < narrow(num_experts); ++j) { + for (size_t j = 0; j < num_experts_size; ++j) { sorted_logits[j] = {logits[j], j}; } - std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + narrow(k_), sorted_logits.end(), std::greater<>()); float max_logit = sorted_logits[0].first; float sum_exp = 0.0f; - for (size_t j = 0; j < narrow(k_); ++j) { + for (size_t j = 0; j < k_size; ++j) { top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit); sum_exp += top_k_exp[j]; } const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); - for (size_t j = 0; j < narrow(k_); ++j) { + for (size_t j = 0; j < k_size; ++j) { int64_t expert_idx = sorted_logits[j].second; int64_t route_idx = i * k_ + narrow(j); route_expert[route_idx] = narrow(expert_idx); route_scale[route_idx] = top_k_exp[j] * inv_sum; if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights - local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + local_expert_token_map[ToSize(expert_idx)].push_back(route_idx); } } } }); - std::vector> expert_token_map(static_cast(num_experts)); + std::vector> expert_token_map(num_experts_size); for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { size_t total_tokens_for_expert = 0; for (int t = 0; t < num_routing_threads; ++t) { - total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); + total_tokens_for_expert += thread_local_expert_token_maps[t][ToSize(expert_idx)].size(); } - expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); + expert_token_map[ToSize(expert_idx)].reserve(total_tokens_for_expert); for (int t = 0; t < num_routing_threads; ++t) { - auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; + auto& local_tokens = thread_local_expert_token_maps[t][ToSize(expert_idx)]; if (!local_tokens.empty()) { - expert_token_map[static_cast(expert_idx)].insert( - expert_token_map[static_cast(expert_idx)].end(), + expert_token_map[ToSize(expert_idx)].insert( + expert_token_map[ToSize(expert_idx)].end(), local_tokens.begin(), local_tokens.end()); } } @@ -534,18 +566,19 @@ Status QMoECPU::Compute(OpKernelContext* context) const { IAllocatorUniquePtr input_float_buffer; const float* input_float; if constexpr (std::is_same_v) { - input_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + const size_t input_count = SafeProduct(num_tokens_size, hidden_size_size); + input_float_buffer = IAllocator::MakeUniquePtr(allocator, input_count); input_float = input_float_buffer.get(); MlasConvertHalfToFloatBuffer(reinterpret_cast(input_data), const_cast(input_float), - static_cast(num_tokens * hidden_size)); + input_count); } else { input_float = reinterpret_cast(input_data); } const int max_expert_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; const int64_t total_expert_work = std::accumulate(expert_token_map.begin(), expert_token_map.end(), 0LL, - [](int64_t sum, const std::vector& tokens) { return sum + static_cast(tokens.size()); }); + [](int64_t sum, const std::vector& tokens) { return sum + onnxruntime::narrow(tokens.size()); }); // Expert threading uses larger divisor (8x vs 4x for routing) because expert processing involves // heavier computation (GEMM operations) that benefits more from parallelization const int64_t expert_thread_divisor = std::max(1, max_expert_threads * 8); @@ -556,9 +589,11 @@ Status QMoECPU::Compute(OpKernelContext* context) const { int num_expert_threads = (tp == nullptr || total_expert_work < min_expert_work_per_thread) ? 1 : std::min(narrow(total_expert_work / std::max(int64_t{1}, min_expert_work_per_thread)), std::min(narrow(num_experts), max_expert_threads)); if (num_expert_threads == 0) num_expert_threads = 1; - auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, + SafeProduct(onnxruntime::narrow(num_expert_threads), output_buffer_size)); float* thread_local_outputs = thread_local_outputs_ptr.get(); - std::memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); + const size_t thread_local_outputs_bytes = SafeProduct(SafeProduct(onnxruntime::narrow(num_expert_threads), output_buffer_size), sizeof(float)); + std::memset(thread_local_outputs, 0, thread_local_outputs_bytes); size_t max_tokens_per_expert = 0; for (const auto& tokens : expert_token_map) { @@ -569,21 +604,23 @@ Status QMoECPU::Compute(OpKernelContext* context) const { return (size + 63) & ~63; }; - const size_t A1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); - const size_t C1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(fc1_out_features)); - const size_t A2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(inter_size)); - const size_t C2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); - const size_t B1_dequant_size = align_size(static_cast(fc1_out_features) * static_cast(hidden_size)); - const size_t B2_dequant_size = align_size(static_cast(hidden_size) * static_cast(inter_size)); + const size_t A1_size = align_size(SafeProduct(max_tokens_per_expert, hidden_size_size)); + const size_t C1_size = align_size(SafeProduct(max_tokens_per_expert, fc1_out_features_size)); + const size_t A2_size = align_size(SafeProduct(max_tokens_per_expert, inter_size_size)); + const size_t C2_size = align_size(SafeProduct(max_tokens_per_expert, hidden_size_size)); + const size_t B1_dequant_size = align_size(SafeProduct(fc1_out_features_size, hidden_size_size)); + const size_t B2_dequant_size = align_size(SafeProduct(hidden_size_size, inter_size_size)); const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + B1_dequant_size + B2_dequant_size; - auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * workspace_elements_per_thread); + auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, + SafeProduct(onnxruntime::narrow(num_expert_threads), workspace_elements_per_thread)); float* workspace = workspace_ptr.get(); auto bias_conversion_buffers_ptr = IAllocator::MakeUniquePtr(allocator, - static_cast(num_expert_threads) * (static_cast(fc1_out_features) + static_cast(hidden_size))); + SafeProduct(onnxruntime::narrow(num_expert_threads), + fc1_out_features_size + hidden_size_size)); float* bias_conversion_buffers = bias_conversion_buffers_ptr.get(); const auto& fc1_scales_dims = fc1_scales->Shape().GetDims(); @@ -608,7 +645,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { size_t total_work = 0; for (int64_t i = 0; i < num_experts; ++i) { - const size_t token_count = expert_token_map[static_cast(i)].size(); + const size_t token_count = expert_token_map[ToSize(i)].size(); if (token_count > 0) { expert_workload.emplace_back(i, token_count); total_work += token_count; @@ -640,24 +677,27 @@ Status QMoECPU::Compute(OpKernelContext* context) const { std::sort(expert_workload.begin(), expert_workload.end(), [](const auto& a, const auto& b) { return a.second > b.second; }); + const size_t num_expert_threads_size = onnxruntime::narrow(num_expert_threads); + std::vector> expert_batches(num_expert_threads); size_t thread_idx = 0; for (const auto& work : expert_workload) { expert_batches[thread_idx].push_back(work.first); - thread_idx = (thread_idx + 1) % static_cast(num_expert_threads); + thread_idx = (thread_idx + 1) % num_expert_threads_size; } concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { const int thread_id = narrow(thread_id_pd); - const auto& expert_batch = expert_batches[static_cast(thread_id)]; + const auto& expert_batch = expert_batches[onnxruntime::narrow(thread_id)]; - float* thread_workspace = workspace + static_cast(thread_id) * workspace_elements_per_thread; + const size_t thread_offset = SafeProduct(onnxruntime::narrow(thread_id), workspace_elements_per_thread); + float* thread_workspace = workspace + thread_offset; - float* thread_bias1_buffer = bias_conversion_buffers + static_cast(thread_id) * (static_cast(fc1_out_features) + static_cast(hidden_size)); - float* thread_bias2_buffer = thread_bias1_buffer + static_cast(fc1_out_features); + float* thread_bias1_buffer = bias_conversion_buffers + SafeProduct(onnxruntime::narrow(thread_id), fc1_out_features_size + hidden_size_size); + float* thread_bias2_buffer = thread_bias1_buffer + fc1_out_features_size; for (int64_t expert_idx : expert_batch) { - const auto& routes = expert_token_map[static_cast(expert_idx)]; + const auto& routes = expert_token_map[ToSize(expert_idx)]; if (routes.empty()) { continue; } @@ -680,51 +720,54 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t end_idx = std::min(start_idx + dynamic_block_size, num_expert_tokens); for (int64_t i = start_idx; i < end_idx; ++i) { - const int64_t token_idx = routes[static_cast(i)] / k_; + const int64_t token_idx = routes[ToSize(i)] / k_; const float* src = input_float + token_idx * hidden_size; float* dst = A1 + i * hidden_size; - std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + std::memcpy(dst, src, SafeProduct(hidden_size_size, sizeof(float))); } }); } else { for (int64_t i = 0; i < num_expert_tokens; ++i) { - const int64_t token_idx = routes[static_cast(i)] / k_; + const int64_t token_idx = routes[ToSize(i)] / k_; const float* src = input_float + token_idx * hidden_size; float* dst = A1 + i * hidden_size; if (ShouldUseMemcpy(hidden_size)) { - std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + std::memcpy(dst, src, SafeProduct(hidden_size_size, sizeof(float))); } else { const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); size_t j = 0; - for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { + for (; j + unroll_factor <= hidden_size_size; j += unroll_factor) { for (size_t k = 0; k < unroll_factor; ++k) { dst[j + k] = src[j + k]; } } - for (; j < narrow(hidden_size); ++j) { + for (; j < hidden_size_size; ++j) { dst[j] = src[j]; } } } } + const int64_t fc1_blocks_per_row = is_fc1_block_wise ? (hidden_size + block_size_ - 1) / block_size_ : 1; + const size_t fc1_blocks_per_row_size = ToSize(fc1_blocks_per_row); + const T* fc1_scales_ptr; if (is_fc1_block_wise) { - const int64_t fc1_blocks_per_row = (hidden_size + block_size_ - 1) / block_size_; - fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features * fc1_blocks_per_row; + const size_t scale_offset = SafeProduct(SafeProduct(ToSize(expert_idx), fc1_out_features_size), fc1_blocks_per_row_size); + fc1_scales_ptr = fc1_scales_data + scale_offset; } else { - fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features; + fc1_scales_ptr = fc1_scales_data + SafeProduct(ToSize(expert_idx), fc1_out_features_size); } const int64_t dequant_block_size = GetDequantBlockSize(fc1_out_features, num_expert_tokens); const int64_t num_dequant_blocks = (fc1_out_features + dequant_block_size - 1) / dequant_block_size; - const size_t m = static_cast(num_expert_tokens); - const size_t n = static_cast(fc1_out_features); - const size_t k = static_cast(hidden_size); + const size_t m = ToSize(num_expert_tokens); + const size_t n = fc1_out_features_size; + const size_t k = hidden_size_size; MLAS_BLK_QUANT_TYPE q_type; bool use_direct_q4_gemm = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, @@ -781,14 +824,14 @@ Status QMoECPU::Compute(OpKernelContext* context) const { concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_dequant_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_row = block_idx * dequant_block_size; const int64_t end_row = std::min(start_row + dequant_block_size, fc1_out_features); - const auto offset = expert_idx * fc1_out_features * fc1_packed_cols + start_row * fc1_packed_cols; + const auto offset = SafeProduct(SafeProduct(ToSize(expert_idx), fc1_out_features_size), ToSize(fc1_packed_cols)) + SafeProduct(ToSize(start_row), ToSize(fc1_packed_cols)); DequantizeBlock(fc1_weights_data + offset, - fc1_scales_ptr + (is_fc1_block_wise ? start_row * fc1_blocks_per_row : start_row), + fc1_scales_ptr + (is_fc1_block_wise ? SafeProduct(ToSize(start_row), fc1_blocks_per_row_size) : ToSize(start_row)), is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, - end_row - start_row, hidden_size, B1_dequant + start_row * hidden_size, nullptr); + end_row - start_row, hidden_size, B1_dequant + SafeProduct(ToSize(start_row), hidden_size_size), nullptr); }); } else { - DequantizeBlock(fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, + DequantizeBlock(fc1_weights_data + SafeProduct(SafeProduct(ToSize(expert_idx), fc1_out_features_size), ToSize(fc1_packed_cols)), fc1_scales_ptr, is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, fc1_out_features, hidden_size, B1_dequant, tp); @@ -875,13 +918,16 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } } + const int64_t fc2_blocks_per_row = is_fc2_block_wise ? (inter_size + block_size_ - 1) / block_size_ : 1; + const size_t fc2_blocks_per_row_size = ToSize(fc2_blocks_per_row); + const T* fc2_scales_ptr; if (is_fc2_block_wise) { - const int64_t fc2_blocks_per_row = (inter_size + block_size_ - 1) / block_size_; - fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size * fc2_blocks_per_row; + const size_t scale_offset = SafeProduct(SafeProduct(ToSize(expert_idx), hidden_size_size), fc2_blocks_per_row_size); + fc2_scales_ptr = fc2_scales_data + scale_offset; } else { - fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size; + fc2_scales_ptr = fc2_scales_data + SafeProduct(ToSize(expert_idx), hidden_size_size); } const int64_t fc2_dequant_block_size = GetDequantBlockSize(hidden_size, num_expert_tokens); @@ -946,14 +992,14 @@ Status QMoECPU::Compute(OpKernelContext* context) const { concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_fc2_dequant_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_row = block_idx * fc2_dequant_block_size; const int64_t end_row = std::min(start_row + fc2_dequant_block_size, hidden_size); - const auto offset = expert_idx * hidden_size * fc2_packed_cols + start_row * fc2_packed_cols; + const auto offset = SafeProduct(SafeProduct(ToSize(expert_idx), hidden_size_size), ToSize(fc2_packed_cols)) + SafeProduct(ToSize(start_row), ToSize(fc2_packed_cols)); DequantizeBlock(fc2_weights_data + offset, - fc2_scales_ptr + (is_fc2_block_wise ? start_row * fc2_blocks_per_row : start_row), + fc2_scales_ptr + (is_fc2_block_wise ? SafeProduct(ToSize(start_row), fc2_blocks_per_row_size) : ToSize(start_row)), is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, - end_row - start_row, inter_size, B2_dequant + start_row * inter_size, nullptr); + end_row - start_row, inter_size, B2_dequant + SafeProduct(ToSize(start_row), inter_size_size), nullptr); }); } else { - DequantizeBlock(fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, + DequantizeBlock(fc2_weights_data + SafeProduct(SafeProduct(ToSize(expert_idx), hidden_size_size), ToSize(fc2_packed_cols)), fc2_scales_ptr, is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, hidden_size, inter_size, B2_dequant, tp); From d57a7c3e1c8f65b8e2e42f23746c161b1f291b25 Mon Sep 17 00:00:00 2001 From: apsonawane Date: Fri, 26 Sep 2025 20:24:07 +0000 Subject: [PATCH 4/4] Disable quantized Mlas, still not giving good tps --- .../cpu/moe/moe_quantization_cpu.cc | 334 +----------------- 1 file changed, 5 insertions(+), 329 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 5a96e25bf6089..8d055d20c18e8 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -59,157 +59,11 @@ inline size_t SafeProduct(size_t lhs, size_t rhs) { return SafeInt(lhs) * rhs; } -bool CanUseMlasQ4Dequant(int64_t num_bits) { - if (num_bits != 4) { - return false; - } - - return true; -} - -bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, - int64_t rows, int64_t cols, MLAS_BLK_QUANT_TYPE& out_qtype) { - if (expert_weight_bits != 4) { - return false; - } - - if (rows <= 0 || cols <= 0) { - return false; - } - - const size_t rows_size = onnxruntime::narrow(rows); - const size_t cols_size = onnxruntime::narrow(cols); - - MLAS_BLK_QUANT_TYPE qtype; - switch (block_size) { - case 0: - case 32: - qtype = BlkQ4Sym; - break; - case 64: - qtype = BlkQ4Sym64; - break; - case 128: - qtype = BlkQ4Sym128; - break; - default: - return false; - } - - const size_t pack_size = MlasQ4GemmPackBSize(qtype, rows_size, cols_size); - - if (pack_size == 0) { - return false; - } - - out_qtype = qtype; - return true; -} - -void TransposeFp32RowMajorToColumnMajor(const float* src, - float* dst, - int64_t rows, - int64_t cols) { - if (rows <= 0 || cols <= 0) { - return; - } - - const size_t rows_size = onnxruntime::narrow(rows); - const size_t cols_size = onnxruntime::narrow(cols); - - for (int64_t r = 0; r < rows; ++r) { - const size_t r_index = onnxruntime::narrow(r); - const size_t row_offset = SafeInt(r_index) * cols_size; - for (int64_t c = 0; c < cols; ++c) { - const size_t c_index = onnxruntime::narrow(c); - dst[c_index * rows_size + r_index] = src[row_offset + c_index]; - } - } -} - } // namespace namespace onnxruntime { namespace contrib { -template -void DequantizeBlockWithMlas(const uint8_t* quantized_data, - const TScale* scales, - int64_t block_size, - int64_t num_bits, - int64_t rows, - int64_t cols, - float* dequantized_data, - MLAS_THREADPOOL* thread_pool); - -template -Status ConvertToMlasQ4Format(const uint8_t* quantized_data, - const TScale* scales, - int64_t block_size, - int64_t num_bits, - int64_t rows, - int64_t cols, - MLAS_BLK_QUANT_TYPE qtype, - AllocatorPtr allocator, - IAllocatorUniquePtr& mlas_packed_buffer) { - if (num_bits != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only 4-bit quantization supported for MLAS Q4 format conversion"); - } - - const size_t rows_size = onnxruntime::narrow(rows); - const size_t cols_size = onnxruntime::narrow(cols); - const size_t total_elements = SafeInt(rows_size) * cols_size; - - auto temp_float_buffer = IAllocator::MakeUniquePtr(allocator, total_elements); - float* temp_float = temp_float_buffer.get(); - - DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, temp_float, nullptr); - - IAllocatorUniquePtr temp_float_col_major_buffer = - IAllocator::MakeUniquePtr(allocator, total_elements); - float* temp_float_col_major = temp_float_col_major_buffer.get(); - - TransposeFp32RowMajorToColumnMajor(temp_float, temp_float_col_major, rows, cols); - - size_t packed_size = MlasQ4GemmPackBSize(qtype, rows_size, cols_size); - if (packed_size == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); - } - - mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); - MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float_col_major, rows_size, cols_size, rows_size); - - return Status::OK(); -} - -Status DirectQ4Gemm(const float* A, - const uint8_t* mlas_packed_B, - const float* bias, - float* C, - int64_t M, - int64_t N, - int64_t K, - MLAS_BLK_QUANT_TYPE qtype, - MLAS_THREADPOOL* thread_pool) { - MLAS_Q4_GEMM_DATA_PARAMS params; - params.A = A; - params.lda = onnxruntime::narrow(K); - params.B = mlas_packed_B; - params.Bias = bias; - params.C = C; - params.ldc = onnxruntime::narrow(N); - params.OutputProcessor = nullptr; - - MlasQ4GemmBatch(qtype, - onnxruntime::narrow(M), - onnxruntime::narrow(N), - onnxruntime::narrow(K), - 1, - ¶ms, - thread_pool); - return Status::OK(); -} - template void DequantizeBlockWithMlas(const uint8_t* quantized_data, const TScale* scales, @@ -223,82 +77,6 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data, const float zero_point = num_bits == 8 ? 128.0f : 8.0f; const int64_t blocks_per_row = (block_size > 0) ? ((cols + block_size - 1) / block_size) : 1; - if (CanUseMlasQ4Dequant(num_bits)) { - const int64_t packed_cols = (cols + 1) / 2; - - if (block_size == 0) { - for (int64_t r = 0; r < rows; ++r) { - const uint8_t* row_data = quantized_data + r * packed_cols; - float* row_output = dequantized_data + r * cols; - const float scale = static_cast(scales[r]); - - int64_t c = 0; - for (; c + 8 <= cols; c += 8) { - const uint8_t packed_val0 = row_data[(c + 0) / 2]; - const uint8_t packed_val1 = row_data[(c + 2) / 2]; - const uint8_t packed_val2 = row_data[(c + 4) / 2]; - const uint8_t packed_val3 = row_data[(c + 6) / 2]; - - row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); - row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); - row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); - row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); - row_output[c + 4] = scale * (static_cast(packed_val2 & 0x0F) - zero_point); - row_output[c + 5] = scale * (static_cast(packed_val2 >> 4) - zero_point); - row_output[c + 6] = scale * (static_cast(packed_val3 & 0x0F) - zero_point); - row_output[c + 7] = scale * (static_cast(packed_val3 >> 4) - zero_point); - } - - for (; c < cols; c += 2) { - const uint8_t packed_val = row_data[c / 2]; - const uint8_t val0 = packed_val & 0x0F; - const uint8_t val1 = packed_val >> 4; - - row_output[c] = scale * (static_cast(val0) - zero_point); - if (c + 1 < cols) { - row_output[c + 1] = scale * (static_cast(val1) - zero_point); - } - } - } - return; - } else { - for (int64_t r = 0; r < rows; ++r) { - const uint8_t* row_data = quantized_data + r * packed_cols; - float* row_output = dequantized_data + r * cols; - - for (int64_t block_start = 0; block_start < cols; block_start += block_size) { - const int64_t block_end = std::min(block_start + block_size, cols); - const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); - const int64_t scale_idx = r * blocks_per_row + block_idx; - const float scale = static_cast(scales[scale_idx]); - - int64_t c = block_start; - for (; c + 4 <= block_end; c += 4) { - const uint8_t packed_val0 = row_data[(c + 0) / 2]; - const uint8_t packed_val1 = row_data[(c + 2) / 2]; - - row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); - row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); - row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); - row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); - } - - for (; c < block_end; c += 2) { - const uint8_t packed_val = row_data[c / 2]; - const uint8_t val0 = packed_val & 0x0F; - const uint8_t val1 = packed_val >> 4; - - row_output[c] = scale * (static_cast(val0) - zero_point); - if (c + 1 < block_end) { - row_output[c + 1] = scale * (static_cast(val1) - zero_point); - } - } - } - } - return; - } - } - if (num_bits == 8 && block_size == 0) { for (int64_t r = 0; r < rows; ++r) { const float scale = static_cast(scales[r]); @@ -769,55 +547,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t n = fc1_out_features_size; const size_t k = hidden_size_size; - MLAS_BLK_QUANT_TYPE q_type; - bool use_direct_q4_gemm = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, - fc1_out_features, hidden_size, q_type); - bool fc1_used_direct_q4 = false; - bool fc1_bias_handled_by_q4_gemm = false; - - if (use_direct_q4_gemm) { - IAllocatorUniquePtr mlas_packed_fc1; - Status convert_status = ConvertToMlasQ4Format( - fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, - fc1_scales_ptr, - is_fc1_block_wise ? block_size_ : 0, - expert_weight_bits_, - fc1_out_features, - hidden_size, - q_type, - allocator, - mlas_packed_fc1); - - if (convert_status.IsOK()) { - float* fc1_bias_float = nullptr; - IAllocatorUniquePtr fc1_bias_buffer; - - if (has_fc1_bias) { - const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; - fc1_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(fc1_out_features)); - fc1_bias_float = fc1_bias_buffer.get(); - - if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), fc1_bias_float, static_cast(fc1_out_features)); - } else { - for (int64_t i = 0; i < fc1_out_features; ++i) { - fc1_bias_float[i] = static_cast(B1_bias[i]); - } - } - } - - Status gemm_status = DirectQ4Gemm(A1, mlas_packed_fc1.get(), fc1_bias_float, C1, - num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); - - if (gemm_status.IsOK()) { - fc1_used_direct_q4 = true; - goto fc1_gemm_done; - } - } - // If direct Q4 GEMM failed, fall back to traditional approach - } - - // Traditional approach: dequantize + regular GEMM + // Quantized MLAS is disabled, using traditional approach: dequantize + regular GEMM // Use parallel dequantization when we have multiple blocks and sufficient work per thread // Threshold of 32 features ensures each thread has meaningful work to justify threading overhead if (num_dequant_blocks > 1 && fc1_out_features >= 32) { @@ -844,8 +574,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { 0.0f, C1, n, tp); - fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias; - if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) { + if (has_fc1_bias) { const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); @@ -882,8 +611,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } } - fc1_gemm_done: - // Activation threshold scales inversely with inter_size to balance work per thread // For large inter_size, fewer tokens justify parallel activation; for small inter_size, more tokens needed // Base value of 256 ensures reasonable work distribution across different model configurations @@ -937,55 +664,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t n2 = static_cast(hidden_size); const size_t k2 = static_cast(inter_size); - MLAS_BLK_QUANT_TYPE q_type2; - bool use_direct_q4_gemm_fc2 = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, - hidden_size, inter_size, q_type2); - bool fc2_used_direct_q4 = false; - - if (use_direct_q4_gemm_fc2) { - IAllocatorUniquePtr mlas_packed_fc2; - Status convert_status = ConvertToMlasQ4Format( - fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, - fc2_scales_ptr, - is_fc2_block_wise ? block_size_ : 0, - expert_weight_bits_, - hidden_size, - inter_size, - q_type2, - allocator, - mlas_packed_fc2); - - if (convert_status.IsOK()) { - float* fc2_bias_float = nullptr; - IAllocatorUniquePtr fc2_bias_buffer; - - if (has_fc2_bias) { - const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; - fc2_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(hidden_size)); - fc2_bias_float = fc2_bias_buffer.get(); - - if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), fc2_bias_float, static_cast(hidden_size)); - } else { - for (int64_t i = 0; i < hidden_size; ++i) { - fc2_bias_float[i] = static_cast(B2_bias[i]); - } - } - } - - Status gemm_status = DirectQ4Gemm(A2, mlas_packed_fc2.get(), fc2_bias_float, C2, - num_expert_tokens, hidden_size, inter_size, q_type2, tp); - - if (gemm_status.IsOK()) { - fc2_used_direct_q4 = true; - goto fc2_gemm_done; - } - } - - // If direct Q4 GEMM failed, fall back to traditional approach - } - - // Traditional approach: dequantize + regular GEMM + // Quantized MLAS is disabled, using traditional approach: dequantize + regular GEMM // Use parallel dequantization when we have multiple blocks and sufficient work per thread // Threshold of 32 for hidden_size ensures each thread has meaningful work to justify threading overhead if (num_fc2_dequant_blocks > 1 && hidden_size >= 32) { @@ -1012,10 +691,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { 0.0f, C2, n2, tp); - fc2_gemm_done: - - bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias) { const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); @@ -1050,7 +726,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias) { const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); size_t j = 0; for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) {