diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index 39249f842e632..b3801fda0598e 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -33,8 +33,22 @@ struct MoEParameters { MoEParallelType parallel_type{MoEParallelType::None}; int64_t tensor_shards{1}; }; + namespace moe_helper { +// Validate block-wise quantization requirements according to op spec +inline Status ValidateBlockwiseQuantization(int64_t block_size, int64_t hidden_size, int64_t inter_size) { + if (block_size > 0) { + ORT_ENFORCE(hidden_size % block_size == 0, + "For block-wise quantization, hidden_size (", hidden_size, + ") must be divisible by block_size (", block_size, ")."); + ORT_ENFORCE(inter_size % block_size == 0, + "For block-wise quantization, inter_size (", inter_size, + ") must be divisible by block_size (", block_size, ")."); + } + return Status::OK(); +} + template Status CheckInputs(MoEParameters& parameters, const Tensor* input, // required diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 8a3c3f6d9f37a..8d055d20c18e8 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -51,32 +51,12 @@ inline int64_t GetDequantBlockSize(int64_t features, int64_t total_work) { return std::min(target_block_size, work_based_size); } -bool CanUseMlasQ4Dequant(int64_t num_bits) { - if (num_bits != 4) { - return false; - } - - return true; +inline size_t ToSize(int64_t value) { + return onnxruntime::narrow(value); } -bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, - int64_t rows, int64_t cols, MLAS_BLK_QUANT_TYPE& out_qtype) { - if (expert_weight_bits != 4) { - return false; - } - - if (block_size == 64) { - out_qtype = BlkQ4Sym64; - } else if (block_size == 128) { - out_qtype = BlkQ4Sym128; - } else if (block_size == 0) { - out_qtype = BlkQ4Sym; - } else { - return false; - } - - size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(cols), static_cast(rows)); - return expected_size > 0; +inline size_t SafeProduct(size_t lhs, size_t rhs) { + return SafeInt(lhs) * rhs; } } // namespace @@ -84,68 +64,6 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, namespace onnxruntime { namespace contrib { -template -void DequantizeBlockWithMlas(const uint8_t* quantized_data, - const TScale* scales, - int64_t block_size, - int64_t num_bits, - int64_t rows, - int64_t cols, - float* dequantized_data, - MLAS_THREADPOOL* thread_pool); - -template -Status ConvertToMlasQ4Format(const uint8_t* quantized_data, - const TScale* scales, - int64_t block_size, - int64_t num_bits, - int64_t rows, - int64_t cols, - MLAS_BLK_QUANT_TYPE qtype, - AllocatorPtr allocator, - IAllocatorUniquePtr& mlas_packed_buffer) { - if (num_bits != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only 4-bit quantization supported for MLAS Q4 format conversion"); - } - - auto temp_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); - float* temp_float = temp_float_buffer.get(); - - DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, temp_float, nullptr); - - size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(cols), static_cast(rows)); - if (packed_size == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); - } - - mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); - MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast(cols), static_cast(rows), static_cast(cols)); - - return Status::OK(); -} - -Status DirectQ4Gemm(const float* A, - const uint8_t* mlas_packed_B, - const float* bias, - float* C, - int64_t M, - int64_t N, - int64_t K, - MLAS_BLK_QUANT_TYPE qtype, - MLAS_THREADPOOL* thread_pool) { - MLAS_Q4_GEMM_DATA_PARAMS params; - params.A = A; - params.lda = static_cast(K); - params.B = mlas_packed_B; - params.Bias = bias; - params.C = C; - params.ldc = static_cast(N); - params.OutputProcessor = nullptr; - - MlasQ4GemmBatch(qtype, static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, thread_pool); - return Status::OK(); -} - template void DequantizeBlockWithMlas(const uint8_t* quantized_data, const TScale* scales, @@ -159,82 +77,6 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data, const float zero_point = num_bits == 8 ? 128.0f : 8.0f; const int64_t blocks_per_row = (block_size > 0) ? ((cols + block_size - 1) / block_size) : 1; - if (CanUseMlasQ4Dequant(num_bits)) { - const int64_t packed_cols = (cols + 1) / 2; - - if (block_size == 0) { - for (int64_t r = 0; r < rows; ++r) { - const uint8_t* row_data = quantized_data + r * packed_cols; - float* row_output = dequantized_data + r * cols; - const float scale = static_cast(scales[r]); - - int64_t c = 0; - for (; c + 8 <= cols; c += 8) { - const uint8_t packed_val0 = row_data[(c + 0) / 2]; - const uint8_t packed_val1 = row_data[(c + 2) / 2]; - const uint8_t packed_val2 = row_data[(c + 4) / 2]; - const uint8_t packed_val3 = row_data[(c + 6) / 2]; - - row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); - row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); - row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); - row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); - row_output[c + 4] = scale * (static_cast(packed_val2 & 0x0F) - zero_point); - row_output[c + 5] = scale * (static_cast(packed_val2 >> 4) - zero_point); - row_output[c + 6] = scale * (static_cast(packed_val3 & 0x0F) - zero_point); - row_output[c + 7] = scale * (static_cast(packed_val3 >> 4) - zero_point); - } - - for (; c < cols; c += 2) { - const uint8_t packed_val = row_data[c / 2]; - const uint8_t val0 = packed_val & 0x0F; - const uint8_t val1 = packed_val >> 4; - - row_output[c] = scale * (static_cast(val0) - zero_point); - if (c + 1 < cols) { - row_output[c + 1] = scale * (static_cast(val1) - zero_point); - } - } - } - return; - } else { - for (int64_t r = 0; r < rows; ++r) { - const uint8_t* row_data = quantized_data + r * packed_cols; - float* row_output = dequantized_data + r * cols; - - for (int64_t block_start = 0; block_start < cols; block_start += block_size) { - const int64_t block_end = std::min(block_start + block_size, cols); - const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); - const int64_t scale_idx = r * blocks_per_row + block_idx; - const float scale = static_cast(scales[scale_idx]); - - int64_t c = block_start; - for (; c + 4 <= block_end; c += 4) { - const uint8_t packed_val0 = row_data[(c + 0) / 2]; - const uint8_t packed_val1 = row_data[(c + 2) / 2]; - - row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); - row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); - row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); - row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); - } - - for (; c < block_end; c += 2) { - const uint8_t packed_val = row_data[c / 2]; - const uint8_t val0 = packed_val & 0x0F; - const uint8_t val1 = packed_val >> 4; - - row_output[c] = scale * (static_cast(val0) - zero_point); - if (c + 1 < block_end) { - row_output[c + 1] = scale * (static_cast(val1) - zero_point); - } - } - } - } - return; - } - } - if (num_bits == 8 && block_size == 0) { for (int64_t r = 0; r < rows; ++r) { const float scale = static_cast(scales[r]); @@ -243,7 +85,7 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data, MlasDequantizeLinear( quantized_data + r * cols, dequantized_data + r * cols, - static_cast(cols), + onnxruntime::narrow(cols), scale, zero_pt); } @@ -392,31 +234,42 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t num_experts = moe_params.num_experts; const int64_t fc1_out_features = inter_size * (swiglu_fusion_ > 0 ? 2 : 1); + const size_t num_tokens_size = ToSize(num_tokens); + const size_t hidden_size_size = ToSize(hidden_size); + const size_t inter_size_size = ToSize(inter_size); + const size_t num_experts_size = ToSize(num_experts); + const size_t fc1_out_features_size = ToSize(fc1_out_features); + const size_t k_size = onnxruntime::narrow(k_); + + // Validate block-wise quantization requirements + ORT_RETURN_IF_ERROR(moe_helper::ValidateBlockwiseQuantization(block_size_, hidden_size, inter_size)); + auto* output = context->Output(0, input_shape); auto* tp = context->GetOperatorThreadPool(); AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - const size_t output_buffer_size = static_cast(output->Shape().Size()); + const size_t output_buffer_size = onnxruntime::narrow(output->Shape().Size()); const T* input_data = input->Data(); IAllocatorUniquePtr router_logits_float_buffer; const float* router_logits_float; if constexpr (std::is_same_v) { - router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); + const size_t logits_count = SafeProduct(num_tokens_size, num_experts_size); + router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, logits_count); router_logits_float = router_logits_float_buffer.get(); MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs->Data()), const_cast(router_logits_float), - static_cast(num_tokens * num_experts)); + logits_count); } else { router_logits_float = reinterpret_cast(router_probs->Data()); } - auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, SafeProduct(num_tokens_size, k_size)); int* route_expert = route_expert_ptr.get(); - auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, SafeProduct(num_tokens_size, k_size)); float* route_scale = route_scale_ptr.get(); const int max_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; @@ -427,7 +280,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { std::vector>> thread_local_expert_token_maps(num_routing_threads); for (auto& map : thread_local_expert_token_maps) { - map.resize(static_cast(num_experts)); + map.resize(num_experts_size); for (auto& expert_tokens : map) { expert_tokens.reserve(32); } @@ -437,52 +290,52 @@ Status QMoECPU::Compute(OpKernelContext* context) const { auto work = concurrency::ThreadPool::PartitionWork(narrow(thread_id), num_routing_threads, static_cast(num_tokens)); auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; - std::vector> sorted_logits(static_cast(num_experts)); - std::vector top_k_exp(static_cast(k_)); + std::vector> sorted_logits(num_experts_size); + std::vector top_k_exp(k_size); for (int64_t i = work.start; i < work.end; ++i) { const float* logits = router_logits_float + i * num_experts; - for (size_t j = 0; j < narrow(num_experts); ++j) { + for (size_t j = 0; j < num_experts_size; ++j) { sorted_logits[j] = {logits[j], j}; } - std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + narrow(k_), sorted_logits.end(), std::greater<>()); float max_logit = sorted_logits[0].first; float sum_exp = 0.0f; - for (size_t j = 0; j < narrow(k_); ++j) { + for (size_t j = 0; j < k_size; ++j) { top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit); sum_exp += top_k_exp[j]; } const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); - for (size_t j = 0; j < narrow(k_); ++j) { + for (size_t j = 0; j < k_size; ++j) { int64_t expert_idx = sorted_logits[j].second; int64_t route_idx = i * k_ + narrow(j); route_expert[route_idx] = narrow(expert_idx); route_scale[route_idx] = top_k_exp[j] * inv_sum; if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights - local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + local_expert_token_map[ToSize(expert_idx)].push_back(route_idx); } } } }); - std::vector> expert_token_map(static_cast(num_experts)); + std::vector> expert_token_map(num_experts_size); for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { size_t total_tokens_for_expert = 0; for (int t = 0; t < num_routing_threads; ++t) { - total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); + total_tokens_for_expert += thread_local_expert_token_maps[t][ToSize(expert_idx)].size(); } - expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); + expert_token_map[ToSize(expert_idx)].reserve(total_tokens_for_expert); for (int t = 0; t < num_routing_threads; ++t) { - auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; + auto& local_tokens = thread_local_expert_token_maps[t][ToSize(expert_idx)]; if (!local_tokens.empty()) { - expert_token_map[static_cast(expert_idx)].insert( - expert_token_map[static_cast(expert_idx)].end(), + expert_token_map[ToSize(expert_idx)].insert( + expert_token_map[ToSize(expert_idx)].end(), local_tokens.begin(), local_tokens.end()); } } @@ -491,27 +344,34 @@ Status QMoECPU::Compute(OpKernelContext* context) const { IAllocatorUniquePtr input_float_buffer; const float* input_float; if constexpr (std::is_same_v) { - input_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + const size_t input_count = SafeProduct(num_tokens_size, hidden_size_size); + input_float_buffer = IAllocator::MakeUniquePtr(allocator, input_count); input_float = input_float_buffer.get(); MlasConvertHalfToFloatBuffer(reinterpret_cast(input_data), const_cast(input_float), - static_cast(num_tokens * hidden_size)); + input_count); } else { input_float = reinterpret_cast(input_data); } const int max_expert_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; const int64_t total_expert_work = std::accumulate(expert_token_map.begin(), expert_token_map.end(), 0LL, - [](int64_t sum, const std::vector& tokens) { return sum + static_cast(tokens.size()); }); + [](int64_t sum, const std::vector& tokens) { return sum + onnxruntime::narrow(tokens.size()); }); + // Expert threading uses larger divisor (8x vs 4x for routing) because expert processing involves + // heavier computation (GEMM operations) that benefits more from parallelization const int64_t expert_thread_divisor = std::max(1, max_expert_threads * 8); + // Minimum work per expert thread is smaller (16 vs 32 for routing) because expert work + // involves matrix operations that are more compute-intensive per token const int64_t min_expert_work_per_thread = std::max(int64_t{16}, total_expert_work / expert_thread_divisor); int num_expert_threads = (tp == nullptr || total_expert_work < min_expert_work_per_thread) ? 1 : std::min(narrow(total_expert_work / std::max(int64_t{1}, min_expert_work_per_thread)), std::min(narrow(num_experts), max_expert_threads)); if (num_expert_threads == 0) num_expert_threads = 1; - auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, + SafeProduct(onnxruntime::narrow(num_expert_threads), output_buffer_size)); float* thread_local_outputs = thread_local_outputs_ptr.get(); - std::memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); + const size_t thread_local_outputs_bytes = SafeProduct(SafeProduct(onnxruntime::narrow(num_expert_threads), output_buffer_size), sizeof(float)); + std::memset(thread_local_outputs, 0, thread_local_outputs_bytes); size_t max_tokens_per_expert = 0; for (const auto& tokens : expert_token_map) { @@ -522,21 +382,23 @@ Status QMoECPU::Compute(OpKernelContext* context) const { return (size + 63) & ~63; }; - const size_t A1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); - const size_t C1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(fc1_out_features)); - const size_t A2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(inter_size)); - const size_t C2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); - const size_t B1_dequant_size = align_size(static_cast(fc1_out_features) * static_cast(hidden_size)); - const size_t B2_dequant_size = align_size(static_cast(hidden_size) * static_cast(inter_size)); + const size_t A1_size = align_size(SafeProduct(max_tokens_per_expert, hidden_size_size)); + const size_t C1_size = align_size(SafeProduct(max_tokens_per_expert, fc1_out_features_size)); + const size_t A2_size = align_size(SafeProduct(max_tokens_per_expert, inter_size_size)); + const size_t C2_size = align_size(SafeProduct(max_tokens_per_expert, hidden_size_size)); + const size_t B1_dequant_size = align_size(SafeProduct(fc1_out_features_size, hidden_size_size)); + const size_t B2_dequant_size = align_size(SafeProduct(hidden_size_size, inter_size_size)); const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + B1_dequant_size + B2_dequant_size; - auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * workspace_elements_per_thread); + auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, + SafeProduct(onnxruntime::narrow(num_expert_threads), workspace_elements_per_thread)); float* workspace = workspace_ptr.get(); auto bias_conversion_buffers_ptr = IAllocator::MakeUniquePtr(allocator, - static_cast(num_expert_threads) * (static_cast(fc1_out_features) + static_cast(hidden_size))); + SafeProduct(onnxruntime::narrow(num_expert_threads), + fc1_out_features_size + hidden_size_size)); float* bias_conversion_buffers = bias_conversion_buffers_ptr.get(); const auto& fc1_scales_dims = fc1_scales->Shape().GetDims(); @@ -561,42 +423,59 @@ Status QMoECPU::Compute(OpKernelContext* context) const { size_t total_work = 0; for (int64_t i = 0; i < num_experts; ++i) { - const size_t token_count = expert_token_map[static_cast(i)].size(); + const size_t token_count = expert_token_map[ToSize(i)].size(); if (token_count > 0) { expert_workload.emplace_back(i, token_count); total_work += token_count; } } - if (total_work < 48) { + // Adaptive thresholds optimized for both decoding (1 token) and batch inference + // For single-token decoding: allow more parallelization across experts since each expert + // represents substantial work (especially for large hidden_size) + // For batch: higher thresholds ensure efficient thread utilization + if (total_work < 4) { + // Very small workload - single token with few experts, use 1 thread num_expert_threads = 1; - } else if (total_work < 192) { - num_expert_threads = std::min(num_expert_threads, 2); - } else if (total_work < 512) { + } else if (total_work < 12) { + // Small workload - single token or small batch, use up to 4 threads + // This covers typical decoding case with top-k=2-4 where each expert is substantial work num_expert_threads = std::min(num_expert_threads, 4); + } else if (total_work < 32) { + // Medium workload - use up to 6 threads + num_expert_threads = std::min(num_expert_threads, 6); + } else if (total_work < 128) { + // Large workload - use up to 8 threads + num_expert_threads = std::min(num_expert_threads, 8); + } else if (total_work < 384) { + // Very large workload - use more threads + num_expert_threads = std::min(num_expert_threads, 12); } std::sort(expert_workload.begin(), expert_workload.end(), [](const auto& a, const auto& b) { return a.second > b.second; }); + const size_t num_expert_threads_size = onnxruntime::narrow(num_expert_threads); + std::vector> expert_batches(num_expert_threads); size_t thread_idx = 0; for (const auto& work : expert_workload) { expert_batches[thread_idx].push_back(work.first); - thread_idx = (thread_idx + 1) % static_cast(num_expert_threads); + thread_idx = (thread_idx + 1) % num_expert_threads_size; } concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { const int thread_id = narrow(thread_id_pd); - const auto& expert_batch = expert_batches[static_cast(thread_id)]; + const auto& expert_batch = expert_batches[onnxruntime::narrow(thread_id)]; - float* thread_workspace = workspace + static_cast(thread_id) * workspace_elements_per_thread; + const size_t thread_offset = SafeProduct(onnxruntime::narrow(thread_id), workspace_elements_per_thread); + float* thread_workspace = workspace + thread_offset; - float* thread_bias1_buffer = bias_conversion_buffers + static_cast(thread_id) * (static_cast(fc1_out_features) + static_cast(hidden_size)); - float* thread_bias2_buffer = thread_bias1_buffer + static_cast(fc1_out_features); + float* thread_bias1_buffer = bias_conversion_buffers + SafeProduct(onnxruntime::narrow(thread_id), fc1_out_features_size + hidden_size_size); + float* thread_bias2_buffer = thread_bias1_buffer + fc1_out_features_size; for (int64_t expert_idx : expert_batch) { - const auto& routes = expert_token_map[static_cast(expert_idx)]; + const auto& routes = expert_token_map[ToSize(expert_idx)]; if (routes.empty()) { continue; } @@ -619,113 +498,70 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t end_idx = std::min(start_idx + dynamic_block_size, num_expert_tokens); for (int64_t i = start_idx; i < end_idx; ++i) { - const int64_t token_idx = routes[static_cast(i)] / k_; + const int64_t token_idx = routes[ToSize(i)] / k_; const float* src = input_float + token_idx * hidden_size; float* dst = A1 + i * hidden_size; - std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + std::memcpy(dst, src, SafeProduct(hidden_size_size, sizeof(float))); } }); } else { for (int64_t i = 0; i < num_expert_tokens; ++i) { - const int64_t token_idx = routes[static_cast(i)] / k_; + const int64_t token_idx = routes[ToSize(i)] / k_; const float* src = input_float + token_idx * hidden_size; float* dst = A1 + i * hidden_size; if (ShouldUseMemcpy(hidden_size)) { - std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + std::memcpy(dst, src, SafeProduct(hidden_size_size, sizeof(float))); } else { const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); size_t j = 0; - for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { + for (; j + unroll_factor <= hidden_size_size; j += unroll_factor) { for (size_t k = 0; k < unroll_factor; ++k) { dst[j + k] = src[j + k]; } } - for (; j < narrow(hidden_size); ++j) { + for (; j < hidden_size_size; ++j) { dst[j] = src[j]; } } } } + const int64_t fc1_blocks_per_row = is_fc1_block_wise ? (hidden_size + block_size_ - 1) / block_size_ : 1; + const size_t fc1_blocks_per_row_size = ToSize(fc1_blocks_per_row); + const T* fc1_scales_ptr; if (is_fc1_block_wise) { - const int64_t fc1_blocks_per_row = fc1_scales_dims[2]; - fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features * fc1_blocks_per_row; + const size_t scale_offset = SafeProduct(SafeProduct(ToSize(expert_idx), fc1_out_features_size), fc1_blocks_per_row_size); + fc1_scales_ptr = fc1_scales_data + scale_offset; } else { - fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features; + fc1_scales_ptr = fc1_scales_data + SafeProduct(ToSize(expert_idx), fc1_out_features_size); } const int64_t dequant_block_size = GetDequantBlockSize(fc1_out_features, num_expert_tokens); const int64_t num_dequant_blocks = (fc1_out_features + dequant_block_size - 1) / dequant_block_size; - const size_t m = static_cast(num_expert_tokens); - const size_t n = static_cast(fc1_out_features); - const size_t k = static_cast(hidden_size); - - MLAS_BLK_QUANT_TYPE q_type; - bool use_direct_q4_gemm = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, - fc1_out_features, hidden_size, q_type); - bool fc1_used_direct_q4 = false; - bool fc1_bias_handled_by_q4_gemm = false; - - if (use_direct_q4_gemm) { - IAllocatorUniquePtr mlas_packed_fc1; - Status convert_status = ConvertToMlasQ4Format( - fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, - fc1_scales_ptr, - is_fc1_block_wise ? block_size_ : 0, - expert_weight_bits_, - fc1_out_features, - hidden_size, - q_type, - allocator, - mlas_packed_fc1); - - if (convert_status.IsOK()) { - float* fc1_bias_float = nullptr; - IAllocatorUniquePtr fc1_bias_buffer; - - if (has_fc1_bias) { - const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; - fc1_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(fc1_out_features)); - fc1_bias_float = fc1_bias_buffer.get(); - - if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), fc1_bias_float, static_cast(fc1_out_features)); - } else { - for (int64_t i = 0; i < fc1_out_features; ++i) { - fc1_bias_float[i] = static_cast(B1_bias[i]); - } - } - } - - Status gemm_status = DirectQ4Gemm(A1, mlas_packed_fc1.get(), fc1_bias_float, C1, - num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); + const size_t m = ToSize(num_expert_tokens); + const size_t n = fc1_out_features_size; + const size_t k = hidden_size_size; - if (gemm_status.IsOK()) { - fc1_used_direct_q4 = true; - goto fc1_gemm_done; - } - } - // If direct Q4 GEMM failed, fall back to traditional approach - } - - // Traditional approach: dequantize + regular GEMM + // Quantized MLAS is disabled, using traditional approach: dequantize + regular GEMM + // Use parallel dequantization when we have multiple blocks and sufficient work per thread + // Threshold of 32 features ensures each thread has meaningful work to justify threading overhead if (num_dequant_blocks > 1 && fc1_out_features >= 32) { concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_dequant_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_row = block_idx * dequant_block_size; const int64_t end_row = std::min(start_row + dequant_block_size, fc1_out_features); - const auto offset = expert_idx * fc1_out_features * fc1_packed_cols + start_row * fc1_packed_cols; + const auto offset = SafeProduct(SafeProduct(ToSize(expert_idx), fc1_out_features_size), ToSize(fc1_packed_cols)) + SafeProduct(ToSize(start_row), ToSize(fc1_packed_cols)); DequantizeBlock(fc1_weights_data + offset, - fc1_scales_ptr + (is_fc1_block_wise ? start_row * fc1_scales_dims[2] : start_row), + fc1_scales_ptr + (is_fc1_block_wise ? SafeProduct(ToSize(start_row), fc1_blocks_per_row_size) : ToSize(start_row)), is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, - end_row - start_row, hidden_size, B1_dequant + start_row * hidden_size, tp); + end_row - start_row, hidden_size, B1_dequant + SafeProduct(ToSize(start_row), hidden_size_size), nullptr); }); } else { - DequantizeBlock(fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, + DequantizeBlock(fc1_weights_data + SafeProduct(SafeProduct(ToSize(expert_idx), fc1_out_features_size), ToSize(fc1_packed_cols)), fc1_scales_ptr, is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, fc1_out_features, hidden_size, B1_dequant, tp); @@ -738,8 +574,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { 0.0f, C1, n, tp); - fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias; - if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) { + if (has_fc1_bias) { const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); @@ -776,9 +611,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } } - fc1_gemm_done: - - const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); + // Activation threshold scales inversely with inter_size to balance work per thread + // For large inter_size, fewer tokens justify parallel activation; for small inter_size, more tokens needed + // Base value of 256 ensures reasonable work distribution across different model configurations + const int64_t activation_threshold = std::max(int64_t{4}, 256 / inter_size); if (num_expert_tokens >= activation_threshold && tp != nullptr) { const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; @@ -809,13 +645,16 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } } + const int64_t fc2_blocks_per_row = is_fc2_block_wise ? (inter_size + block_size_ - 1) / block_size_ : 1; + const size_t fc2_blocks_per_row_size = ToSize(fc2_blocks_per_row); + const T* fc2_scales_ptr; if (is_fc2_block_wise) { - const int64_t fc2_blocks_per_row = fc2_scales_dims[2]; - fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size * fc2_blocks_per_row; + const size_t scale_offset = SafeProduct(SafeProduct(ToSize(expert_idx), hidden_size_size), fc2_blocks_per_row_size); + fc2_scales_ptr = fc2_scales_data + scale_offset; } else { - fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size; + fc2_scales_ptr = fc2_scales_data + SafeProduct(ToSize(expert_idx), hidden_size_size); } const int64_t fc2_dequant_block_size = GetDequantBlockSize(hidden_size, num_expert_tokens); @@ -825,67 +664,21 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t n2 = static_cast(hidden_size); const size_t k2 = static_cast(inter_size); - MLAS_BLK_QUANT_TYPE q_type2; - bool use_direct_q4_gemm_fc2 = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, - hidden_size, inter_size, q_type2); - bool fc2_used_direct_q4 = false; - - if (use_direct_q4_gemm_fc2) { - IAllocatorUniquePtr mlas_packed_fc2; - Status convert_status = ConvertToMlasQ4Format( - fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, - fc2_scales_ptr, - is_fc2_block_wise ? block_size_ : 0, - expert_weight_bits_, - hidden_size, - inter_size, - q_type2, - allocator, - mlas_packed_fc2); - - if (convert_status.IsOK()) { - float* fc2_bias_float = nullptr; - IAllocatorUniquePtr fc2_bias_buffer; - - if (has_fc2_bias) { - const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; - fc2_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(hidden_size)); - fc2_bias_float = fc2_bias_buffer.get(); - - if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), fc2_bias_float, static_cast(hidden_size)); - } else { - for (int64_t i = 0; i < hidden_size; ++i) { - fc2_bias_float[i] = static_cast(B2_bias[i]); - } - } - } - - Status gemm_status = DirectQ4Gemm(A2, mlas_packed_fc2.get(), fc2_bias_float, C2, - num_expert_tokens, hidden_size, inter_size, q_type2, tp); - - if (gemm_status.IsOK()) { - fc2_used_direct_q4 = true; - goto fc2_gemm_done; - } - } - - // If direct Q4 GEMM failed, fall back to traditional approach - } - - // Traditional approach: dequantize + regular GEMM + // Quantized MLAS is disabled, using traditional approach: dequantize + regular GEMM + // Use parallel dequantization when we have multiple blocks and sufficient work per thread + // Threshold of 32 for hidden_size ensures each thread has meaningful work to justify threading overhead if (num_fc2_dequant_blocks > 1 && hidden_size >= 32) { concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_fc2_dequant_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_row = block_idx * fc2_dequant_block_size; const int64_t end_row = std::min(start_row + fc2_dequant_block_size, hidden_size); - const auto offset = expert_idx * hidden_size * fc2_packed_cols + start_row * fc2_packed_cols; + const auto offset = SafeProduct(SafeProduct(ToSize(expert_idx), hidden_size_size), ToSize(fc2_packed_cols)) + SafeProduct(ToSize(start_row), ToSize(fc2_packed_cols)); DequantizeBlock(fc2_weights_data + offset, - fc2_scales_ptr + (is_fc2_block_wise ? start_row * fc2_scales_dims[2] : start_row), + fc2_scales_ptr + (is_fc2_block_wise ? SafeProduct(ToSize(start_row), fc2_blocks_per_row_size) : ToSize(start_row)), is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, - end_row - start_row, inter_size, B2_dequant + start_row * inter_size, tp); + end_row - start_row, inter_size, B2_dequant + SafeProduct(ToSize(start_row), inter_size_size), nullptr); }); } else { - DequantizeBlock(fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, + DequantizeBlock(fc2_weights_data + SafeProduct(SafeProduct(ToSize(expert_idx), hidden_size_size), ToSize(fc2_packed_cols)), fc2_scales_ptr, is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, hidden_size, inter_size, B2_dequant, tp); @@ -898,10 +691,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { 0.0f, C2, n2, tp); - fc2_gemm_done: - - bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias) { const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); @@ -936,7 +726,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias) { const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); size_t j = 0; for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) {