diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index e494719464d20..39249f842e632 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -49,7 +49,8 @@ Status CheckInputs(MoEParameters& parameters, const Tensor* fc3_experts_bias, // optional const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) - const bool is_fused_swiglu) { + const bool is_fused_swiglu, + const int64_t block_size = 0) { // block size for block-wise quantization // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. ASSERT_TENSOR_2D_OR_3D(input); ASSERT_TENSOR_3D(fc1_experts_weights); @@ -90,9 +91,63 @@ Status CheckInputs(MoEParameters& parameters, CHECK_TENSOR_SHAPE(fc2_experts_bias, num_experts, hidden_size); CHECK_TENSOR_SHAPE(fc3_experts_bias, num_experts, inter_size); - CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size); - CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size); - CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size); + // Validate scale tensors: Handle both row-wise and block-wise quantization flexibly + // First, detect the actual quantization method from the tensor shapes + bool is_row_wise_quantization = true; + if (fc1_experts_scales != nullptr) { + const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims(); + if (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1) { + is_row_wise_quantization = false; + } + } + + if (block_size > 0 && !is_row_wise_quantization) { + // Block-wise quantization: 3D scale tensors + // For block-wise quantization, we calculate the number of blocks using ceiling division + // to handle cases where the dimension is not perfectly divisible by block_size + const int64_t fc1_blocks_per_row = (hidden_size + block_size - 1) / block_size; + const int64_t fc2_blocks_per_row = (inter_size + block_size - 1) / block_size; + const int64_t fc3_blocks_per_row = (hidden_size + block_size - 1) / block_size; + + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, fc1_blocks_per_row); + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, fc2_blocks_per_row); + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, fc3_blocks_per_row); + } else { + // Row-wise quantization: 2D scale tensors or 3D with last dimension = 1 + // Handle both {num_experts, features} and {num_experts, features, 1} shapes + if (fc1_experts_scales != nullptr) { + const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims(); + if (fc1_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size); + } else if (fc1_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, 1); + } else { + ORT_THROW("fc1_experts_scales must be 2D or 3D tensor"); + } + } + + if (fc2_experts_scales != nullptr) { + const auto& fc2_scales_dims = fc2_experts_scales->Shape().GetDims(); + if (fc2_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size); + } else if (fc2_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, 1); + } else { + ORT_THROW("fc2_experts_scales must be 2D or 3D tensor"); + } + } + + if (fc3_experts_scales != nullptr) { + const auto& fc3_scales_dims = fc3_experts_scales->Shape().GetDims(); + if (fc3_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size); + } else if (fc3_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, 1); + } else { + ORT_THROW("fc3_experts_scales must be 2D or 3D tensor"); + } + } + } if (fc3_experts_weights == nullptr) { ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr); diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 5c6c3b919b572..8195c9438d408 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -2,12 +2,16 @@ // Licensed under the MIT License. #include "contrib_ops/cpu/moe/moe_quantization_cpu.h" - #include "core/framework/allocator.h" #include "core/framework/float16.h" #include "core/mlas/inc/mlas.h" +#include "core/mlas/inc/mlas_q4.h" #include "core/platform/threadpool.h" #include "core/providers/cpu/math/gemm_helper.h" +#include "core/providers/cpu/activation/activations.h" +#include "core/common/safeint.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/util/math.h" #include "contrib_ops/cpu/moe/moe_utils.h" #include "contrib_ops/cpu/moe/moe_helper.h" @@ -17,44 +21,325 @@ #include #include +namespace { +inline int64_t GetOptimalBlockSize(int64_t total_elements, int num_threads) { + if (total_elements <= 0 || num_threads <= 0) return 64; + const int64_t l1_cache_elements = 8192; // ~32KB / 4 bytes per float + const int64_t divisor = std::max(1, num_threads > 1 ? 4 : 2); + const int64_t base_block_size = l1_cache_elements / divisor; + const int64_t max_block = std::max(int64_t{32}, total_elements / std::max(int64_t{1}, int64_t{4})); + return std::clamp(base_block_size, int64_t{32}, std::min(int64_t{512}, max_block)); +} + +inline int64_t GetUnrollFactor(int64_t vector_size) { + if (vector_size <= 0) return 2; + if (vector_size >= 512) return 16; + if (vector_size >= 128) return 8; + if (vector_size >= 32) return 4; + return 2; +} + +inline bool ShouldUseMemcpy(int64_t size) { + return size >= 64; +} + +inline int64_t GetDequantBlockSize(int64_t features, int64_t total_work) { + if (features <= 0 || total_work <= 0) return 16; + const int64_t target_block_size = std::max(int64_t{16}, features / std::max(int64_t{1}, int64_t{8})); + const int64_t work_based_size = std::max(int64_t{16}, total_work / std::max(int64_t{1}, int64_t{4})); + return std::min(target_block_size, work_based_size); +} + +bool CanUseMlasQ4Dequant(int64_t num_bits, int64_t block_size) { + if (num_bits != 4) { + return false; + } + + return true; +} + +bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, + int64_t rows, int64_t cols, MLAS_BLK_QUANT_TYPE& out_qtype) { + if (expert_weight_bits != 4) { + return false; + } + + if (block_size == 64) { + out_qtype = BlkQ4Sym64; + } else if (block_size == 128) { + out_qtype = BlkQ4Sym128; + } else if (block_size == 0) { + out_qtype = BlkQ4Sym; + } else { + return false; + } + + size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(cols), static_cast(rows)); + return expected_size > 0; +} + +} // namespace + namespace onnxruntime { namespace contrib { -// Helper function to dequantize weights. Supports 4-bit and 8-bit symmetric quantization. -// The source quantized weights are stored as a row-major representation of the transposed -// logical weight matrix (W^T). This function dequantizes it into a float row-major W^T matrix. template -void DequantizeBlock(const uint8_t* quantized_data, - const TScale* scales, - int64_t /*block_size*/, - int64_t num_bits, - int64_t rows, - int64_t cols, - float* dequantized_data) { +void DequantizeBlockWithMlas(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool); + +template +Status ConvertToMlasQ4Format(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + MLAS_BLK_QUANT_TYPE qtype, + AllocatorPtr allocator, + IAllocatorUniquePtr& mlas_packed_buffer) { + if (num_bits != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only 4-bit quantization supported for MLAS Q4 format conversion"); + } + + auto temp_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); + float* temp_float = temp_float_buffer.get(); + + DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, temp_float, nullptr); + + size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(cols), static_cast(rows)); + if (packed_size == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); + } + + mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); + MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast(cols), static_cast(rows), static_cast(cols)); + + return Status::OK(); +} + +Status DirectQ4Gemm(const float* A, + const uint8_t* mlas_packed_B, + const float* bias, + float* C, + int64_t M, + int64_t N, + int64_t K, + MLAS_BLK_QUANT_TYPE qtype, + MLAS_THREADPOOL* thread_pool) { + MLAS_Q4_GEMM_DATA_PARAMS params; + params.A = A; + params.lda = static_cast(K); + params.B = mlas_packed_B; + params.Bias = bias; + params.C = C; + params.ldc = static_cast(N); + params.OutputProcessor = nullptr; + + MlasQ4GemmBatch(qtype, static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, thread_pool); + return Status::OK(); +} + +template +void DequantizeBlockWithMlas(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool) { const float zero_point = num_bits == 8 ? 128.0f : 8.0f; - if (num_bits == 8) { - for (int64_t r = 0; r < rows; ++r) { - const float scale = static_cast(scales[r]); - for (int64_t c = 0; c < cols; ++c) { - // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) - dequantized_data[r * cols + c] = scale * (static_cast(quantized_data[r * cols + c]) - zero_point); + const int64_t blocks_per_row = (block_size > 0) ? ((cols + block_size - 1) / block_size) : 1; + + if (CanUseMlasQ4Dequant(num_bits, block_size)) { + const int64_t packed_cols = (cols + 1) / 2; + + if (block_size == 0) { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + const float scale = static_cast(scales[r]); + + int64_t c = 0; + for (; c + 8 <= cols; c += 8) { + const uint8_t packed_val0 = row_data[(c + 0) / 2]; + const uint8_t packed_val1 = row_data[(c + 2) / 2]; + const uint8_t packed_val2 = row_data[(c + 4) / 2]; + const uint8_t packed_val3 = row_data[(c + 6) / 2]; + + row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); + row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); + row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); + row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); + row_output[c + 4] = scale * (static_cast(packed_val2 & 0x0F) - zero_point); + row_output[c + 5] = scale * (static_cast(packed_val2 >> 4) - zero_point); + row_output[c + 6] = scale * (static_cast(packed_val3 & 0x0F) - zero_point); + row_output[c + 7] = scale * (static_cast(packed_val3 >> 4) - zero_point); + } + + for (; c < cols; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < cols) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } } + return; + } else { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + int64_t c = block_start; + for (; c + 4 <= block_end; c += 4) { + const uint8_t packed_val0 = row_data[(c + 0) / 2]; + const uint8_t packed_val1 = row_data[(c + 2) / 2]; + + row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); + row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); + row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); + row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); + } + + for (; c < block_end; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < block_end) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } + } + return; } - } else if (num_bits == 4) { - const int64_t packed_cols = (cols + 1) / 2; + } + + if (num_bits == 8 && block_size == 0) { for (int64_t r = 0; r < rows; ++r) { const float scale = static_cast(scales[r]); - for (int64_t c = 0; c < cols; ++c) { - const uint8_t packed_val = quantized_data[r * packed_cols + c / 2]; - // Unpack the 4-bit value. Low nibble for even columns, high nibble for odd columns. - const uint8_t quantized_val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); - // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) - dequantized_data[r * cols + c] = scale * (static_cast(quantized_val) - zero_point); + const uint8_t zero_pt = static_cast(zero_point); + + MlasDequantizeLinear( + quantized_data + r * cols, + dequantized_data + r * cols, + static_cast(cols), + scale, + zero_pt); + } + } else { + if (num_bits == 8) { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * cols; + float* row_output = dequantized_data + r * cols; + + int64_t c = 0; + if (block_size > 0) { + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + for (c = block_start; c + 4 <= block_end; c += 4) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + row_output[c + 1] = scale * (static_cast(row_data[c + 1]) - zero_point); + row_output[c + 2] = scale * (static_cast(row_data[c + 2]) - zero_point); + row_output[c + 3] = scale * (static_cast(row_data[c + 3]) - zero_point); + } + for (; c < block_end; ++c) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + } + } + } else { + const float scale = static_cast(scales[r]); + for (; c + 8 <= cols; c += 8) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + row_output[c + 1] = scale * (static_cast(row_data[c + 1]) - zero_point); + row_output[c + 2] = scale * (static_cast(row_data[c + 2]) - zero_point); + row_output[c + 3] = scale * (static_cast(row_data[c + 3]) - zero_point); + row_output[c + 4] = scale * (static_cast(row_data[c + 4]) - zero_point); + row_output[c + 5] = scale * (static_cast(row_data[c + 5]) - zero_point); + row_output[c + 6] = scale * (static_cast(row_data[c + 6]) - zero_point); + row_output[c + 7] = scale * (static_cast(row_data[c + 7]) - zero_point); + } + for (; c < cols; ++c) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + } + } + } + } else if (num_bits == 4) { + const int64_t packed_cols = (cols + 1) / 2; + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + + if (block_size > 0) { + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + for (int64_t c = block_start; c < block_end; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < block_end) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } + } else { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < cols) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } } } } } +template +void DequantizeBlock(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool = nullptr) { + DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, dequantized_data, thread_pool); +} + template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), @@ -63,11 +348,15 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, "Attribute 'expert_weight_bits' must be 4 or 8."); block_size_ = op_kernel_info.GetAttrOrDefault("block_size", 0); + + if (block_size_ > 0) { + ORT_ENFORCE(block_size_ >= 16, "block_size must be >= 16 when provided."); + ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); + } } template Status QMoECPU::Compute(OpKernelContext* context) const { - // --- 1. Get Inputs and Attributes --- const auto* input = context->Input(0); const auto* router_probs = context->Input(1); const auto* fc1_experts_weights = context->Input(2); @@ -87,7 +376,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias, fc2_scales, fc3_experts_weights, fc3_experts_bias, fc3_scales, expert_weight_bits_ == 4 ? 2 : 1, - true)); + true, + block_size_)); if (fc3_experts_weights || fc3_experts_bias || fc3_scales) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); @@ -109,19 +399,17 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t output_buffer_size = static_cast(output->Shape().Size()); const T* input_data = input->Data(); - const T* router_probs_data = router_probs->Data(); - // --- 2. Routing Logic: Assign tokens to experts --- IAllocatorUniquePtr router_logits_float_buffer; const float* router_logits_float; if constexpr (std::is_same_v) { router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); router_logits_float = router_logits_float_buffer.get(); - MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs_data), + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs->Data()), const_cast(router_logits_float), static_cast(num_tokens * num_experts)); } else { - router_logits_float = reinterpret_cast(router_probs_data); + router_logits_float = reinterpret_cast(router_probs->Data()); } auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); @@ -129,36 +417,37 @@ Status QMoECPU::Compute(OpKernelContext* context) const { auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); float* route_scale = route_scale_ptr.get(); - // Parallelize the routing logic to improve performance for large token batches. - // Minor performance regression for single-token decoding is an acceptable trade-off - int num_routing_threads = (tp == nullptr || num_tokens < 4096) ? 1 : std::min(static_cast(num_tokens), concurrency::ThreadPool::DegreeOfParallelism(tp)); + const int max_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const int64_t thread_divisor = std::max(1, max_threads * 4); + const int64_t min_work_per_thread = std::max(int64_t{32}, static_cast(num_tokens / thread_divisor)); + const int optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(static_cast(num_tokens / std::max(int64_t{1}, min_work_per_thread)), max_threads); + const int num_routing_threads = std::max(1, optimal_routing_threads); std::vector>> thread_local_expert_token_maps(num_routing_threads); for (auto& map : thread_local_expert_token_maps) { map.resize(static_cast(num_experts)); + for (auto& expert_tokens : map) { + expert_tokens.reserve(32); + } } concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; - // Pre-allocate buffers for this thread to reuse, avoiding allocations inside the loop. std::vector> sorted_logits(static_cast(num_experts)); std::vector top_k_exp(static_cast(k_)); for (int64_t i = work.start; i < work.end; ++i) { const float* logits = router_logits_float + i * num_experts; + for (int64_t j = 0; j < num_experts; ++j) { sorted_logits[static_cast(j)] = {logits[j], j}; } - std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), + sorted_logits.end(), std::greater<>()); - float max_logit = -std::numeric_limits::infinity(); - for (int64_t j = 0; j < k_; ++j) { - if (sorted_logits[static_cast(j)].first > max_logit) { - max_logit = sorted_logits[static_cast(j)].first; - } - } + float max_logit = sorted_logits[0].first; float sum_exp = 0.0f; for (int64_t j = 0; j < k_; ++j) { @@ -166,20 +455,19 @@ Status QMoECPU::Compute(OpKernelContext* context) const { sum_exp += top_k_exp[static_cast(j)]; } - float scale = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); + const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); for (int64_t j = 0; j < k_; ++j) { int64_t expert_idx = sorted_logits[static_cast(j)].second; int64_t route_idx = i * k_ + j; route_expert[route_idx] = static_cast(expert_idx); - route_scale[route_idx] = top_k_exp[static_cast(j)] * scale; - if (route_scale[route_idx] > 0.0f) { + route_scale[route_idx] = top_k_exp[static_cast(j)] * inv_sum; + if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); } } } }); - // Merge the maps from each thread into a single global map. std::vector> expert_token_map(static_cast(num_experts)); for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { size_t total_tokens_for_expert = 0; @@ -187,18 +475,17 @@ Status QMoECPU::Compute(OpKernelContext* context) const { total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); } expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); - } - for (int t = 0; t < num_routing_threads; ++t) { - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + for (int t = 0; t < num_routing_threads; ++t) { auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; if (!local_tokens.empty()) { - expert_token_map[static_cast(expert_idx)].insert(expert_token_map[static_cast(expert_idx)].end(), local_tokens.begin(), local_tokens.end()); + expert_token_map[static_cast(expert_idx)].insert( + expert_token_map[static_cast(expert_idx)].end(), + local_tokens.begin(), local_tokens.end()); } } } - // --- 3. Parallel Expert Computation --- IAllocatorUniquePtr input_float_buffer; const float* input_float; if constexpr (std::is_same_v) { @@ -211,118 +498,434 @@ Status QMoECPU::Compute(OpKernelContext* context) const { input_float = reinterpret_cast(input_data); } - int num_expert_threads = (tp == nullptr) ? 1 : std::min(static_cast(num_experts), concurrency::ThreadPool::DegreeOfParallelism(tp)); + const int max_expert_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const int64_t total_expert_work = std::accumulate(expert_token_map.begin(), expert_token_map.end(), 0LL, + [](int64_t sum, const std::vector& tokens) { return sum + static_cast(tokens.size()); }); + const int64_t expert_thread_divisor = std::max(1, max_expert_threads * 8); + const int64_t min_expert_work_per_thread = std::max(int64_t{16}, total_expert_work / expert_thread_divisor); + + int num_expert_threads = (tp == nullptr || total_expert_work < min_expert_work_per_thread) ? 1 : std::min(static_cast(total_expert_work / std::max(int64_t{1}, min_expert_work_per_thread)), std::min(static_cast(num_experts), max_expert_threads)); if (num_expert_threads == 0) num_expert_threads = 1; + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); float* thread_local_outputs = thread_local_outputs_ptr.get(); - memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); + std::memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); - // Pre-calculate workspace size per thread to avoid allocations inside the loop size_t max_tokens_per_expert = 0; for (const auto& tokens : expert_token_map) { - if (tokens.size() > max_tokens_per_expert) { - max_tokens_per_expert = tokens.size(); - } + max_tokens_per_expert = std::max(max_tokens_per_expert, tokens.size()); } - const size_t A1_size = static_cast(max_tokens_per_expert * hidden_size); - const size_t C1_size = static_cast(max_tokens_per_expert * fc1_out_features); - const size_t A2_size = static_cast(max_tokens_per_expert * inter_size); - const size_t C2_size = static_cast(max_tokens_per_expert * hidden_size); - const size_t B1_dequant_size = static_cast(fc1_out_features * hidden_size); - const size_t B2_dequant_size = static_cast(hidden_size * inter_size); - const size_t bias1_size = static_cast(fc1_out_features); - const size_t bias2_size = static_cast(hidden_size); + const auto align_size = [](size_t size) -> size_t { + return (size + 63) & ~63; + }; + + const size_t A1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); + const size_t C1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(fc1_out_features)); + const size_t A2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(inter_size)); + const size_t C2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); + const size_t B1_dequant_size = align_size(static_cast(fc1_out_features) * static_cast(hidden_size)); + const size_t B2_dequant_size = align_size(static_cast(hidden_size) * static_cast(inter_size)); + + const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + + B1_dequant_size + B2_dequant_size; - const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + B1_dequant_size + B2_dequant_size + bias1_size + bias2_size; auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * workspace_elements_per_thread); float* workspace = workspace_ptr.get(); + auto bias_conversion_buffers_ptr = IAllocator::MakeUniquePtr(allocator, + static_cast(num_expert_threads) * (static_cast(fc1_out_features) + static_cast(hidden_size))); + float* bias_conversion_buffers = bias_conversion_buffers_ptr.get(); + + const auto& fc1_scales_dims = fc1_scales->Shape().GetDims(); + const auto& fc2_scales_dims = fc2_scales->Shape().GetDims(); + const bool is_fc1_block_wise = (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1); + const bool is_fc2_block_wise = (fc2_scales_dims.size() == 3 && fc2_scales_dims[2] > 1); + + const uint8_t* fc1_weights_data = fc1_experts_weights->Data(); + const uint8_t* fc2_weights_data = fc2_experts_weights->Data(); + const T* fc1_scales_data = fc1_scales->Data(); + const T* fc2_scales_data = fc2_scales->Data(); + const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; + const T* fc2_bias_data = fc2_experts_bias ? fc2_experts_bias->Data() : nullptr; + + const int64_t pack_unit = (8 / expert_weight_bits_); + const int64_t fc1_packed_cols = (hidden_size + pack_unit - 1) / pack_unit; + const int64_t fc2_packed_cols = (inter_size + pack_unit - 1) / pack_unit; + const bool has_fc1_bias = (fc1_bias_data != nullptr); + const bool has_fc2_bias = (fc2_bias_data != nullptr); + + std::vector> expert_workload; + size_t total_work = 0; + + for (int64_t i = 0; i < num_experts; ++i) { + const size_t token_count = expert_token_map[static_cast(i)].size(); + if (token_count > 0) { + expert_workload.emplace_back(i, token_count); + total_work += token_count; + } + } + + if (total_work < 48) { + num_expert_threads = 1; + } else if (total_work < 192) { + num_expert_threads = std::min(num_expert_threads, 2); + } else if (total_work < 512) { + num_expert_threads = std::min(num_expert_threads, 4); + } + + std::sort(expert_workload.begin(), expert_workload.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + + std::vector> expert_batches(num_expert_threads); + size_t thread_idx = 0; + for (const auto& work : expert_workload) { + expert_batches[thread_idx].push_back(work.first); + thread_idx = (thread_idx + 1) % static_cast(num_expert_threads); + } + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { - int thread_id = static_cast(thread_id_pd); - auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); + const int thread_id = static_cast(thread_id_pd); + const auto& expert_batch = expert_batches[static_cast(thread_id)]; float* thread_workspace = workspace + static_cast(thread_id) * workspace_elements_per_thread; - for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) { + float* thread_bias1_buffer = bias_conversion_buffers + static_cast(thread_id) * (static_cast(fc1_out_features) + static_cast(hidden_size)); + float* thread_bias2_buffer = thread_bias1_buffer + static_cast(fc1_out_features); + + for (int64_t expert_idx : expert_batch) { const auto& routes = expert_token_map[static_cast(expert_idx)]; if (routes.empty()) { continue; } - const int64_t num_expert_tokens = routes.size(); + const int64_t num_expert_tokens = static_cast(routes.size()); - // Partition the workspace for the current expert float* A1 = thread_workspace; - float* C1 = A1 + num_expert_tokens * hidden_size; - float* A2 = C1 + num_expert_tokens * fc1_out_features; - float* C2 = A2 + num_expert_tokens * inter_size; - float* B1_dequant = C2 + num_expert_tokens * hidden_size; - float* B2_dequant = B1_dequant + fc1_out_features * hidden_size; - float* bias1_float = B2_dequant + hidden_size * inter_size; - float* bias2_float = bias1_float + fc1_out_features; - - // --- Gather input tokens for the current expert --- - for (int64_t i = 0; i < num_expert_tokens; ++i) { - const int64_t token_idx = routes[static_cast(i)] / k_; - memcpy(A1 + i * hidden_size, - input_float + token_idx * hidden_size, - static_cast(hidden_size) * sizeof(float)); + float* C1 = A1 + A1_size; + float* A2 = C1 + C1_size; + float* C2 = A2 + A2_size; + float* B1_dequant = C2 + C2_size; + float* B2_dequant = B1_dequant + B1_dequant_size; + + const int64_t dynamic_block_size = GetOptimalBlockSize(num_expert_tokens, tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1); + const int64_t num_blocks = (num_expert_tokens + dynamic_block_size - 1) / dynamic_block_size; + + if (num_expert_tokens >= 8 && num_blocks > 1 && tp != nullptr) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_idx = block_idx * dynamic_block_size; + const int64_t end_idx = std::min(start_idx + dynamic_block_size, num_expert_tokens); + + for (int64_t i = start_idx; i < end_idx; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + const float* src = input_float + token_idx * hidden_size; + float* dst = A1 + i * hidden_size; + + std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + } + }); + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + const float* src = input_float + token_idx * hidden_size; + float* dst = A1 + i * hidden_size; + + if (ShouldUseMemcpy(hidden_size)) { + std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + } else { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dst[j + k] = src[j + k]; + } + } + for (; j < hidden_size; ++j) { + dst[j] = src[j]; + } + } + } + } + + const T* fc1_scales_ptr; + + if (is_fc1_block_wise) { + const int64_t fc1_blocks_per_row = fc1_scales_dims[2]; + fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features * fc1_blocks_per_row; + } else { + fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features; } - // --- FC1 GEMM (X * W1^T) --- - DequantizeBlock(fc1_experts_weights->Data() + expert_idx * fc1_out_features * (hidden_size / (8 / expert_weight_bits_)), - fc1_scales->Data() + expert_idx * fc1_out_features * (block_size_ > 0 ? hidden_size / block_size_ : 1), - block_size_, expert_weight_bits_, - fc1_out_features, hidden_size, B1_dequant); + const int64_t dequant_block_size = GetDequantBlockSize(fc1_out_features, num_expert_tokens); + const int64_t num_dequant_blocks = (fc1_out_features + dequant_block_size - 1) / dequant_block_size; + + const size_t m = static_cast(num_expert_tokens); + const size_t n = static_cast(fc1_out_features); + const size_t k = static_cast(hidden_size); + + MLAS_BLK_QUANT_TYPE q_type; + bool use_direct_q4_gemm = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type); + bool fc1_used_direct_q4 = false; + bool fc1_bias_handled_by_q4_gemm = false; + + if (use_direct_q4_gemm) { + IAllocatorUniquePtr mlas_packed_fc1; + Status convert_status = ConvertToMlasQ4Format( + fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, + fc1_scales_ptr, + is_fc1_block_wise ? block_size_ : 0, + expert_weight_bits_, + fc1_out_features, + hidden_size, + q_type, + allocator, + mlas_packed_fc1); + + if (convert_status.IsOK()) { + float* fc1_bias_float = nullptr; + IAllocatorUniquePtr fc1_bias_buffer; + + if (has_fc1_bias) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; + fc1_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(fc1_out_features)); + fc1_bias_float = fc1_bias_buffer.get(); + + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), fc1_bias_float, static_cast(fc1_out_features)); + } else { + for (int64_t i = 0; i < fc1_out_features; ++i) { + fc1_bias_float[i] = static_cast(B1_bias[i]); + } + } + } + + Status gemm_status = DirectQ4Gemm(A1, mlas_packed_fc1.get(), fc1_bias_float, C1, + num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); + + if (gemm_status.IsOK()) { + fc1_used_direct_q4 = true; +#ifdef ONNXRUNTIME_ENABLE_VERBOSE_LOGGING + LOGS_DEFAULT(VERBOSE) << "QMoE: Using direct MLAS Q4 GEMM for FC1 expert " << expert_idx + << " (M=" << num_expert_tokens << ", N=" << fc1_out_features << ", K=" << hidden_size << ")"; +#endif + goto fc1_gemm_done; + } + } + // If direct Q4 GEMM failed, fall back to traditional approach + } + + // Traditional approach: dequantize + regular GEMM + if (num_dequant_blocks > 1 && fc1_out_features >= 32) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_dequant_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_row = block_idx * dequant_block_size; + const int64_t end_row = std::min(start_row + dequant_block_size, fc1_out_features); + const auto offset = expert_idx * fc1_out_features * fc1_packed_cols + start_row * fc1_packed_cols; + DequantizeBlock(fc1_weights_data + offset, + fc1_scales_ptr + (is_fc1_block_wise ? start_row * fc1_scales_dims[2] : start_row), + is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, + end_row - start_row, hidden_size, B1_dequant + start_row * hidden_size, tp); + }); + } else { + DequantizeBlock(fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, + fc1_scales_ptr, + is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, + fc1_out_features, hidden_size, B1_dequant, tp); + } MlasGemm(CblasNoTrans, CblasTrans, - static_cast(num_expert_tokens), static_cast(fc1_out_features), static_cast(hidden_size), - 1.0f, A1, static_cast(hidden_size), - B1_dequant, static_cast(hidden_size), - 0.0f, C1, static_cast(fc1_out_features), - nullptr); - - const T* B1_bias = (fc1_experts_bias) ? fc1_experts_bias->Data() + expert_idx * fc1_out_features : nullptr; - if (B1_bias) { + m, n, k, + 1.0f, A1, k, + B1_dequant, k, + 0.0f, C1, n, + tp); + + fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias; + if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), bias1_float, static_cast(fc1_out_features)); + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); } else { - memcpy(bias1_float, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + if (ShouldUseMemcpy(fc1_out_features)) { + std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } else { + const int64_t unroll_factor = GetUnrollFactor(fc1_out_features); + int64_t j = 0; + for (; j + unroll_factor <= fc1_out_features; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + thread_bias1_buffer[j + k] = static_cast(B1_bias[j + k]); + } + } + for (; j < fc1_out_features; ++j) { + thread_bias1_buffer[j] = static_cast(B1_bias[j]); + } + } } + for (int64_t i = 0; i < num_expert_tokens; ++i) { - for (int64_t j = 0; j < fc1_out_features; ++j) { - C1[i * fc1_out_features + j] += bias1_float[j]; + float* C1_row = C1 + i * fc1_out_features; + const int64_t unroll_factor = GetUnrollFactor(fc1_out_features); + + int64_t j = 0; + for (; j + unroll_factor <= fc1_out_features; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + C1_row[j + k] += thread_bias1_buffer[j + k]; + } + } + for (; j < fc1_out_features; ++j) { + C1_row[j] += thread_bias1_buffer[j]; } } } - // --- Activation --- - for (int64_t i = 0; i < num_expert_tokens; ++i) { - const float* C1_token = C1 + i * fc1_out_features; - float* A2_token = A2 + i * inter_size; - ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + fc1_gemm_done: + + const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); + if (num_expert_tokens >= activation_threshold && tp != nullptr) { + const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); + const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; + + if (num_activation_blocks > 1) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_activation_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_token = block_idx * activation_block_size; + const int64_t end_token = std::min(start_token + activation_block_size, num_expert_tokens); + + for (int64_t i = start_token; i < end_token; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + }); + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + } + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + } + + const T* fc2_scales_ptr; + + if (is_fc2_block_wise) { + const int64_t fc2_blocks_per_row = fc2_scales_dims[2]; + fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size * fc2_blocks_per_row; + } else { + fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size; } - // --- FC2 GEMM (A2 * W2^T) --- - DequantizeBlock(fc2_experts_weights->Data() + expert_idx * hidden_size * (inter_size / (8 / expert_weight_bits_)), - fc2_scales->Data() + expert_idx * hidden_size * (block_size_ > 0 ? inter_size / block_size_ : 1), - block_size_, expert_weight_bits_, - hidden_size, inter_size, B2_dequant); + const int64_t fc2_dequant_block_size = GetDequantBlockSize(hidden_size, num_expert_tokens); + const int64_t num_fc2_dequant_blocks = (hidden_size + fc2_dequant_block_size - 1) / fc2_dequant_block_size; + + const size_t m2 = static_cast(num_expert_tokens); + const size_t n2 = static_cast(hidden_size); + const size_t k2 = static_cast(inter_size); + + MLAS_BLK_QUANT_TYPE q_type2; + bool use_direct_q4_gemm_fc2 = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2); + bool fc2_used_direct_q4 = false; + + if (use_direct_q4_gemm_fc2) { + IAllocatorUniquePtr mlas_packed_fc2; + Status convert_status = ConvertToMlasQ4Format( + fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, + fc2_scales_ptr, + is_fc2_block_wise ? block_size_ : 0, + expert_weight_bits_, + hidden_size, + inter_size, + q_type2, + allocator, + mlas_packed_fc2); + + if (convert_status.IsOK()) { + float* fc2_bias_float = nullptr; + IAllocatorUniquePtr fc2_bias_buffer; + + if (has_fc2_bias) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; + fc2_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(hidden_size)); + fc2_bias_float = fc2_bias_buffer.get(); + + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), fc2_bias_float, static_cast(hidden_size)); + } else { + for (int64_t i = 0; i < hidden_size; ++i) { + fc2_bias_float[i] = static_cast(B2_bias[i]); + } + } + } + + Status gemm_status = DirectQ4Gemm(A2, mlas_packed_fc2.get(), fc2_bias_float, C2, + num_expert_tokens, hidden_size, inter_size, q_type2, tp); + + if (gemm_status.IsOK()) { + fc2_used_direct_q4 = true; +#ifdef ONNXRUNTIME_ENABLE_VERBOSE_LOGGING + LOGS_DEFAULT(VERBOSE) << "QMoE: Using direct MLAS Q4 GEMM for FC2 expert " << expert_idx + << " (M=" << num_expert_tokens << ", N=" << hidden_size << ", K=" << inter_size << ")"; +#endif + goto fc2_gemm_done; + } + } + + // If direct Q4 GEMM failed, fall back to traditional approach + } + + // Traditional approach: dequantize + regular GEMM + if (num_fc2_dequant_blocks > 1 && hidden_size >= 32) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_fc2_dequant_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_row = block_idx * fc2_dequant_block_size; + const int64_t end_row = std::min(start_row + fc2_dequant_block_size, hidden_size); + const auto offset = expert_idx * hidden_size * fc2_packed_cols + start_row * fc2_packed_cols; + DequantizeBlock(fc2_weights_data + offset, + fc2_scales_ptr + (is_fc2_block_wise ? start_row * fc2_scales_dims[2] : start_row), + is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, + end_row - start_row, inter_size, B2_dequant + start_row * inter_size, tp); + }); + } else { + DequantizeBlock(fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, + fc2_scales_ptr, + is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, + hidden_size, inter_size, B2_dequant, tp); + } MlasGemm(CblasNoTrans, CblasTrans, - static_cast(num_expert_tokens), static_cast(hidden_size), static_cast(inter_size), - 1.0f, A2, static_cast(inter_size), - B2_dequant, static_cast(inter_size), - 0.0f, C2, static_cast(hidden_size), - nullptr); - - const T* B2_bias = (fc2_experts_bias) ? fc2_experts_bias->Data() + expert_idx * hidden_size : nullptr; - if (B2_bias) { + m2, n2, k2, + 1.0f, A2, k2, + B2_dequant, k2, + 0.0f, C2, n2, + tp); + + fc2_gemm_done: + + bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias; + if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), bias2_float, static_cast(hidden_size)); + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); } else { - memcpy(bias2_float, B2_bias, static_cast(hidden_size) * sizeof(float)); + if (ShouldUseMemcpy(hidden_size)) { + std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); + } else { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + thread_bias2_buffer[j + k] = static_cast(B2_bias[j + k]); + } + } + for (; j < hidden_size; ++j) { + thread_bias2_buffer[j] = static_cast(B2_bias[j]); + } + } } } @@ -331,28 +934,89 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t token_idx = route_idx / k_; const float weight = route_scale[route_idx]; + if (token_idx < 0 || token_idx >= num_tokens) continue; + const size_t buffer_offset = static_cast(token_idx) * static_cast(hidden_size); - if (buffer_offset + static_cast(hidden_size) > output_buffer_size) { - // Skip this token to prevent buffer overflow - continue; - } + if (buffer_offset + static_cast(hidden_size) > output_buffer_size) continue; float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; - for (int64_t j = 0; j < hidden_size; ++j) { - dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f)); + + if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dest[j + k] += weight * (src[j + k] + thread_bias2_buffer[j + k]); + } + } + for (; j < hidden_size; ++j) { + dest[j] += weight * (src[j] + thread_bias2_buffer[j]); + } + } else { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dest[j + k] += weight * src[j + k]; + } + } + for (; j < hidden_size; ++j) { + dest[j] += weight * src[j]; + } } } } }); - // --- 4. Final Reduction (accumulate expert outputs to a float buffer) --- auto accumulate = [&](float* buffer) { - memset(buffer, 0, output_buffer_size * sizeof(float)); - for (int i = 0; i < num_expert_threads; ++i) { - const size_t thread_offset = static_cast(i) * output_buffer_size; - for (size_t j = 0; j < output_buffer_size; ++j) { - buffer[j] += thread_local_outputs[thread_offset + j]; + std::memset(buffer, 0, output_buffer_size * sizeof(float)); + + const int max_acc_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const size_t acc_thread_divisor = std::max(size_t{1}, static_cast(max_acc_threads) * 8); + const size_t min_elements_per_thread = std::max(size_t{32}, output_buffer_size / acc_thread_divisor); + const int optimal_acc_threads = (tp == nullptr || output_buffer_size < min_elements_per_thread) ? 1 : std::min(static_cast(output_buffer_size / std::max(size_t{1}, min_elements_per_thread)), max_acc_threads); + const int num_acc_threads = std::max(1, optimal_acc_threads); + + if (num_acc_threads > 1) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_acc_threads, [&](std::ptrdiff_t acc_thread_id) { + const size_t elements_per_thread = output_buffer_size / static_cast(num_acc_threads); + const size_t start_idx = static_cast(acc_thread_id) * elements_per_thread; + const size_t end_idx = (acc_thread_id == num_acc_threads - 1) ? output_buffer_size : start_idx + elements_per_thread; + + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + const float* src = thread_local_outputs + thread_offset + start_idx; + float* dst = buffer + start_idx; + + size_t j = 0; + const size_t chunk_size = end_idx - start_idx; + const int64_t unroll_factor = GetUnrollFactor(static_cast(chunk_size)); + for (; j + static_cast(unroll_factor) <= chunk_size; j += static_cast(unroll_factor)) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dst[j + static_cast(k)] += src[j + static_cast(k)]; + } + } + for (; j < chunk_size; ++j) { + dst[j] += src[j]; + } + } + }); + } else { + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + const float* src = thread_local_outputs + thread_offset; + + size_t j = 0; + const int64_t unroll_factor = GetUnrollFactor(static_cast(output_buffer_size)); + for (; j + static_cast(unroll_factor) <= output_buffer_size; j += static_cast(unroll_factor)) { + for (int64_t k = 0; k < unroll_factor; ++k) { + buffer[j + static_cast(k)] += src[j + static_cast(k)]; + } + } + for (; j < output_buffer_size; ++j) { + buffer[j] += src[j]; + } } } }; @@ -362,18 +1026,16 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* final_output_float = final_output_float_ptr.get(); accumulate(final_output_float); - // --- 5. Convert final float buffer to output type T --- MlasConvertFloatToHalfBuffer(final_output_float, reinterpret_cast(output->MutableData()), static_cast(output_buffer_size)); - } else { // T is float + } else { accumulate(output->MutableData()); } return Status::OK(); } -// Explicit template instantiation template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 93d802ca05b42..167b2af946183 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -77,7 +77,8 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, nullptr, fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, 1, // no quantization so pack size is 1 - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // no block-wise quantization for sharded MoE ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index a5b9d483d5ad1..e5a064d59e360 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -45,7 +45,8 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, nullptr, fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, 1, // no quantization so pack size is 1 - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // no block-wise quantization for regular MoE using CudaT = typename OrtToCudaType::type; auto stream = context->GetComputeStream(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index dcf32bb3c5ae4..931b8ac09aa49 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -150,7 +150,8 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, fc2_scales, fc3_experts_weights_optional, fc3_experts_bias_optional, fc3_scales_optional, expert_weight_bits_ == 4 ? 2 : 1, - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // CUDA doesn't support block-wise quantization yet #if defined(__GNUC__) #pragma GCC diagnostic push diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 403becbe0616a..0292111b16962 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -128,6 +128,148 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): # Calculate scale like C++ implementation abs_max = weights.abs().max(dim=-1, keepdim=True)[0] + + # Set minimum scale to avoid division by zero + scale = torch.clamp(abs_max, min=1e-6) + + # Quantization ranges for symmetric quantization + if is_4_bit_quantization: + qmin, qmax = -8, 7 + zero_point = 8 # Offset to make values unsigned + else: + qmin, qmax = -128, 127 + zero_point = 128 # Offset to make values unsigned + + # Quantize using double precision division and C-like rounding (half away from zero) + scaled = weights.double() / scale.double() + sign = torch.sign(scaled) + abs_scaled = torch.abs(scaled) + quant_rounded = torch.floor(abs_scaled + 0.5) + quantized = torch.clamp((sign * quant_rounded).to(torch.int32), qmin, qmax).to(weights.dtype) + + # Convert to unsigned and pack for storage + if is_4_bit_quantization: + # Convert to unsigned 4-bit and pack into uint8 + unsigned_quantized = (quantized + zero_point).to(torch.uint8) + + # Pack two 4-bit values into one uint8 + packed_size = (weights.shape[-1] + 1) // 2 + packed_quantized = torch.zeros((*weights.shape[:-1], packed_size), dtype=torch.uint8, device=weights.device) + + for i in range(0, weights.shape[-1], 2): + val1 = unsigned_quantized[..., i] + val2 = unsigned_quantized[..., i + 1] if i + 1 < weights.shape[-1] else torch.zeros_like(val1) + packed_quantized[..., i // 2] = (val1 & 0xF) | ((val2 & 0xF) << 4) + + quantized_storage = packed_quantized + else: + # 8-bit: convert to unsigned uint8 + quantized_storage = (quantized + zero_point).to(torch.uint8) + + # Dequantize for verification (use float32 scale for higher precision) + dequantized = quantized.to(torch.float32) * scale + + return scale.squeeze(-1).to(torch.float32), quantized_storage, dequantized + + +def quant_dequant_blockwise(weights, block_size, is_4_bit_quantization: bool = True): + """ + Block-wise quantization and dequantization for testing purposes. + This function uses symmetric quantization centered around 0 (no zero-point). + + Args: + weights: Input tensor of shape [rows, cols] + block_size: Size of each quantization block + is_4_bit_quantization: Whether to use 4-bit (True) or 8-bit (False) quantization + + Returns: + scales: Scale tensor of shape [rows, num_blocks] + quantized: Quantized tensor + dequantized: Dequantized tensor for verification + """ + rows, cols = weights.shape + num_blocks = (cols + block_size - 1) // block_size + + # Handle edge case of all-zero weights tensor + if torch.all(weights == 0): + scales = torch.zeros((rows, num_blocks), dtype=torch.float16, device=weights.device) + if is_4_bit_quantization: + packed_size = (cols + 1) // 2 + quantized = torch.zeros((rows, packed_size), dtype=torch.uint8, device=weights.device) + else: + quantized = torch.zeros((rows, cols), dtype=torch.uint8, device=weights.device) + dequantized = torch.zeros_like(weights) + return scales, quantized, dequantized + + # Initialize output tensors; use float32 for scales to reduce precision loss + scales = torch.zeros((rows, num_blocks), dtype=torch.float32, device=weights.device) + dequantized = torch.zeros_like(weights) + + # Quantization ranges and zero point + if is_4_bit_quantization: + qmin, qmax = -8, 7 + zero_point = 8 + packed_size = (cols + 1) // 2 + quantized = torch.zeros((rows, packed_size), dtype=torch.uint8, device=weights.device) + else: + qmin, qmax = -128, 127 + zero_point = 128 + quantized = torch.zeros((rows, cols), dtype=torch.uint8, device=weights.device) + + # Process each block with higher-precision math to match C++ behavior + for row in range(rows): + for block_idx in range(num_blocks): + start_col = block_idx * block_size + end_col = min(start_col + block_size, cols) + + # Get block data + block_data = weights[row, start_col:end_col] + + # Calculate absolute max and ensure small epsilon to avoid div-by-zero + abs_max = block_data.abs().max() + abs_max = torch.clamp(abs_max, min=1e-8) + + # Compute scale consistent with C++: use 7.0 for 4-bit positive max, 127.0 for 8-bit + if is_4_bit_quantization: + # Use higher precision then keep as float32 for scale + scale = (abs_max.double() / 7.0).float() + 1e-12 + else: + scale = (abs_max.double() / 127.0).float() + 1e-12 + + scales[row, block_idx] = scale.to(torch.float32) + + if scale == 0: + continue + + # Quantize using double precision for the division to reduce rounding error + scaled = block_data.double() / scale.double() + # Emulate C's round() behavior (round half away from zero) to match C++ implementation + sign = torch.sign(scaled) + abs_scaled = torch.abs(scaled) + quant_rounded = torch.floor(abs_scaled + 0.5) + quantized_block = (sign * quant_rounded).clamp(qmin, qmax).to(torch.int32) + + # Pack for 4-bit or store directly for 8-bit + if is_4_bit_quantization: + for i in range(0, end_col - start_col, 2): + col_idx = start_col + i + packed_idx = col_idx // 2 + + val1 = int(quantized_block[i]) + zero_point + val2 = int(quantized_block[i + 1]) + zero_point if i + 1 < len(quantized_block) else zero_point + + # Pack two 4-bit values into one uint8 + packed_val = (val1 & 0xF) | ((val2 & 0xF) << 4) + quantized[row, packed_idx] = packed_val + else: + quantized_vals = (quantized_block + zero_point).to(torch.uint8) + quantized[row, start_col:end_col] = quantized_vals + + # Dequantize for verification (signed quantized values multiplied by scale) + signed = quantized_block.to(torch.float32) + dequantized[row, start_col:end_col] = signed * scale + + return scales, quantized, dequantized abs_max = torch.clamp(abs_max, min=1e-8) # More conservative clamping for better precision if is_4_bit_quantization: @@ -247,6 +389,7 @@ def create_cpu_moe_onnx_graph( use_quant=False, quant_bits=4, swiglu_interleaved=False, + block_size=0, # New parameter for block-wise quantization ): if not has_onnx: return None @@ -264,14 +407,13 @@ def create_cpu_moe_onnx_graph( if not has_onnx: return None - if use_quant: - # Assertions only apply to quantized MoE - assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" - assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" - assert fc1_scales is not None, "FC1 scales must be provided for QMoE" - assert fc2_scales is not None, "FC2 scales must be provided for QMoE" - assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" - assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" + assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" + assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" + assert fc1_scales is not None, "FC1 scales must be provided for QMoE" + assert fc2_scales is not None, "FC2 scales must be provided for QMoE" + # Accept float16 or float32 scales; tests may produce float32 for better precision + assert fc1_scales.dtype in (torch.float16, torch.float32), "FC1 scales must be float16 or float32 for QMoE" + assert fc2_scales.dtype in (torch.float16, torch.float32), "FC2 scales must be float16 or float32 for QMoE" if not has_onnx: return None @@ -332,6 +474,10 @@ def create_cpu_moe_onnx_graph( if use_quant: nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + # Add block_size attribute for block-wise quantization + if block_size > 0: + nodes[0].attribute.extend([helper.make_attribute("block_size", block_size)]) + # Weights are store in column major order. Need pack 2 int4 values into uint8. # Use the actual tensor shapes instead of calculating them to avoid size mismatches fc1_shape = list(fc1_experts_weights.shape) @@ -342,30 +488,59 @@ def create_cpu_moe_onnx_graph( weight_numpy_type = numpy.uint8 if use_quant else ort_to_numpy_type_map[onnx_dtype] weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + # Use raw bytes from C-contiguous numpy arrays to ensure the exact memory layout + # of the packed uint8 weight tensors is preserved when writing the ONNX initializer. + fc1_np = fc1_experts_weights.detach().cpu().numpy().astype(weight_numpy_type) + fc2_np = fc2_experts_weights.detach().cpu().numpy().astype(weight_numpy_type) + fc1_np = numpy.ascontiguousarray(fc1_np) + fc2_np = numpy.ascontiguousarray(fc2_np) + initializers = [ helper.make_tensor( "fc1_experts_weights", weight_onnx_type, fc1_shape, - fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + fc1_np.tobytes(), + raw=True, ), helper.make_tensor( "fc2_experts_weights", weight_onnx_type, fc2_shape, - fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + fc2_np.tobytes(), + raw=True, ), ] - fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] - fc2_scale_shape = [num_experts, hidden_size] + # Calculate scale tensor shapes based on block_size + if block_size > 0: + # Block-wise quantization: 3D scale tensors + fc1_blocks_per_row = (hidden_size + block_size - 1) // block_size + fc2_blocks_per_row = (inter_size + block_size - 1) // block_size - fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) - fc2_scale_size = num_experts * hidden_size + fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size, fc1_blocks_per_row] + fc2_scale_shape = [num_experts, hidden_size, fc2_blocks_per_row] - # Handle scale tensors based on quantization mode - if use_quant: - # Handle different possible scale tensor structures for fc1_scales + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) * fc1_blocks_per_row + fc2_scale_size = num_experts * hidden_size * fc2_blocks_per_row + else: + # Row-wise quantization: 2D scale tensors + fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] + fc2_scale_shape = [num_experts, hidden_size] + + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) + fc2_scale_size = num_experts * hidden_size + + # Handle scale tensors - fc1_scales and fc2_scales are guaranteed to be not None due to earlier assertions + # Process scale tensors based on whether block-wise quantization is used + if block_size > 0: + # For block-wise quantization, the scales are already in the correct 3D shape + # [num_experts, output_features, num_blocks] from quant_dequant_blockwise + # Convert scales to the selected ONNX dtype (prefer float32 for higher precision) + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + else: + # For row-wise quantization, handle different possible scale tensor structures for fc1_scales if len(fc1_scales.shape) == 4: # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output if use_swiglu: @@ -395,10 +570,6 @@ def create_cpu_moe_onnx_graph( [fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)] ) - # Process scale tensor for proper shape - fc1_scale_data_list = fc1_scale_tensor.tolist() - fc1_scale_data = fc1_scale_data_list - # Handle different possible scale tensor structures for fc2_scales if len(fc2_scales.shape) == 4: # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output @@ -421,48 +592,30 @@ def create_cpu_moe_onnx_graph( [fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)] ) - # Process scale tensor for proper shape - fc2_scale_data_list = fc2_scale_tensor.tolist() - fc2_scale_data = fc2_scale_data_list - - initializers.extend( - [ - helper.make_tensor( - "fc1_scales", - onnx_dtype, - fc1_scale_shape, - fc1_scale_data, - raw=False, - ), - helper.make_tensor( - "fc2_scales", - onnx_dtype, - fc2_scale_shape, - fc2_scale_data, - raw=False, - ), - ] - ) - else: - # For non-quantized mode, add bias tensors if provided - if fc1_bias is not None: - initializers.append( - helper.make_tensor( - "fc1_experts_bias", - onnx_dtype, - list(fc1_bias.shape), - fc1_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(), - ) - ) - if fc2_bias is not None: - initializers.append( - helper.make_tensor( - "fc2_experts_bias", - onnx_dtype, - list(fc2_bias.shape), - fc2_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(), - ) - ) + # Process scale tensors for proper data format + fc1_scale_data_list = fc1_scale_tensor.tolist() + fc1_scale_data = fc1_scale_data_list + fc2_scale_data_list = fc2_scale_tensor.tolist() + fc2_scale_data = fc2_scale_data_list + + initializers.extend( + [ + helper.make_tensor( + "fc1_scales", + onnx_dtype, + fc1_scale_shape, + fc1_scale_data, + raw=False, + ), + helper.make_tensor( + "fc2_scales", + onnx_dtype, + fc2_scale_shape, + fc2_scale_data, + raw=False, + ), + ] + ) graph_inputs = [ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), @@ -645,10 +798,7 @@ class SparseMoeBlockORTHelper(nn.Module): def __init__(self, quant_bits=0, onnx_dtype=None): super().__init__() self.quant_bits = quant_bits - if onnx_dtype is None: - self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT - else: - self.onnx_dtype = onnx_dtype + self.onnx_dtype = onnx_dtype self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 def create_ort_session(self, moe_onnx_graph): @@ -717,8 +867,8 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False tensors = { "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), - "router_probs": router_input.clone().to(device=device, dtype=torch_dtype), - "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros((batch_size * sequence_length, hidden_dim), device=device, dtype=torch_dtype), } try: @@ -779,14 +929,47 @@ def recreate_onnx_model(self): is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) - w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) + if self.block_size > 0: + # Use block-wise quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant_blockwise( + self.experts[i].w1.weight, self.block_size, is_4_bit + ) + w2_scale, pre_qweight2, w2_qdq = quant_dequant_blockwise( + self.experts[i].w2.weight, self.block_size, is_4_bit + ) + else: + # Use row-wise quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) if self.use_swiglu: - # For SwiGLU, CPU kernel now always expects interleaved format - # SwigluMlp weights are already in interleaved format [gate_0, linear_0, gate_1, linear_1, ...] - # No conversion needed - both CPU and CUDA use interleaved format - self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + if self.swiglu_interleaved: + pass + else: + if self.block_size > 0: + w3_scale, pre_qweight3, w3_qdq = quant_dequant_blockwise( + self.experts[i].w3.weight, self.block_size, is_4_bit + ) + else: + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) + + gate_weights = pre_qweight1 + value_weights = pre_qweight3 + gate_scales = w1_scale + value_scales = w3_scale + + pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) + w1_scale = torch.cat([gate_scales, value_scales], dim=0) + + if self.swiglu_interleaved: + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + + else: + intermediate_size = self.experts[i].w1.weight.shape[0] + gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() + value_dequant = w1_qdq[intermediate_size:].contiguous().clone() + self.experts[i].w1.weight.data = gate_dequant + self.experts[i].w3.weight.data = value_dequant else: self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() @@ -828,7 +1011,8 @@ def recreate_onnx_model(self): use_swiglu=self.use_swiglu, use_quant=True, # Always use QMoE quant_bits=self.quant_bits, - swiglu_interleaved=True, # CPU kernel now always expects interleaved format + swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + block_size=self.block_size, # Add block_size for block-wise quantization ) except Exception: self.moe_onnx_graph = None @@ -877,6 +1061,45 @@ def parity_check(self): print(f"Parity check - {act_type} {self.quant_bits}-bit: max_diff = {max_diff:.6f}") + # Diagnostic dump: when differences are large, show the index and nearby values + if max_diff > 1e-3: + diff = (torch_output.cpu() - ort_output.cpu()).abs() + idx = torch.argmax(diff) + flat_idx = int(idx) + # Derive coordinates (batch, seq, hidden) from flattened index + total_elems = torch_output.numel() + # Work in flattened [batch, seq, hidden] ordering + hidden_dim = self.hidden_dim + seq = self.sequence_length + # Clamp to safe bounds + flat_idx = min(flat_idx, total_elems - 1) + i = flat_idx // (hidden_dim) + j = i // seq + k = flat_idx % hidden_dim + print( + f"Diagnostic - max diff at flat_idx={flat_idx} -> sample (batch_idx={j}, seq_idx={i % seq}, hidden_idx={k})" + ) + print("Torch sample:", torch_output.cpu().reshape(-1, hidden_dim)[i, k].item()) + print("ORT sample:", ort_output.cpu().reshape(-1, hidden_dim)[i, k].item()) + # Print routing and per-expert contributions for this token from the PyTorch reference + try: + hidden_states_flat = hidden_state.view(-1, hidden_dim) + token_vec = hidden_states_flat[i : i + 1] + gate_logits = self.gate(token_vec) + topk_vals, topk_experts = torch.topk(gate_logits, self.top_k, dim=-1) + topk_soft = F.softmax(topk_vals, dim=1) + print("Gate logits:", gate_logits.detach().cpu().numpy()) + print("Selected experts:", topk_experts.detach().cpu().numpy()) + print("Routing weights:", topk_soft.detach().cpu().numpy()) + # Compute per-expert contributions for selected experts + for idx_e, e in enumerate(topk_experts[0].tolist()): + expert_layer = self.experts[e] + expert_out = expert_layer(token_vec) + contrib = expert_out[0, k].item() * topk_soft[0, idx_e].item() + print(f"Expert {e} contrib at hidden {k}: {contrib}") + except Exception as _: + pass + ort_dtype_quant_bits_tolerance_map = { "FP32:0": (5e-3, 1e-3), "FP16:0": (5e-2, 1e-3), @@ -917,7 +1140,13 @@ def small_test_cases(): class SwigluMoEBlock(SparseMoeBlockORTHelper): def __init__( - self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + self, + config: SwigluMoeConfig, + batch_size: int, + sequence_length: int, + quant_bits: int = 0, + onnx_dtype=None, + block_size: int = 0, ): super().__init__(quant_bits, onnx_dtype=onnx_dtype) self.hidden_dim = config.hidden_size @@ -926,6 +1155,7 @@ def __init__( self.top_k = config.num_experts_per_token self.use_swiglu = True self.swiglu_interleaved = True + self.block_size = block_size # Store block_size for QMoE use_quant = self.quant_bits > 0 self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) @@ -995,7 +1225,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): def __init__( - self, config: PhiMoEConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + self, + config: PhiMoEConfig, + batch_size: int, + sequence_length: int, + quant_bits: int = 0, + onnx_dtype=None, + block_size: int = 0, ): super().__init__(quant_bits, onnx_dtype=onnx_dtype) self.hidden_dim = config.hidden_size @@ -1005,6 +1241,7 @@ def __init__( self.router_jitter_noise = config.router_jitter_noise self.use_swiglu = True self.swiglu_interleaved = True + self.block_size = block_size # Store block_size for QMoE use_quant = self.quant_bits > 0 self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) @@ -1024,8 +1261,14 @@ def __init__( else: is_4_bit = self.quant_bits == 4 - scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) - scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + if self.block_size > 0: + # Use block-wise quantization + scale1, pre_qweight1, w1_qdq = quant_dequant_blockwise(expert.w1.weight, self.block_size, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant_blockwise(expert.w2.weight, self.block_size, is_4_bit) + else: + # Use row-wise quantization + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) expert.w1.weight.data = w1_qdq expert.w2.weight.data = w2_qdq @@ -1064,6 +1307,7 @@ def __init__( use_quant=use_quant, quant_bits=self.quant_bits, swiglu_interleaved=self.swiglu_interleaved, + block_size=self.block_size, # Add block_size for block-wise quantization ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None @@ -1075,9 +1319,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # Match CPU implementation: select top-k experts by logits, then softmax over those logits + routing_weights_vals, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + routing_weights = F.softmax(routing_weights_vals, dim=1, dtype=torch.float) routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( @@ -1112,6 +1356,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: (2, 16, 8), ] +# Define test cases for block-wise quantization +phi3_blockwise_test_cases = [ + (1, 32, 4, 32), # batch_size, sequence_length, quant_bits, block_size + (1, 32, 8, 64), + (2, 16, 4, 32), + (2, 16, 8, 64), +] + @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestPhiQMoECPU(unittest.TestCase): @@ -1152,6 +1404,37 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): phi3_moe.parity_check() + @parameterized.expand(phi3_blockwise_test_cases) + def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, block_size={block_size}" + print(f"Running Phi3 QMoE block-wise test: {test_config}") + + config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2) + + phi3_moe = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + block_size=block_size, # Enable block-wise quantization + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = phi3_moe.forward(hidden_states) + + # Verify output shape and basic properties + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + phi3_moe.parity_check() + disable_cpu_qmoe_tests = False @@ -1162,6 +1445,14 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): (2, 16, 8), ] +# Define test cases for block-wise quantization +swiglu_blockwise_test_cases = [ + (1, 32, 4, 32), # batch_size, sequence_length, quant_bits, block_size + (1, 32, 8, 64), + (2, 16, 4, 32), + (2, 16, 8, 64), +] + @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestSwigluQMoECPU(unittest.TestCase): @@ -1201,6 +1492,36 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): swiglu_moe.parity_check() + @parameterized.expand(swiglu_blockwise_test_cases) + def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, block_size={block_size}" + print(f"Running SwiGLU block-wise test: {test_config}") + + config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2) + + swiglu_moe = SwigluMoEBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + block_size=block_size, # Enable block-wise quantization + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = swiglu_moe.forward(hidden_states) + + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + swiglu_moe.parity_check() + @unittest.skipIf(True, "Skipping QMoE CPU benchmark tests") class TestQMoESwiGLUBenchmark(unittest.TestCase):