diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index 257c5a189b3bd..bd30418030dc2 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -35,44 +35,86 @@ struct MoEParameters { }; namespace moe_helper { +// Helper to check shape dimensions +#define ASSERT_SHAPE_DIMENSION(shape_ptr, dim, name) \ + if (shape_ptr != nullptr) { \ + if (shape_ptr->NumDimensions() != dim) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \ + "' is expected to have ", dim, " dimensions, got ", \ + shape_ptr->NumDimensions()); \ + } \ + } + +#define ASSERT_SHAPE_3D(shape_ptr, name) ASSERT_SHAPE_DIMENSION(shape_ptr, 3, name) + +#define CHECK_SHAPE(shape_ptr, name, ...) \ + if (shape_ptr != nullptr) { \ + const TensorShape& expected_shape = make_shape(__VA_ARGS__); \ + if (*shape_ptr != expected_shape) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \ + "' is expected to have shape ", expected_shape, \ + ", got ", *shape_ptr); \ + } \ + } + template Status CheckInputs(MoEParameters& parameters, - const Tensor* input, // required - const Tensor* router_probs, // required - const Tensor* fc1_experts_weights, // required - const Tensor* fc1_experts_bias, // optional - const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc1_zero_points, // optional, for qMoE - const Tensor* fc2_experts_weights, // required - const Tensor* fc2_experts_bias, // optional - const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc2_zero_points, // optional, for qMoE - const Tensor* fc3_experts_weights, // optional - const Tensor* fc3_experts_bias, // optional - const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc3_zero_points, // optional, for qMoE - const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) + const Tensor* input, // required + const Tensor* router_probs, // required + const TensorShape* fc1_experts_weights_shape, // required + const Tensor* fc1_experts_bias, // optional + const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc1_zero_points, // optional, for qMoE + const TensorShape* fc2_experts_weights_shape, // required + const Tensor* fc2_experts_bias, // optional + const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc2_zero_points, // optional, for qMoE + const TensorShape* fc3_experts_weights_shape, // optional + const Tensor* fc3_experts_bias, // optional + const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc3_zero_points, // optional, for qMoE + const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) const bool is_fused_swiglu, const int64_t block_size = 0) { // block size for block-wise quantization - // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. + // Required inputs + if (input == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is required."); + } ASSERT_TENSOR_2D_OR_3D(input); - ASSERT_TENSOR_3D(fc1_experts_weights); - ASSERT_TENSOR_3D(fc2_experts_weights); + + if (router_probs == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'router_probs' is required."); + } ASSERT_TENSOR_2D(router_probs); + if (fc1_experts_weights_shape == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc1_experts_weights' is required."); + } + ASSERT_SHAPE_3D(fc1_experts_weights_shape, "fc1_experts_weights"); + + if (fc2_experts_weights_shape == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc2_experts_weights' is required."); + } + ASSERT_SHAPE_3D(fc2_experts_weights_shape, "fc2_experts_weights"); + const auto& input_dims = input->Shape().GetDims(); const auto& router_probs_dims = router_probs->Shape().GetDims(); - const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); - const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; int64_t hidden_size = input_dims[input_dims.size() - 1]; - int64_t local_num_experts = fc1_experts_weights_dims[0]; int64_t num_experts = router_probs_dims[1]; - int64_t inter_size = (fc2_experts_weights_dims[1] * fc2_experts_weights_dims[2] * pack_size) / hidden_size; - const bool legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || - (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); + int64_t local_num_experts = fc1_experts_weights_shape->GetDims()[0]; + + int64_t inter_size = (fc2_experts_weights_shape->GetDims()[1] * + fc2_experts_weights_shape->GetDims()[2] * pack_size) / + hidden_size; + + bool legacy_shape = false; + const auto& fc2_experts_weights_dims = fc2_experts_weights_shape->GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights_shape->GetDims(); + legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || + (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); // Fused swiglu doubles the output dimension of FC1 since it fused two GEMMs into one. const int64_t fc1_inter_size = is_fused_swiglu ? (inter_size + inter_size) : inter_size; @@ -80,13 +122,13 @@ Status CheckInputs(MoEParameters& parameters, if (legacy_shape) { // legacy shape does not match column major memory layout. This is for backward compatibility. - CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size); - CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size); - CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size); + CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, hidden_size, fc1_inter_size / pack_size); + CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, hidden_size, inter_size / pack_size); } else { - CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size); - CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size); - CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, fc1_inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, hidden_size, inter_size / pack_size); + CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, inter_size, hidden_size / pack_size); } CHECK_TENSOR_SHAPE(router_probs, num_rows, num_experts); @@ -168,9 +210,11 @@ Status CheckInputs(MoEParameters& parameters, } } - if (fc3_experts_weights == nullptr) { + if (fc3_experts_weights_shape == nullptr) { + // If fc3 weights are not provided, ensure no other fc3 parameters are provided ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr && fc3_zero_points == nullptr); } else { + // If fc3 weights are provided, ensure scales logic is consistent ORT_ENFORCE(fc1_experts_scales == nullptr || fc3_experts_scales != nullptr); // MOE no scale, or qMOE need scales } @@ -200,6 +244,36 @@ Status CheckInputs(MoEParameters& parameters, return Status::OK(); } +template +Status CheckInputs(MoEParameters& parameters, + const Tensor* input, // required + const Tensor* router_probs, // required + const Tensor* fc1_experts_weights, // required + const Tensor* fc1_experts_bias, // optional + const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc1_zero_points, // optional, for qMoE + const Tensor* fc2_experts_weights, // required + const Tensor* fc2_experts_bias, // optional + const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc2_zero_points, // optional, for qMoE + const Tensor* fc3_experts_weights, // optional + const Tensor* fc3_experts_bias, // optional + const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc3_zero_points, // optional, for qMoE + const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) + const bool is_fused_swiglu, + const int64_t block_size = 0) { // block size for block-wise quantization + + const TensorShape* fc1_shape = (fc1_experts_weights != nullptr) ? &fc1_experts_weights->Shape() : nullptr; + const TensorShape* fc2_shape = (fc2_experts_weights != nullptr) ? &fc2_experts_weights->Shape() : nullptr; + const TensorShape* fc3_shape = (fc3_experts_weights != nullptr) ? &fc3_experts_weights->Shape() : nullptr; + + return CheckInputs(parameters, input, router_probs, fc1_shape, fc1_experts_bias, fc1_experts_scales, fc1_zero_points, + fc2_shape, fc2_experts_bias, fc2_experts_scales, fc2_zero_points, + fc3_shape, fc3_experts_bias, fc3_experts_scales, fc3_zero_points, + pack_size, is_fused_swiglu, block_size); +} + } // namespace moe_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 14bddaf324ae7..81d2b0f8efdc6 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -13,6 +13,7 @@ #include "core/common/narrow.h" #include "core/framework/tensor_type_and_shape.h" #include "core/util/math.h" +#include "core/platform/env_var_utils.h" #include "contrib_ops/cpu/moe/moe_utils.h" #include "contrib_ops/cpu/moe/moe_helper.h" @@ -69,13 +70,13 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, out_qtype = BlkQ4Sym64; } else if (block_size == 128) { out_qtype = BlkQ4Sym128; - } else if (block_size == 0) { + } else if (block_size == 0 || block_size == 32) { out_qtype = BlkQ4Sym; } else { return false; } - size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(cols), static_cast(rows)); + size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(rows), static_cast(cols)); return expected_size > 0; } @@ -84,6 +85,8 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, namespace onnxruntime { namespace contrib { +constexpr const char* kUseMlasQ4GemmMoe = "ORT_USE_MLAS_Q4_GEMM_MOE"; + template void DequantizeBlockWithMlas(const uint8_t* quantized_data, const TScale* scales, @@ -364,6 +367,257 @@ void DequantizeBlock(const uint8_t* quantized_data, DequantizeBlockWithMlas(quantized_data, scales, zero_points, block_size, num_bits, rows, cols, dequantized_data, thread_pool); } +template +void DequantizePrePacked(const uint8_t* prepacked_data, + const TScale* scales, + const uint8_t* zero_points, + int64_t block_size, + int64_t rows, + int64_t cols, + float* dequantized_data, + const gsl::span& scale_dims) { + // prepacked_data is [cols, rows] (transposed, unpacked) + // dequantized_data is [cols, rows] (transposed) + // scales, zero_points correspond to original [rows, cols] layout + + const float default_zp_4bit = 8.0f; + const int64_t blocks_per_row = (block_size > 0) ? ((cols + block_size - 1) / block_size) : 1; + const int64_t zp_pack_size = 2; // Always 2 for 4-bit + + // Iterate over Columns (K) then Rows (N) because prepacked_data is [K, N] + for (int64_t c = 0; c < cols; ++c) { + for (int64_t r = 0; r < rows; ++r) { + uint8_t val = prepacked_data[c * rows + r]; + + int64_t block_idx = (block_size > 0) ? (c / block_size) : 0; + if (block_size > 0) block_idx = std::min(block_idx, blocks_per_row - 1); + + int64_t scale_idx; + if (scale_dims.size() == 3 && scale_dims[2] > 1) { // block-wise + scale_idx = r * blocks_per_row + block_idx; + } else { // per-channel + scale_idx = r; + } + + float scale = static_cast(scales[scale_idx]); + float zp = default_zp_4bit; + + if (zero_points != nullptr) { + int64_t zp_idx; + bool is_lower_nibble; + + if (scale_dims.size() == 3 && scale_dims[2] > 1) { // block-wise + int64_t zp_blocks_packed = (blocks_per_row + zp_pack_size - 1) / zp_pack_size; + zp_idx = r * zp_blocks_packed + block_idx / 2; + is_lower_nibble = (block_idx % 2 == 0); + } else { + zp_idx = r / 2; + is_lower_nibble = (r % 2 == 0); + } + + uint8_t packed_zp = zero_points[zp_idx]; + zp = is_lower_nibble ? static_cast(packed_zp & 0x0F) : static_cast(packed_zp >> 4); + } + + dequantized_data[c * rows + r] = scale * (static_cast(val) - zp); + } + } +} + +template +Status BuildDirectQ4PackedBCache(const uint8_t* prepacked_weights, + const TScale* scales_data, + int64_t num_experts, + int64_t rows, + int64_t cols, + int64_t block_size, + const gsl::span& scales_dims, + MLAS_BLK_QUANT_TYPE qtype, + AllocatorPtr allocator, + IAllocatorUniquePtr& packed_b) { + const size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); + if (packed_size == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to compute MLAS Q4 packed size for cache"); + } + + const bool is_block_wise = (scales_dims.size() == 3 && scales_dims[2] > 1); + const int64_t scales_expert_stride = is_block_wise ? (rows * scales_dims[2]) : rows; + const size_t prepacked_expert_stride = static_cast(rows * cols); + const size_t total_packed_size = packed_size * static_cast(num_experts); + + packed_b = IAllocator::MakeUniquePtr(allocator, total_packed_size, true); + uint8_t* packed_b_ptr = static_cast(packed_b.get()); + + std::vector dequantized_transposed(static_cast(rows * cols)); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const uint8_t* expert_prepacked = prepacked_weights + static_cast(expert_idx) * prepacked_expert_stride; + const TScale* expert_scales = scales_data + expert_idx * scales_expert_stride; + + DequantizePrePacked(expert_prepacked, expert_scales, nullptr, block_size, rows, cols, + dequantized_transposed.data(), scales_dims); + + MlasQ4GemmPackB(qtype, packed_b_ptr + expert_idx * packed_size, dequantized_transposed.data(), + static_cast(rows), static_cast(cols), static_cast(rows)); + } + + return Status::OK(); +} + +template +Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + + // If scales are prepacked, they are constant initializers. + if (input_idx == 3) { + return Status::OK(); + } + if (input_idx == 6) { + return Status::OK(); + } + + // Only support PrePack for FC1 (2) and FC2 (5) weights + // and only if expert_weight_bits_ == 4 (since we unpack to uint8) + if (expert_weight_bits_ != 4) { + return Status::OK(); + } + + if (input_idx == 2 || input_idx == 5) { + const auto& shape = tensor.Shape(); + const int64_t num_experts = shape[0]; + const int64_t rows = shape[1]; + const int64_t cols_packed = shape[2]; + const int64_t cols = cols_packed * 2; + + size_t packed_size = static_cast(num_experts * rows * cols); + auto packed_buffer = IAllocator::MakeUniquePtr(alloc, packed_size, true); + uint8_t* dst_base = static_cast(packed_buffer.get()); + const uint8_t* src_base = static_cast(tensor.DataRaw()); + + for (int64_t i = 0; i < num_experts; ++i) { + const uint8_t* src = src_base + i * rows * cols_packed; + uint8_t* dst = dst_base + i * rows * cols; + + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + uint8_t packed_val = src[r * cols_packed + (c / 2)]; + uint8_t val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); + + dst[c * rows + r] = val; + } + } + } + + if (input_idx == 2) { + fc1_shape_ = shape; + } else if (input_idx == 5) { + fc2_shape_ = shape; + } + + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_buffer)); + prepacked_weights->buffer_sizes_.push_back(packed_size); + is_packed = true; + + // Pack Shape (Buffer 1) + auto dims = shape.GetDims(); + size_t rank_bytes = sizeof(int64_t); + size_t dims_bytes = dims.size() * sizeof(int64_t); + size_t shape_size = rank_bytes + dims_bytes; + + auto shape_buffer = IAllocator::MakeUniquePtr(alloc, shape_size); + int64_t* buffer_data = static_cast(shape_buffer.get()); + *buffer_data = static_cast(dims.size()); + memcpy(buffer_data + 1, dims.data(), dims_bytes); + + prepacked_weights->buffers_.push_back(std::move(shape_buffer)); + prepacked_weights->buffer_sizes_.push_back(shape_size); + + // Try build MLAS Q4 cache if scales are available + if (use_mlas_q4_gemm_) { + const Tensor* scales_tensor = nullptr; + MLAS_BLK_QUANT_TYPE qtype = BlkQ4Sym; + int scales_idx = -1; + int zp_idx = -1; + + if (input_idx == 2) { // FC1 + scales_idx = 3; + zp_idx = 11; + } else if (input_idx == 5) { // FC2 + scales_idx = 6; + zp_idx = 12; + } + + if (scales_idx != -1 && + (zp_idx >= static_cast(Info().node().InputDefs().size()) || !Info().node().InputDefs()[zp_idx]->Exists()) && + Info().TryGetConstantInput(scales_idx, &scales_tensor) && + scales_tensor != nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, block_size_, rows, cols, qtype)) { + IAllocatorUniquePtr cache_buffer; + const auto& scales_dims = scales_tensor->Shape().GetDims(); + const T* scales_data = scales_tensor->Data(); + // Use the simple packed buffer we just created (buffer 0) as input + const uint8_t* simple_packed = dst_base; + + if (BuildDirectQ4PackedBCache(simple_packed, scales_data, num_experts, rows, cols, + block_size_, scales_dims, qtype, + alloc, cache_buffer) + .IsOK()) { + // Store the MLAS Q4 cache as buffer 2 (after unpacked weights and shape). + size_t cache_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)) * static_cast(num_experts); + prepacked_weights->buffers_.push_back(std::move(cache_buffer)); + prepacked_weights->buffer_sizes_.push_back(cache_size); + } + } + } + } + } + + return Status::OK(); +} + +template +Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + + if (expert_weight_bits_ != 4) { + return Status::OK(); + } + + if ((input_idx == 2 || input_idx == 5) && !prepacked_buffers.empty()) { + auto parse_shape = [&](TensorShape& shape) { + if (prepacked_buffers.size() > 1) { + int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); + int64_t rank = buffer_data[0]; + std::vector dims(static_cast(rank)); + memcpy(dims.data(), buffer_data + 1, static_cast(rank) * sizeof(int64_t)); + shape = TensorShape(dims); + } + }; + + if (input_idx == 2) { + packed_fc1_ = std::move(prepacked_buffers[0]); + parse_shape(fc1_shape_); + if (prepacked_buffers.size() > 2) { + packed_fc1_mlas_cache_ = std::move(prepacked_buffers[2]); + } + } else if (input_idx == 5) { + packed_fc2_ = std::move(prepacked_buffers[0]); + parse_shape(fc2_shape_); + if (prepacked_buffers.size() > 2) { + packed_fc2_mlas_cache_ = std::move(prepacked_buffers[2]); + } + } + used_shared_buffers = true; + } + + return Status::OK(); +} + template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), @@ -372,21 +626,32 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, "Attribute 'expert_weight_bits' must be 4 or 8."); block_size_ = op_kernel_info.GetAttrOrDefault("block_size", 0); + ORT_ENFORCE(block_size_ >= 0); if (block_size_ > 0) { ORT_ENFORCE(block_size_ >= 16, "block_size must be >= 16 when provided."); ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); } + + const auto use_mlas_q4_gemm = ParseEnvironmentVariable(kUseMlasQ4GemmMoe); + if (use_mlas_q4_gemm.has_value()) { + use_mlas_q4_gemm_ = *use_mlas_q4_gemm; + use_mlas_q4_gemm_overridden_ = true; + } else { + // Default policy: enable fast path unless this run hits a known accuracy-loss configuration. + use_mlas_q4_gemm_ = true; + use_mlas_q4_gemm_overridden_ = false; + } } template Status QMoECPU::Compute(OpKernelContext* context) const { const auto* input = context->Input(0); const auto* router_probs = context->Input(1); - const auto* fc1_experts_weights = context->Input(2); + const auto* fc1_experts_weights = packed_fc1_ ? nullptr : context->Input(2); const auto* fc1_scales = context->Input(3); const auto* fc1_experts_bias = context->Input(4); - const auto* fc2_experts_weights = context->Input(5); + const auto* fc2_experts_weights = packed_fc2_ ? nullptr : context->Input(5); const auto* fc2_scales = context->Input(6); const auto* fc2_experts_bias = context->Input(7); const auto* fc3_experts_weights = context->Input(8); @@ -396,17 +661,21 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const auto* fc2_zero_points = context->Input(12); const auto* fc3_zero_points = context->Input(13); + const TensorShape* fc1_shape_ptr = packed_fc1_ ? &fc1_shape_ : (fc1_experts_weights ? &fc1_experts_weights->Shape() : nullptr); + const TensorShape* fc2_shape_ptr = packed_fc2_ ? &fc2_shape_ : (fc2_experts_weights ? &fc2_experts_weights->Shape() : nullptr); + const TensorShape* fc3_shape_ptr = fc3_experts_weights ? &fc3_experts_weights->Shape() : nullptr; + MoEParameters moe_params; ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias, fc1_scales, fc1_zero_points, - fc2_experts_weights, fc2_experts_bias, fc2_scales, fc2_zero_points, - fc3_experts_weights, fc3_experts_bias, fc3_scales, fc3_zero_points, + fc1_shape_ptr, fc1_experts_bias, fc1_scales, fc1_zero_points, + fc2_shape_ptr, fc2_experts_bias, fc2_scales, fc2_zero_points, + fc3_shape_ptr, fc3_experts_bias, fc3_scales, fc3_zero_points, expert_weight_bits_ == 4 ? 2 : 1, - true, + activation_type_ == ActivationType::SwiGLU, block_size_)); - if (fc3_experts_weights || fc3_experts_bias || fc3_scales || fc3_zero_points) { + if (fc3_shape_ptr || fc3_experts_bias || fc3_scales || fc3_zero_points) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); } @@ -569,8 +838,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const bool is_fc1_block_wise = (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1); const bool is_fc2_block_wise = (fc2_scales_dims.size() == 3 && fc2_scales_dims[2] > 1); - const uint8_t* fc1_weights_data = fc1_experts_weights->Data(); - const uint8_t* fc2_weights_data = fc2_experts_weights->Data(); + const uint8_t* fc1_weights_data = (packed_fc1_ != nullptr) ? nullptr : fc1_experts_weights->template Data(); + const uint8_t* fc2_weights_data = (packed_fc2_ != nullptr) ? nullptr : fc2_experts_weights->template Data(); const T* fc1_scales_data = fc1_scales->Data(); const T* fc2_scales_data = fc2_scales->Data(); const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; @@ -578,6 +847,13 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const uint8_t* fc1_zp_data = fc1_zero_points ? fc1_zero_points->Data() : nullptr; const uint8_t* fc2_zp_data = fc2_zero_points ? fc2_zero_points->Data() : nullptr; + // Known loss-prone case from parity testing: 4-bit symmetric path (row-wise and block-wise). + const bool known_accuracy_loss_case = (expert_weight_bits_ == 4) && + (fc1_zp_data == nullptr) && (fc2_zp_data == nullptr); + const bool use_mlas_q4_gemm_effective = use_mlas_q4_gemm_overridden_ + ? use_mlas_q4_gemm_ + : (use_mlas_q4_gemm_ && !known_accuracy_loss_case); + const int64_t pack_unit = (8 / expert_weight_bits_); const int64_t fc1_packed_cols = (hidden_size + pack_unit - 1) / pack_unit; const int64_t fc2_packed_cols = (inter_size + pack_unit - 1) / pack_unit; @@ -605,6 +881,22 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_zp_expert_stride = (hidden_size + zp_pack_size - 1) / zp_pack_size; } + MLAS_BLK_QUANT_TYPE fc1_direct_qtype = BlkQ4Sym; + MLAS_BLK_QUANT_TYPE fc2_direct_qtype = BlkQ4Sym; + + // Use pre-packed MLAS cache if available + const void* fc1_direct_q4_cache_ptr = nullptr; + if (use_mlas_q4_gemm_effective && packed_fc1_mlas_cache_ && fc1_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, fc1_out_features, hidden_size, fc1_direct_qtype)) { + fc1_direct_q4_cache_ptr = packed_fc1_mlas_cache_.get(); + } + + const void* fc2_direct_q4_cache_ptr = nullptr; + if (use_mlas_q4_gemm_effective && packed_fc2_mlas_cache_ && fc2_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, hidden_size, inter_size, fc2_direct_qtype)) { + fc2_direct_q4_cache_ptr = packed_fc2_mlas_cache_.get(); + } + std::vector> expert_workload; size_t total_work = 0; @@ -718,10 +1010,57 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t k = static_cast(hidden_size); MLAS_BLK_QUANT_TYPE q_type = BlkQ4Sym; // Initialize to default - // Direct Q4 GEMM only supports symmetric quantization, so we disable it if zero_points are provided. - bool use_direct_q4_gemm = (fc1_zp_data == nullptr) && - CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, - fc1_out_features, hidden_size, q_type); + bool use_direct_q4_gemm = use_mlas_q4_gemm_effective && + ((fc1_direct_q4_cache_ptr != nullptr) || + ((packed_fc1_ == nullptr) && (fc1_zp_data == nullptr) && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type))); + + if (packed_fc1_ != nullptr) { + if (use_mlas_q4_gemm_effective && fc1_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type)) { + if (fc1_direct_q4_cache_ptr != nullptr) { + float* fc1_bias_float = nullptr; + if (has_fc1_bias) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); + } else { + std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } + fc1_bias_float = thread_bias1_buffer; + } + + size_t packed_size = MlasQ4GemmPackBSize(q_type, static_cast(fc1_out_features), static_cast(hidden_size)); + const uint8_t* packed_b = static_cast(fc1_direct_q4_cache_ptr) + expert_idx * packed_size; + + Status gemm_status = DirectQ4Gemm(A1, packed_b, fc1_bias_float, C1, + num_expert_tokens, fc1_out_features, hidden_size, fc1_direct_qtype, tp); + if (gemm_status.IsOK()) { + goto fc1_gemm_done; + } + } + } + + // Fallback: Dequantize from PrePacked (transposed, unpacked) -> MlasGemm + const uint8_t* current_packed_ptr = static_cast(packed_fc1_.get()) + expert_idx * fc1_out_features * hidden_size; + + DequantizePrePacked(current_packed_ptr, fc1_scales_ptr, fc1_zp_ptr, + is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, + B1_dequant, fc1_scales_dims); + + // Use MlasGemm with B1_dequant (which is already float transposed) + MlasGemm(CblasNoTrans, CblasNoTrans, + m, n, k, + 1.0f, A1, k, + B1_dequant, n, + 0.0f, C1, n, + tp); + + goto fc1_bias_handling; + } if (use_direct_q4_gemm) { IAllocatorUniquePtr mlas_packed_fc1; @@ -739,12 +1078,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if (convert_status.IsOK()) { float* fc1_bias_float = nullptr; - IAllocatorUniquePtr fc1_bias_buffer; if (has_fc1_bias) { const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; - fc1_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(fc1_out_features)); - fc1_bias_float = fc1_bias_buffer.get(); + fc1_bias_float = thread_bias1_buffer; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), fc1_bias_float, static_cast(fc1_out_features)); @@ -805,6 +1142,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { 0.0f, C1, n, tp); + fc1_bias_handling: + if (has_fc1_bias) { const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { @@ -844,22 +1183,30 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc1_gemm_done: - const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); - if (num_expert_tokens >= activation_threshold && tp != nullptr) { - const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); - const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; - - if (num_activation_blocks > 1) { - concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_activation_blocks), [&](std::ptrdiff_t block_idx) { - const int64_t start_token = block_idx * activation_block_size; - const int64_t end_token = std::min(start_token + activation_block_size, num_expert_tokens); - - for (int64_t i = start_token; i < end_token; ++i) { + if (activation_type_ == ActivationType::SwiGLU) { + const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); + if (num_expert_tokens >= activation_threshold && tp != nullptr) { + const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); + const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; + + if (num_activation_blocks > 1) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_activation_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_token = block_idx * activation_block_size; + const int64_t end_token = std::min(start_token + activation_block_size, num_expert_tokens); + + for (int64_t i = start_token; i < end_token; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + }); + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { const float* C1_token = C1 + i * fc1_out_features; float* A2_token = A2 + i * inter_size; ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); } - }); + } } else { for (int64_t i = 0; i < num_expert_tokens; ++i) { const float* C1_token = C1 + i * fc1_out_features; @@ -868,11 +1215,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } } } else { - for (int64_t i = 0; i < num_expert_tokens; ++i) { - const float* C1_token = C1 + i * fc1_out_features; - float* A2_token = A2 + i * inter_size; - ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); - } + ApplyActivationVectorized(C1, num_expert_tokens * fc1_out_features); + std::copy(C1, C1 + (num_expert_tokens * fc1_out_features), A2); } const T* fc2_scales_ptr; @@ -895,9 +1239,58 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t k2 = static_cast(inter_size); MLAS_BLK_QUANT_TYPE q_type2 = BlkQ4Sym; // Initialize to default - bool use_direct_q4_gemm_fc2 = (fc2_zp_data == nullptr) && - CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, - hidden_size, inter_size, q_type2); + bool use_direct_q4_gemm_fc2 = use_mlas_q4_gemm_effective && + ((fc2_direct_q4_cache_ptr != nullptr) || + ((packed_fc2_ == nullptr) && (fc2_zp_data == nullptr) && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2))); + + if (packed_fc2_ != nullptr) { + if (use_mlas_q4_gemm_effective && fc2_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2)) { + if (fc2_direct_q4_cache_ptr != nullptr) { + float* fc2_bias_float = nullptr; + if (has_fc2_bias) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); + } else { + std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); + } + fc2_bias_float = thread_bias2_buffer; + } + + size_t packed_size = MlasQ4GemmPackBSize(q_type2, static_cast(hidden_size), static_cast(inter_size)); + const uint8_t* packed_b = static_cast(fc2_direct_q4_cache_ptr) + expert_idx * packed_size; + + Status gemm_status = DirectQ4Gemm(A2, packed_b, fc2_bias_float, C2, + num_expert_tokens, hidden_size, inter_size, fc2_direct_qtype, tp); + if (gemm_status.IsOK()) { + fc2_bias_added_by_mlas = true; + goto fc2_gemm_done; + } + } + } + + // Dequantize from PrePacked (transposed, unpacked) + const uint8_t* current_packed_ptr = static_cast(packed_fc2_.get()) + expert_idx * hidden_size * inter_size; + + DequantizePrePacked(current_packed_ptr, fc2_scales_ptr, fc2_zp_ptr, + is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, + B2_dequant, fc2_scales_dims); + + // Fallback + MlasGemm(CblasNoTrans, CblasNoTrans, + m2, n2, k2, + 1.0f, A2, k2, + B2_dequant, n2, + 0.0f, C2, n2, + tp); + + goto fc2_gemm_done; + } if (use_direct_q4_gemm_fc2) { IAllocatorUniquePtr mlas_packed_fc2; @@ -915,12 +1308,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if (convert_status.IsOK()) { float* fc2_bias_float = nullptr; - IAllocatorUniquePtr fc2_bias_buffer; if (has_fc2_bias) { const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; - fc2_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(hidden_size)); - fc2_bias_float = fc2_bias_buffer.get(); + fc2_bias_float = thread_bias2_buffer; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), fc2_bias_float, static_cast(hidden_size)); @@ -1114,10 +1505,22 @@ Status QMoECPU::Compute(OpKernelContext* context) const { return Status::OK(); } +template +void QMoECPU::ApplyActivationVectorized(float* data, int64_t size) const { + for (int64_t i = 0; i < size; ++i) { + data[i] = ApplyActivation(data[i], activation_type_); + } +} + template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); + template Status QMoECPU::Compute(OpKernelContext* context) const; +template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); +template Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; +template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); +template Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); // Kernel Registration ONNX_OPERATOR_TYPED_KERNEL_EX( diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index 890580e051a8e..f678a27190c90 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -5,7 +5,9 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas_q4.h" #include "contrib_ops/cpu/moe/moe_base_cpu.h" +#include namespace onnxruntime { namespace contrib { @@ -26,8 +28,30 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { Status Compute(OpKernelContext* context) const override; private: + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + void ApplyActivationVectorized(float* data, int64_t size) const; + int64_t expert_weight_bits_; int64_t block_size_; + bool use_mlas_q4_gemm_{false}; + bool use_mlas_q4_gemm_overridden_{false}; + + IAllocatorUniquePtr packed_fc1_; + IAllocatorUniquePtr packed_fc2_; + + TensorShape fc1_shape_; + TensorShape fc2_shape_; + + IAllocatorUniquePtr packed_fc1_mlas_cache_; + IAllocatorUniquePtr packed_fc2_mlas_cache_; }; } // namespace contrib diff --git a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc index 38dd8de01147c..5137c22d6cf61 100644 --- a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc +++ b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc @@ -621,8 +621,8 @@ void DumpNodeInputs( std::cout << " is non-tensor type.\n"; } } else { - // this could happen with an empty Optional input - std::cout << " was missing data type\n"; + // this could happen with an empty Optional input or the tensor is removed after pre-packing. + std::cout << " was missing data type (maybe pre-packed).\n"; } } else { std::cout << "Input " << i << " is optional and was not provided.\n"; diff --git a/onnxruntime/test/python/transformers/benchmark_qmoe.py b/onnxruntime/test/python/transformers/benchmark_qmoe.py new file mode 100644 index 0000000000000..b96c9cdcf5c3a --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_qmoe.py @@ -0,0 +1,191 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os +import sys +import time +import unittest + +import numpy +import torch + +# Add current directory to path to allow importing from test_qmoe_cpu +current_dir = os.path.dirname(os.path.abspath(__file__)) +if current_dir not in sys.path: + sys.path.append(current_dir) + +from test_qmoe_cpu import PhiMoEConfig, PhiMoESparseMoeBlock, TensorProto # noqa: E402 + +# Reduces number of tests to run for faster pipeline checks +pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" + + +@unittest.skipIf(pipeline_mode, "Skip benchmark in CI pipeline.") +class TestQMoESwiGLUBenchmark(unittest.TestCase): + """Benchmark tests for QMoE SwiGLU performance measurement.""" + + def test_qmoe_swiglu_throughput_benchmark(self): + """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" + print("\n=== QMoE SwiGLU Throughput Benchmark ===") + + # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits) + configs = [ + ("Medium-4bit", 2880, 2880, 32, 4, 4), + ("Medium-8bit", 2880, 2880, 32, 4, 8), + ] + + batch_size = 1 + sequence_length = 512 + num_runs = 1000 + + results = [] + + for config_name, hidden_size, intermediate_size, num_experts, top_k, quant_bits in configs: + torch.manual_seed(42) + numpy.random.seed(42) + + torch_output = None + ort_output = None + + print(f"\nTesting {config_name}:") + print(f" Hidden: {hidden_size}, Intermediate: {intermediate_size}") + print(f" Experts: {num_experts}, Top-K: {top_k}, Quant: {quant_bits}-bit") + + try: + # Create config and model + config = PhiMoEConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_local_experts=num_experts, + num_experts_per_tok=top_k, + ) + + qmoe_swiglu = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + # Create test input with fixed sequence length to match ONNX model + full_hidden_states = torch.randn(batch_size, sequence_length, hidden_size).to(torch.float32) + + # For TTFT simulation, we'll measure single forward pass time + # This represents the time to process one token in autoregressive generation + + # Warm up with full context + for _ in range(3): + _ = qmoe_swiglu.forward(full_hidden_states) + + # Benchmark PyTorch TTFT (Time to First Token) + # Measure time for a single forward pass (represents token generation time) + torch.manual_seed(42) + + start_time = time.time() + for _ in range(num_runs): + torch_output = qmoe_swiglu.forward(full_hidden_states) + end_time = time.time() + torch_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second (throughput) + # For sequence generation, this represents the rate at which we can generate tokens + torch_tokens_per_sec = 1000.0 / torch_ttft_ms # 1 token / (time_ms / 1000) + + print(f" PyTorch TTFT: {torch_ttft_ms:.3f} ms (per token generation time)") + print(f" PyTorch Throughput: {torch_tokens_per_sec:.1f} tokens/sec") + + # Benchmark ONNX Runtime + ort_ttft_ms = 0 + ort_tokens_per_sec = 0 + speedup = 0 + throughput_ratio = 0 + max_diff = 0 + + model_updated = qmoe_swiglu.recreate_onnx_model() + if model_updated and qmoe_swiglu.ort_sess is not None: + # Warm up ORT with full context + for _ in range(3): + _ = qmoe_swiglu.ort_forward(full_hidden_states) + + torch.manual_seed(42) + + # Measure ONNX Runtime TTFT (Time to First Token) + start_time = time.time() + for _ in range(num_runs): + ort_output = qmoe_swiglu.ort_forward(full_hidden_states) + end_time = time.time() + ort_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second for ONNX Runtime + ort_tokens_per_sec = 1000.0 / ort_ttft_ms # 1 token / (time_ms / 1000) + + speedup = torch_ttft_ms / ort_ttft_ms if ort_ttft_ms > 0 else 0 + throughput_ratio = ort_tokens_per_sec / torch_tokens_per_sec if torch_tokens_per_sec > 0 else 0 + + print(f" ONNX RT TTFT: {ort_ttft_ms:.3f} ms (per token generation time)") + print(f" ONNX RT Throughput: {ort_tokens_per_sec:.1f} tokens/sec") + print(f" TTFT Speedup: {speedup:.2f}x") + print(f" Throughput Gain: {throughput_ratio:.2f}x") + else: + print(" ONNX RT: Not available") + + # Calculate max difference if both outputs available + if torch_output is not None and ort_output is not None: + max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max().item() + print(f" Max diff: {max_diff:.6f}") + + results.append( + { + "config": config_name, + "torch_ttft_ms": torch_ttft_ms, + "torch_tokens_per_sec": torch_tokens_per_sec, + "ort_ttft_ms": ort_ttft_ms, + "ort_tokens_per_sec": ort_tokens_per_sec, + "speedup": speedup, + "throughput_ratio": throughput_ratio, + "max_diff": max_diff, + } + ) + + except Exception as e: + print(f" Error: {e}") + continue + + # Summary + print("\n=== Token Generation Time & Throughput Summary ===") + print( + f"{'Config':<15} {'PT Time':<10} {'PT tok/s':<10} {'ORT Time':<11} {'ORT tok/s':<11} {'Time Gain':<10} {'Throughput':<11} {'Max Diff':<10}" + ) + print("-" * 105) + for result in results: + config = result["config"] + torch_ttft = result["torch_ttft_ms"] + torch_tps = result["torch_tokens_per_sec"] + ort_ttft = result["ort_ttft_ms"] + ort_tps = result["ort_tokens_per_sec"] + speedup = result["speedup"] + throughput_ratio = result["throughput_ratio"] + max_diff = result["max_diff"] + + ort_ttft_str = f"{ort_ttft:.3f}" if ort_ttft > 0 else "N/A" + ort_tps_str = f"{ort_tps:.1f}" if ort_tps > 0 else "N/A" + speedup_str = f"{speedup:.2f}x" if speedup > 0 else "N/A" + throughput_str = f"{throughput_ratio:.2f}x" if throughput_ratio > 0 else "N/A" + + print( + f"{config:<15} {torch_ttft:<10.3f} {torch_tps:<10.1f} {ort_ttft_str:<11} {ort_tps_str:<11} {speedup_str:<10} {throughput_str:<11} {max_diff:<10.6f}" + ) + + print("\nNotes:") + print("- Time: Token generation time in ms (lower is better)") + print("- tok/s: Tokens per second throughput (higher is better)") + print("- Time Gain: ORT speedup for latency (higher is better)") + print("- Throughput: ORT throughput improvement (higher is better)") + + +if __name__ == "__main__": + benchmark = TestQMoESwiGLUBenchmark() + benchmark.test_qmoe_swiglu_throughput_benchmark() diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 238ac4d1f077d..8415c7b08b77c 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -23,9 +23,11 @@ # normalization on the selected experts. This provides proper weight distribution # while maintaining computational efficiency. # -------------------------------------------------------------------------- +import os import time import unittest from collections import OrderedDict +from contextlib import contextmanager import numpy import torch @@ -76,6 +78,8 @@ class TensorProtoPlaceholder: ort_provider = ["CPUExecutionProvider"] +ORT_USE_MLAS_Q4_GEMM_MOE = "ORT_USE_MLAS_Q4_GEMM_MOE" + torch.manual_seed(42) numpy.random.seed(42) @@ -1137,6 +1141,43 @@ def small_test_cases(): yield batch_size, sequence_length +def with_mlas_q4_mode(test_cases): + expanded_cases = [] + for case in test_cases: + quant_bits = case[2] + if quant_bits == 4: + expanded_cases.append((*case, None)) + expanded_cases.append((*case, False)) + expanded_cases.append((*case, True)) + else: + expanded_cases.append((*case, None)) + return expanded_cases + + +@contextmanager +def scoped_env_var(name: str, value: str): + previous = os.environ.get(name) + os.environ[name] = value + try: + yield + finally: + if previous is None: + os.environ.pop(name, None) + else: + os.environ[name] = previous + + +def run_parity_with_mlas_q4_mode(test_runner, enable_mlas_q4_gemm: bool | None): + if enable_mlas_q4_gemm is None: # No env var + test_runner() + else: + env_value = "1" if enable_mlas_q4_gemm else "0" + mode = "enabled" if enable_mlas_q4_gemm else "disabled" + print(f"DirectQ4 mode ({ORT_USE_MLAS_Q4_GEMM_MOE}) is {mode}") + with scoped_env_var(ORT_USE_MLAS_Q4_GEMM_MOE, env_value): + test_runner() + + class SwigluMoEBlock(SparseMoeBlockORTHelper): def __init__( self, @@ -1381,8 +1422,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states -disable_cpu_qmoe_tests = False - # Define test cases for different MoE types phi3_test_cases = [ (1, 32, 4), @@ -1400,10 +1439,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ] -@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestPhiQMoECPU(unittest.TestCase): - @parameterized.expand(phi3_test_cases) - def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(phi3_test_cases)) + def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): # Create unique seed based on test parameters to ensure different inputs for each test base_seed = 2000 # Different base seed from other tests param_hash = hash((batch_size, sequence_length, quant_bits)) @@ -1438,10 +1476,10 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_test_cases) - def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(phi3_test_cases)) + def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): base_seed = 3000 param_hash = hash((batch_size, sequence_length, quant_bits)) unique_seed = base_seed + abs(param_hash) % 1000 @@ -1463,10 +1501,12 @@ def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quan onnx_dtype=TensorProto.FLOAT, use_asymmetric_quant=True, ) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_blockwise_test_cases) - def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(phi3_blockwise_test_cases)) + def test_phi3_qmoe_blockwise_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(42) numpy.random.seed(42) @@ -1495,10 +1535,12 @@ def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_blockwise_test_cases) - def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(phi3_blockwise_test_cases)) + def test_phi3_qmoe_blockwise_asymmetric_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(43) numpy.random.seed(43) @@ -1516,10 +1558,8 @@ def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_le block_size=block_size, use_asymmetric_quant=True, ) - phi3_moe.parity_check() - + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) -disable_cpu_qmoe_tests = False swiglu_test_cases = [ (1, 32, 4), @@ -1537,10 +1577,9 @@ def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_le ] -@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestSwigluQMoECPU(unittest.TestCase): - @parameterized.expand(swiglu_test_cases) - def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(swiglu_test_cases)) + def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): # Create unique seed based on test parameters to ensure different inputs for each test base_seed = 1000 # Different base seed from regular MoE tests param_hash = hash((batch_size, sequence_length, quant_bits)) @@ -1574,10 +1613,10 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_test_cases) - def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(swiglu_test_cases)) + def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): base_seed = 1100 param_hash = hash((batch_size, sequence_length, quant_bits)) unique_seed = base_seed + abs(param_hash) % 1000 @@ -1599,10 +1638,12 @@ def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, qu onnx_dtype=TensorProto.FLOAT, use_asymmetric_quant=True, ) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_blockwise_test_cases) - def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(swiglu_blockwise_test_cases)) + def test_swiglu_qmoe_blockwise_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(42) numpy.random.seed(42) @@ -1630,10 +1671,12 @@ def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, qua self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_blockwise_test_cases) - def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(swiglu_blockwise_test_cases)) + def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(43) numpy.random.seed(43) @@ -1651,7 +1694,7 @@ def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_ block_size=block_size, use_asymmetric_quant=True, ) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) @unittest.skipIf(True, "Skipping QMoE CPU benchmark tests") @@ -1660,9 +1703,6 @@ class TestQMoESwiGLUBenchmark(unittest.TestCase): def test_qmoe_swiglu_throughput_benchmark(self): """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" - if disable_cpu_qmoe_tests: - self.skipTest("QMoE CPU tests disabled") - print("\n=== QMoE SwiGLU Throughput Benchmark ===") # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits)