From cb6939de344e28150b660246ba036cc3ce2634dd Mon Sep 17 00:00:00 2001 From: asonawane Date: Mon, 15 Sep 2025 22:15:50 +0000 Subject: [PATCH 1/3] QMoE kernel further optimizations --- .../cpu/moe/moe_quantization_cpu.cc | 85 ++++++++++++++++--- 1 file changed, 73 insertions(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 8195c9438d408..4c8b758fc248d 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -204,6 +204,14 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data, const int64_t block_end = std::min(block_start + block_size, cols); const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); const int64_t scale_idx = r * blocks_per_row + block_idx; + + // Validate scale index bounds + const int64_t max_scale_idx = rows * blocks_per_row; + if (scale_idx < 0 || scale_idx >= max_scale_idx) { + // Skip this block if scale index is invalid + continue; + } + const float scale = static_cast(scales[scale_idx]); int64_t c = block_start; @@ -257,6 +265,14 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data, const int64_t block_end = std::min(block_start + block_size, cols); const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); const int64_t scale_idx = r * blocks_per_row + block_idx; + + // Validate scale index bounds for 8-bit case + const int64_t max_scale_idx = rows * blocks_per_row; + if (scale_idx < 0 || scale_idx >= max_scale_idx) { + // Skip this block if scale index is invalid + continue; + } + const float scale = static_cast(scales[scale_idx]); for (c = block_start; c + 4 <= block_end; c += 4) { @@ -297,6 +313,14 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data, const int64_t block_end = std::min(block_start + block_size, cols); const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); const int64_t scale_idx = r * blocks_per_row + block_idx; + + // Validate scale index bounds for 4-bit case + const int64_t max_scale_idx = rows * blocks_per_row; + if (scale_idx < 0 || scale_idx >= max_scale_idx) { + // Skip this block if scale index is invalid + continue; + } + const float scale = static_cast(scales[scale_idx]); for (int64_t c = block_start; c < block_end; c += 2) { @@ -516,16 +540,13 @@ Status QMoECPU::Compute(OpKernelContext* context) const { max_tokens_per_expert = std::max(max_tokens_per_expert, tokens.size()); } - const auto align_size = [](size_t size) -> size_t { - return (size + 63) & ~63; - }; - - const size_t A1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); - const size_t C1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(fc1_out_features)); - const size_t A2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(inter_size)); - const size_t C2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); - const size_t B1_dequant_size = align_size(static_cast(fc1_out_features) * static_cast(hidden_size)); - const size_t B2_dequant_size = align_size(static_cast(hidden_size) * static_cast(inter_size)); + // Use consistent buffer sizes - no alignment needed for float arrays + const size_t A1_size = static_cast(max_tokens_per_expert) * static_cast(hidden_size); + const size_t C1_size = static_cast(max_tokens_per_expert) * static_cast(fc1_out_features); + const size_t A2_size = static_cast(max_tokens_per_expert) * static_cast(inter_size); + const size_t C2_size = static_cast(max_tokens_per_expert) * static_cast(hidden_size); + const size_t B1_dequant_size = static_cast(fc1_out_features) * static_cast(hidden_size); + const size_t B2_dequant_size = static_cast(hidden_size) * static_cast(inter_size); const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + B1_dequant_size + B2_dequant_size; @@ -601,6 +622,13 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t num_expert_tokens = static_cast(routes.size()); + // Validate that the number of tokens doesn't exceed our allocation + if (static_cast(num_expert_tokens) > max_tokens_per_expert) { + LOGS_DEFAULT(ERROR) << "Expert " << expert_idx << " has " << num_expert_tokens + << " tokens but workspace allocated for max " << max_tokens_per_expert; + continue; + } + float* A1 = thread_workspace; float* C1 = A1 + A1_size; float* A2 = C1 + C1_size; @@ -617,7 +645,15 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t end_idx = std::min(start_idx + dynamic_block_size, num_expert_tokens); for (int64_t i = start_idx; i < end_idx; ++i) { - const int64_t token_idx = routes[static_cast(i)] / k_; + const int64_t route_idx = routes[static_cast(i)]; + // Validate route index to prevent division by zero or invalid token indices + if (route_idx < 0 || k_ <= 0) { + continue; // Skip invalid routes + } + const int64_t token_idx = route_idx / k_; + if (token_idx < 0 || token_idx >= num_tokens) { + continue; // Skip out-of-bounds token indices + } const float* src = input_float + token_idx * hidden_size; float* dst = A1 + i * hidden_size; @@ -626,7 +662,15 @@ Status QMoECPU::Compute(OpKernelContext* context) const { }); } else { for (int64_t i = 0; i < num_expert_tokens; ++i) { - const int64_t token_idx = routes[static_cast(i)] / k_; + const int64_t route_idx = routes[static_cast(i)]; + // Validate route index to prevent division by zero or invalid token indices + if (route_idx < 0 || k_ <= 0) { + continue; // Skip invalid routes + } + const int64_t token_idx = route_idx / k_; + if (token_idx < 0 || token_idx >= num_tokens) { + continue; // Skip out-of-bounds token indices + } const float* src = input_float + token_idx * hidden_size; float* dst = A1 + i * hidden_size; @@ -931,7 +975,18 @@ Status QMoECPU::Compute(OpKernelContext* context) const { for (int64_t i = 0; i < num_expert_tokens; ++i) { const int64_t route_idx = routes[static_cast(i)]; + // Validate route index to prevent division by zero or invalid token indices + if (route_idx < 0 || k_ <= 0) { + continue; // Skip invalid routes + } const int64_t token_idx = route_idx / k_; + if (token_idx < 0 || token_idx >= num_tokens) { + continue; // Skip out-of-bounds token indices + } + // Validate route_scale array bounds + if (route_idx < 0 || route_idx >= num_tokens * k_) { + continue; // Skip if route_idx would be out of bounds for route_scale array + } const float weight = route_scale[route_idx]; if (token_idx < 0 || token_idx >= num_tokens) continue; @@ -939,6 +994,12 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t buffer_offset = static_cast(token_idx) * static_cast(hidden_size); if (buffer_offset + static_cast(hidden_size) > output_buffer_size) continue; + // Simplified thread buffer validation + if (thread_id < 0 || thread_id >= num_expert_threads) continue; + + // Validate source buffer bounds + if (i < 0 || i >= num_expert_tokens) continue; + float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; From 4bfc50894ce758e3896fb9f05552ad889455fc6a Mon Sep 17 00:00:00 2001 From: asonawane Date: Mon, 15 Sep 2025 23:19:00 +0000 Subject: [PATCH 2/3] Address comments --- .../cpu/moe/moe_quantization_cpu.cc | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 4c8b758fc248d..2a45a8a4ee51a 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -444,7 +444,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int max_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; const int64_t thread_divisor = std::max(1, max_threads * 4); const int64_t min_work_per_thread = std::max(int64_t{32}, static_cast(num_tokens / thread_divisor)); - const int optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(static_cast(num_tokens / std::max(int64_t{1}, min_work_per_thread)), max_threads); + const int optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(static_cast(num_tokens / min_work_per_thread), max_threads); const int num_routing_threads = std::max(1, optimal_routing_threads); std::vector>> thread_local_expert_token_maps(num_routing_threads); @@ -587,6 +587,12 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } } + // Adjust thread count based on total work to avoid thread overhead for small workloads + // These thresholds are based on empirical performance testing: + // - < 48 tokens: Single thread is most efficient due to low overhead + // - 48-191 tokens: Cap at 2 threads to balance parallelism vs overhead + // - 192-511 tokens: Cap at 4 threads for good CPU utilization + // - >= 512 tokens: Use full calculated thread count for maximum parallelism if (total_work < 48) { num_expert_threads = 1; } else if (total_work < 192) { @@ -738,8 +744,12 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), fc1_bias_float, static_cast(fc1_out_features)); } else { - for (int64_t i = 0; i < fc1_out_features; ++i) { - fc1_bias_float[i] = static_cast(B1_bias[i]); + if (ShouldUseMemcpy(fc1_out_features)) { + std::memcpy(fc1_bias_float, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } else { + for (int64_t i = 0; i < fc1_out_features; ++i) { + fc1_bias_float[i] = static_cast(B1_bias[i]); + } } } } @@ -760,6 +770,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } // Traditional approach: dequantize + regular GEMM + // Use parallel dequantization when: + // 1. num_dequant_blocks > 1: Multiple blocks to parallelize across + // 2. fc1_out_features >= 32: Sufficient work per thread to justify overhead + // (32 features * hidden_size elements = substantial work per block) if (num_dequant_blocks > 1 && fc1_out_features >= 32) { concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_dequant_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_row = block_idx * dequant_block_size; @@ -824,7 +838,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc1_gemm_done: - const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); + const int64_t activation_threshold = std::max(int64_t{4}, 256 / inter_size); if (num_expert_tokens >= activation_threshold && tp != nullptr) { const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; @@ -901,9 +915,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), fc2_bias_float, static_cast(hidden_size)); } else { - for (int64_t i = 0; i < hidden_size; ++i) { - fc2_bias_float[i] = static_cast(B2_bias[i]); - } + std::memcpy(fc2_bias_float, B2_bias, static_cast(hidden_size) * sizeof(float)); } } @@ -924,6 +936,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } // Traditional approach: dequantize + regular GEMM + // Use parallel dequantization when: + // 1. num_fc2_dequant_blocks > 1: Multiple blocks to parallelize across + // 2. hidden_size >= 32: Sufficient work per thread to justify overhead + // (32 features * inter_size elements = substantial work per block) if (num_fc2_dequant_blocks > 1 && hidden_size >= 32) { concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_fc2_dequant_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_row = block_idx * fc2_dequant_block_size; From 5c04dc95f0bea4851b823fe1a9928e1e6ecf5550 Mon Sep 17 00:00:00 2001 From: asonawane Date: Wed, 17 Sep 2025 04:35:17 +0000 Subject: [PATCH 3/3] address comments --- .../contrib_ops/cpu/moe/moe_base_cpu.h | 1 + .../cpu/moe/moe_quantization_cpu.cc | 82 +++++++++++-------- 2 files changed, 49 insertions(+), 34 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h index 84580b310f6b3..8e7473ff804a6 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h @@ -24,6 +24,7 @@ class MoEBaseCPU { protected: MoEBaseCPU(const OpKernelInfo& op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); + ORT_ENFORCE(k_ > 0, "k must be positive, got: ", k_); std::string activation_type_str; ORT_ENFORCE(op_kernel_info.GetAttr("activation_type", &activation_type_str).IsOK()); diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 2a45a8a4ee51a..2140db8490018 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -64,10 +64,12 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, return false; } + // Disable direct MLAS Q4 GEMM for block-wise quantization to avoid double conversion errors + // Use traditional dequantization path which has been fixed for correct scale indexing if (block_size == 64) { - out_qtype = BlkQ4Sym64; + return false; // Force traditional path } else if (block_size == 128) { - out_qtype = BlkQ4Sym128; + return false; // Force traditional path } else if (block_size == 0) { out_qtype = BlkQ4Sym; } else { @@ -202,7 +204,7 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data, for (int64_t block_start = 0; block_start < cols; block_start += block_size) { const int64_t block_end = std::min(block_start + block_size, cols); - const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t block_idx = block_start / block_size; const int64_t scale_idx = r * blocks_per_row + block_idx; // Validate scale index bounds @@ -263,7 +265,7 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data, if (block_size > 0) { for (int64_t block_start = 0; block_start < cols; block_start += block_size) { const int64_t block_end = std::min(block_start + block_size, cols); - const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t block_idx = block_start / block_size; const int64_t scale_idx = r * blocks_per_row + block_idx; // Validate scale index bounds for 8-bit case @@ -311,7 +313,7 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data, if (block_size > 0) { for (int64_t block_start = 0; block_start < cols; block_start += block_size) { const int64_t block_end = std::min(block_start + block_size, cols); - const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t block_idx = block_start / block_size; const int64_t scale_idx = r * blocks_per_row + block_idx; // Validate scale index bounds for 4-bit case @@ -443,8 +445,18 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int max_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; const int64_t thread_divisor = std::max(1, max_threads * 4); - const int64_t min_work_per_thread = std::max(int64_t{32}, static_cast(num_tokens / thread_divisor)); - const int optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(static_cast(num_tokens / min_work_per_thread), max_threads); + + // For decoding (small num_tokens), use more aggressive parallelization + // For prefill (large num_tokens), ensure sufficient work per thread + int optimal_routing_threads; + if (num_tokens <= 4) { + // Small token counts (decoding): use up to 4 threads for better latency + optimal_routing_threads = (tp == nullptr) ? 1 : std::min(4, max_threads); + } else { + // Larger token counts: ensure minimum work per thread to avoid overhead + const int64_t min_work_per_thread = std::max(int64_t{8}, static_cast(num_tokens / thread_divisor)); + optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(static_cast(num_tokens / min_work_per_thread), max_threads); + } const int num_routing_threads = std::max(1, optimal_routing_threads); std::vector>> thread_local_expert_token_maps(num_routing_threads); @@ -554,6 +566,32 @@ Status QMoECPU::Compute(OpKernelContext* context) const { auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * workspace_elements_per_thread); float* workspace = workspace_ptr.get(); + // Only zero-initialize the dequantization buffers that need it, not the entire workspace + // A1, C1, A2, C2 don't need initialization since they're always fully overwritten + const size_t dequant_buffers_size = B1_dequant_size + B2_dequant_size; + const size_t workspace_data_size = A1_size + C1_size + A2_size + C2_size; + + // Zero only the dequantization buffers for each thread + // Use parallel initialization for large buffers to improve performance + if (dequant_buffers_size > 0) { + const size_t total_dequant_size = static_cast(num_expert_threads) * dequant_buffers_size; + const size_t parallel_threshold = 64 * 1024; // 64KB threshold + + if (total_dequant_size > parallel_threshold && tp != nullptr && num_expert_threads > 1) { + // Parallel initialization for large buffers + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t t) { + float* thread_dequant_start = workspace + static_cast(t) * workspace_elements_per_thread + workspace_data_size; + std::memset(thread_dequant_start, 0, dequant_buffers_size * sizeof(float)); + }); + } else { + // Sequential initialization for smaller buffers + for (int t = 0; t < num_expert_threads; ++t) { + float* thread_dequant_start = workspace + static_cast(t) * workspace_elements_per_thread + workspace_data_size; + std::memset(thread_dequant_start, 0, dequant_buffers_size * sizeof(float)); + } + } + } + auto bias_conversion_buffers_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * (static_cast(fc1_out_features) + static_cast(hidden_size))); float* bias_conversion_buffers = bias_conversion_buffers_ptr.get(); @@ -652,14 +690,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { for (int64_t i = start_idx; i < end_idx; ++i) { const int64_t route_idx = routes[static_cast(i)]; - // Validate route index to prevent division by zero or invalid token indices - if (route_idx < 0 || k_ <= 0) { - continue; // Skip invalid routes - } const int64_t token_idx = route_idx / k_; - if (token_idx < 0 || token_idx >= num_tokens) { - continue; // Skip out-of-bounds token indices - } const float* src = input_float + token_idx * hidden_size; float* dst = A1 + i * hidden_size; @@ -669,12 +700,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } else { for (int64_t i = 0; i < num_expert_tokens; ++i) { const int64_t route_idx = routes[static_cast(i)]; - // Validate route index to prevent division by zero or invalid token indices - if (route_idx < 0 || k_ <= 0) { - continue; // Skip invalid routes - } const int64_t token_idx = route_idx / k_; - if (token_idx < 0 || token_idx >= num_tokens) { + if (token_idx >= num_tokens) { continue; // Skip out-of-bounds token indices } const float* src = input_float + token_idx * hidden_size; @@ -991,31 +1018,18 @@ Status QMoECPU::Compute(OpKernelContext* context) const { for (int64_t i = 0; i < num_expert_tokens; ++i) { const int64_t route_idx = routes[static_cast(i)]; - // Validate route index to prevent division by zero or invalid token indices - if (route_idx < 0 || k_ <= 0) { - continue; // Skip invalid routes - } const int64_t token_idx = route_idx / k_; - if (token_idx < 0 || token_idx >= num_tokens) { - continue; // Skip out-of-bounds token indices - } - // Validate route_scale array bounds - if (route_idx < 0 || route_idx >= num_tokens * k_) { - continue; // Skip if route_idx would be out of bounds for route_scale array + if (token_idx >= num_tokens || route_idx >= num_tokens * k_) { + continue; // Skip out-of-bounds indices } const float weight = route_scale[route_idx]; - if (token_idx < 0 || token_idx >= num_tokens) continue; - const size_t buffer_offset = static_cast(token_idx) * static_cast(hidden_size); if (buffer_offset + static_cast(hidden_size) > output_buffer_size) continue; // Simplified thread buffer validation if (thread_id < 0 || thread_id >= num_expert_threads) continue; - // Validate source buffer bounds - if (i < 0 || i >= num_expert_tokens) continue; - float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size;