From ed0478ad64d96da419dba46f18a1fa8a9fea8644 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 16 Feb 2026 19:30:00 +0000 Subject: [PATCH 01/11] Fix QMoE CPU --- .../cpu/moe/moe_quantization_cpu.cc | 29 +-- onnxruntime/core/mlas/inc/mlas_q4.h | 8 +- .../test/python/transformers/test_qmoe_cpu.py | 183 ++++++++++-------- 3 files changed, 126 insertions(+), 94 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 6d1d191689466..14bddaf324ae7 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -118,13 +118,23 @@ Status ConvertToMlasQ4Format(const uint8_t* quantized_data, DequantizeBlockWithMlas(quantized_data, scales, zero_points, block_size, num_bits, rows, cols, temp_float, nullptr); - size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(cols), static_cast(rows)); + // Transpose from N x K (weights) to K x N. + // DirectQ4Gemm expects weights to be packed in a specific layout ([K, N] logically) + auto transposed_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); + float* transposed_float = transposed_float_buffer.get(); + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + transposed_float[c * rows + r] = temp_float[r * cols + c]; + } + } + + size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); if (packed_size == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); } mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); - MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast(cols), static_cast(rows), static_cast(cols)); + MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), transposed_float, static_cast(rows), static_cast(cols), static_cast(rows)); return Status::OK(); } @@ -634,6 +644,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* thread_bias2_buffer = thread_bias1_buffer + static_cast(fc1_out_features); for (int64_t expert_idx : expert_batch) { + bool fc2_bias_added_by_mlas = false; const auto& routes = expert_token_map[static_cast(expert_idx)]; if (routes.empty()) { continue; @@ -711,8 +722,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { bool use_direct_q4_gemm = (fc1_zp_data == nullptr) && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, fc1_out_features, hidden_size, q_type); - bool fc1_used_direct_q4 = false; - bool fc1_bias_handled_by_q4_gemm = false; if (use_direct_q4_gemm) { IAllocatorUniquePtr mlas_packed_fc1; @@ -750,7 +759,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); if (gemm_status.IsOK()) { - fc1_used_direct_q4 = true; goto fc1_gemm_done; } } @@ -797,8 +805,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { 0.0f, C1, n, tp); - fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias; - if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) { + if (has_fc1_bias) { const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); @@ -891,7 +898,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { bool use_direct_q4_gemm_fc2 = (fc2_zp_data == nullptr) && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, hidden_size, inter_size, q_type2); - bool fc2_used_direct_q4 = false; if (use_direct_q4_gemm_fc2) { IAllocatorUniquePtr mlas_packed_fc2; @@ -929,7 +935,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { num_expert_tokens, hidden_size, inter_size, q_type2, tp); if (gemm_status.IsOK()) { - fc2_used_direct_q4 = true; + fc2_bias_added_by_mlas = true; goto fc2_gemm_done; } } @@ -979,8 +985,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_gemm_done: - bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias && !fc2_bias_added_by_mlas) { const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); @@ -1015,7 +1020,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias && !fc2_bias_added_by_mlas) { const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); size_t j = 0; for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 69f0435615079..d60e5b0164fe8 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -57,10 +57,10 @@ MlasQ4GemmPackBSize( * * @param QType type of block quantization * @param PackedBuf destination buffer - * @param FpData the pointer to fp32 matrix - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B + * @param FpData the pointer to fp32 matrix, with shape [K, N]. + * @param N the number of columns of matrix B (Output Channels). + * @param K the number of rows of matrix B (Input Channels). + * @param ldb leading dimension of FpData (usually N) */ void MLASCALL diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 90ebb148a26a5..238ac4d1f077d 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -364,7 +364,7 @@ def create_cpu_moe_onnx_graph( use_swiglu=False, use_quant=False, quant_bits=4, - swiglu_interleaved=False, + swiglu_fusion=0, block_size=0, ): if not has_onnx: @@ -400,10 +400,10 @@ def create_cpu_moe_onnx_graph( "router_probs", # 1 "fc1_experts_weights", # 2 "fc1_scales", # 3 - "", # 4: fc1_bias + "fc1_experts_bias" if fc1_bias is not None else "", # 4 "fc2_experts_weights", # 5 "fc2_scales", # 6 - "", # 7: fc2_bias + "fc2_experts_bias" if fc2_bias is not None else "", # 7 "", # 8: fc3_weights "", # 9: fc3_scales "", # 10: fc3_bias @@ -442,11 +442,10 @@ def create_cpu_moe_onnx_graph( normalize_routing_weights=normalize_routing, activation_type=activation, # Add new attributes with backwards-compatible default values - swiglu_fusion=1 if use_swiglu else 0, # 1 if using SwiGLU activation + swiglu_fusion=swiglu_fusion, swiglu_limit=7.0, activation_alpha=1.702, activation_beta=1.0, - swiglu_interleaved=1 if swiglu_interleaved else 0, # Enable this attribute domain="com.microsoft", ), ] @@ -559,6 +558,30 @@ def create_cpu_moe_onnx_graph( ) ) + if fc1_bias is not None: + fc1_bias_np = fc1_bias.detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]) + initializers.append( + helper.make_tensor( + "fc1_experts_bias", + onnx_dtype, + list(fc1_bias.shape), + fc1_bias_np.flatten().tolist(), + raw=False, + ) + ) + + if fc2_bias is not None: + fc2_bias_np = fc2_bias.detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]) + initializers.append( + helper.make_tensor( + "fc2_experts_bias", + onnx_dtype, + list(fc2_bias.shape), + fc2_bias_np.flatten().tolist(), + raw=False, + ) + ) + graph_inputs = [ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] @@ -626,7 +649,7 @@ def __init__( self.num_experts_per_token = num_experts_per_token -def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): +def swiglu(x: torch.Tensor, alpha: float = 1.702, beta: float = 1.0, limit: float = 7.0): dim = x.shape[-1] x = x.view(-1, dim // 2, 2) x_glu, x_linear = x[..., 0], x[..., 1] @@ -635,8 +658,8 @@ def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): x_glu = x_glu.clamp(max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) - y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) - return y + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + beta) + return y.view(-1, dim // 2) class MoEBlockSparseTop2MLP(nn.Module): @@ -855,7 +878,7 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False e = time.time() time_ms = (e - s) / repeat * 1000 is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu - is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + is_interleaved = hasattr(self, "swiglu_fusion") and self.swiglu_fusion == 1 act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" print(f"ORT Performance - {act_type} {self.quant_bits}-bit: {time_ms:.3f} ms/inference") @@ -868,62 +891,80 @@ def recreate_onnx_model(self): """Recreate the ONNX model with the current weights to reflect any changes to the quantization code.""" w1_list, w2_list = [], [] + w1_bias_list, w2_bias_list = [], [] w1_scale_list, w2_scale_list = [], [] w1_zp_list, w2_zp_list = [], [] is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - if self.block_size > 0: - # Use block-wise quantization - w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( - self.experts[i].w1.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( - self.experts[i].w2.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) + if hasattr(self.experts[i], "w3"): + w1, w3 = self.experts[i].w1.weight, self.experts[i].w3.weight + w2 = self.experts[i].w2.weight + w1_bias = self.experts[i].w1.bias + w3_bias = getattr(self.experts[i].w3, "bias", None) + + # Combine and interleave w1 and w3 for the fused kernel + w1_combined = torch.cat([w1, w3], dim=0) # [2*inter, hidden] + if getattr(self, "swiglu_fusion", 0) == 1: + w1_combined = w1_combined.view(2, -1, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + + if self.block_size > 0: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( + w1_combined, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( + w2, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + else: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( + w1_combined, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( + w2, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + + if w1_bias is not None and w3_bias is not None: + b1_combined = torch.cat([w1_bias, w3_bias], dim=0) + if getattr(self, "swiglu_fusion", 0) == 1: + b1_combined = b1_combined.view(2, -1).transpose(0, 1).reshape(-1) + w1_bias_list.append(b1_combined.detach().cpu()) + elif w1_bias is not None: + w1_bias_list.append(w1_bias.detach().cpu()) else: - # Use row-wise quantization - w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( - self.experts[i].w1.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( - self.experts[i].w2.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) + # PhiMoESwiGLUMLP already has interleaved weights in w1 + w1 = self.experts[i].w1.weight + w2 = self.experts[i].w2.weight + w1_bias = self.experts[i].w1.bias - if self.use_swiglu: - if self.swiglu_interleaved: - pass + if self.block_size > 0: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( + w1, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( + w2, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) else: - if self.block_size > 0: - w3_scale, pre_qweight3, w3_qdq, w3_zp = quant_dequant_blockwise( - self.experts[i].w3.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - else: - w3_scale, pre_qweight3, w3_qdq, w3_zp = quant_dequant( - self.experts[i].w3.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - - gate_weights = pre_qweight1 - value_weights = pre_qweight3 - gate_scales = w1_scale - value_scales = w3_scale - gate_zp = w1_zp - value_zp = w3_zp - - pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) - w1_scale = torch.cat([gate_scales, value_scales], dim=0) - if w1_zp is not None and w3_zp is not None: - w1_zp = torch.cat([gate_zp, value_zp], dim=0) - - if self.swiglu_interleaved: - self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( + w1, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( + w2, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + if w1_bias is not None: + w1_bias_list.append(w1_bias.detach().cpu()) + if self.use_swiglu: + if getattr(self, "swiglu_fusion", 0) == 1: + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) else: intermediate_size = self.experts[i].w1.weight.shape[0] gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() value_dequant = w1_qdq[intermediate_size:].contiguous().clone() - self.experts[i].w1.weight.data = gate_dequant - self.experts[i].w3.weight.data = value_dequant + if hasattr(self.experts[i], "w3"): + self.experts[i].w1.weight.data = gate_dequant + self.experts[i].w3.weight.data = value_dequant + else: + self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() else: self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() @@ -931,6 +972,9 @@ def recreate_onnx_model(self): w1_list.append(pre_qweight1) w2_list.append(pre_qweight2) + + if self.experts[i].w2.bias is not None: + w2_bias_list.append(self.experts[i].w2.bias) w1_scale_list.append(w1_scale) w2_scale_list.append(w2_scale) if w1_zp is not None: @@ -963,9 +1007,9 @@ def recreate_onnx_model(self): onnx_dtype=self.onnx_dtype, fc1_experts_weights=self.moe_experts_weight1, fc2_experts_weights=self.moe_experts_weight2, - # Biases are not used in QMoE - fc1_bias=None, - fc2_bias=None, + # Pass collected biases + fc1_bias=torch.stack(w1_bias_list, dim=0) if w1_bias_list else None, + fc2_bias=torch.stack(w2_bias_list, dim=0) if w2_bias_list else None, # Scales are used for dequantization fc1_scales=moe_experts_weight_scale1, fc2_scales=moe_experts_weight_scale2, @@ -975,7 +1019,7 @@ def recreate_onnx_model(self): use_swiglu=self.use_swiglu, use_quant=True, # Always use QMoE quant_bits=self.quant_bits, - swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + swiglu_fusion=getattr(self, "swiglu_fusion", 0), block_size=self.block_size, # Add block_size for block-wise quantization ) except Exception: @@ -1020,7 +1064,7 @@ def parity_check(self): max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max() is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu - is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + is_interleaved = getattr(self, "swiglu_fusion", 0) == 1 act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" quant_type = "Asymmetric" if self.use_asymmetric_quant else "Symmetric" block_type = f"Block({self.block_size})" if self.block_size > 0 else "Row" @@ -1047,24 +1091,6 @@ def parity_check(self): ) print("Torch sample:", torch_output.cpu().reshape(-1, hidden_dim)[i, k].item()) print("ORT sample:", ort_output.cpu().reshape(-1, hidden_dim)[i, k].item()) - # Print routing and per-expert contributions for this token from the PyTorch reference - try: - hidden_states_flat = hidden_state.view(-1, hidden_dim) - token_vec = hidden_states_flat[i : i + 1] - gate_logits = self.gate(token_vec) - topk_vals, topk_experts = torch.topk(gate_logits, self.top_k, dim=-1) - topk_soft = F.softmax(topk_vals, dim=1) - print("Gate logits:", gate_logits.detach().cpu().numpy()) - print("Selected experts:", topk_experts.detach().cpu().numpy()) - print("Routing weights:", topk_soft.detach().cpu().numpy()) - # Compute per-expert contributions for selected experts - for idx_e, e in enumerate(topk_experts[0].tolist()): - expert_layer = self.experts[e] - expert_out = expert_layer(token_vec) - contrib = expert_out[0, k].item() * topk_soft[0, idx_e].item() - print(f"Expert {e} contrib at hidden {k}: {contrib}") - except Exception as _: - pass ort_dtype_quant_bits_tolerance_map = { "FP32:0": (5e-3, 1e-3), @@ -1128,7 +1154,7 @@ def __init__( self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_token self.use_swiglu = True - self.swiglu_interleaved = True + self.swiglu_fusion = 1 self.block_size = block_size use_quant = self.quant_bits > 0 @@ -1232,7 +1258,7 @@ def __init__( self.top_k = config.num_experts_per_tok self.router_jitter_noise = config.router_jitter_noise self.use_swiglu = True - self.swiglu_interleaved = True + self.swiglu_fusion = 1 self.block_size = block_size use_quant = self.quant_bits > 0 @@ -1314,7 +1340,8 @@ def __init__( use_swiglu=self.use_swiglu, use_quant=use_quant, quant_bits=self.quant_bits, - swiglu_interleaved=self.swiglu_interleaved, + # swiglu_fusion=1 means fused and interleaved, which is the standard for QMoE. + swiglu_fusion=getattr(self, "swiglu_fusion", 0), block_size=self.block_size, ) From 2c838293216189fcf52f97e0dc570b539299a46c Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 16 Feb 2026 20:02:02 +0000 Subject: [PATCH 02/11] prepack and cache to improve QMoE perf --- onnxruntime/contrib_ops/cpu/moe/moe_helper.h | 54 ++- .../cpu/moe/moe_quantization_cpu.cc | 430 +++++++++++++++++- .../cpu/moe/moe_quantization_cpu.h | 32 ++ .../debug_node_inputs_outputs_utils.cc | 4 +- .../python/transformers/benchmark_qmoe.py | 187 ++++++++ .../test/python/transformers/test_qmoe_cpu.py | 91 +++- 6 files changed, 746 insertions(+), 52 deletions(-) create mode 100644 onnxruntime/test/python/transformers/benchmark_qmoe.py diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index 257c5a189b3bd..611d2f989d576 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -56,23 +56,49 @@ Status CheckInputs(MoEParameters& parameters, const int64_t block_size = 0) { // block size for block-wise quantization // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. ASSERT_TENSOR_2D_OR_3D(input); - ASSERT_TENSOR_3D(fc1_experts_weights); - ASSERT_TENSOR_3D(fc2_experts_weights); + if (fc1_experts_weights) ASSERT_TENSOR_3D(fc1_experts_weights); + if (fc2_experts_weights) ASSERT_TENSOR_3D(fc2_experts_weights); ASSERT_TENSOR_2D(router_probs); const auto& input_dims = input->Shape().GetDims(); const auto& router_probs_dims = router_probs->Shape().GetDims(); - const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); - const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; int64_t hidden_size = input_dims[input_dims.size() - 1]; - int64_t local_num_experts = fc1_experts_weights_dims[0]; int64_t num_experts = router_probs_dims[1]; - int64_t inter_size = (fc2_experts_weights_dims[1] * fc2_experts_weights_dims[2] * pack_size) / hidden_size; - const bool legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || - (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); + int64_t local_num_experts; + if (fc1_experts_weights != nullptr) { + local_num_experts = fc1_experts_weights->Shape().GetDims()[0]; + } else if (fc1_experts_scales != nullptr) { + local_num_experts = fc1_experts_scales->Shape().GetDims()[0]; + } else { + // Fallback for non-quantized MoE without weights (should not happen in current code paths) + // or if only bias is provided? + local_num_experts = num_experts; + } + + int64_t inter_size; + if (fc2_experts_weights != nullptr) { + const auto& dims = fc2_experts_weights->Shape().GetDims(); + inter_size = (dims[1] * dims[2] * pack_size) / hidden_size; + } else if (fc3_experts_scales != nullptr) { + inter_size = fc3_experts_scales->Shape().GetDims()[1]; + } else if (fc1_experts_scales != nullptr) { + int64_t fc1_inter_size = fc1_experts_scales->Shape().GetDims()[1]; + inter_size = is_fused_swiglu ? fc1_inter_size / 2 : fc1_inter_size; + } else { + // Should not happen for valid QMoE calls + inter_size = 0; + } + + bool legacy_shape = false; + if (fc2_experts_weights != nullptr && fc1_experts_weights != nullptr) { + const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); + legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || + (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); + } // Fused swiglu doubles the output dimension of FC1 since it fused two GEMMs into one. const int64_t fc1_inter_size = is_fused_swiglu ? (inter_size + inter_size) : inter_size; @@ -80,13 +106,13 @@ Status CheckInputs(MoEParameters& parameters, if (legacy_shape) { // legacy shape does not match column major memory layout. This is for backward compatibility. - CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size); - CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size); - CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size); + if (fc1_experts_weights) CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size); + if (fc2_experts_weights) CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size); + if (fc3_experts_weights) CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size); } else { - CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size); - CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size); - CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size); + if (fc1_experts_weights) CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size); + if (fc2_experts_weights) CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size); + if (fc3_experts_weights) CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size); } CHECK_TENSOR_SHAPE(router_probs, num_rows, num_experts); diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 14bddaf324ae7..43a01da0bffef 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -13,6 +13,7 @@ #include "core/common/narrow.h" #include "core/framework/tensor_type_and_shape.h" #include "core/util/math.h" +#include "core/platform/env_var_utils.h" #include "contrib_ops/cpu/moe/moe_utils.h" #include "contrib_ops/cpu/moe/moe_helper.h" @@ -84,6 +85,8 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, namespace onnxruntime { namespace contrib { +constexpr char* kUseMlasQ4GemmMoe = "ORT_USE_MLAS_Q4_GEMM_MOE"; + template void DequantizeBlockWithMlas(const uint8_t* quantized_data, const TScale* scales, @@ -364,6 +367,184 @@ void DequantizeBlock(const uint8_t* quantized_data, DequantizeBlockWithMlas(quantized_data, scales, zero_points, block_size, num_bits, rows, cols, dequantized_data, thread_pool); } +template +void DequantizePrePacked(const uint8_t* prepacked_data, + const TScale* scales, + const uint8_t* zero_points, + int64_t block_size, + int64_t rows, + int64_t cols, + float* dequantized_data, + const gsl::span& scale_dims) { + // prepacked_data is [cols, rows] (transposed, unpacked) + // dequantized_data is [cols, rows] (transposed) + // scales, zero_points correspond to original [rows, cols] layout + + const float default_zp_4bit = 8.0f; + const int64_t blocks_per_row = (block_size > 0) ? ((cols + block_size - 1) / block_size) : 1; + const int64_t zp_pack_size = 2; // Always 2 for 4-bit + + // Iterate over Columns (K) then Rows (N) because prepacked_data is [K, N] + for (int64_t c = 0; c < cols; ++c) { + for (int64_t r = 0; r < rows; ++r) { + uint8_t val = prepacked_data[c * rows + r]; + + int64_t block_idx = (block_size > 0) ? (c / block_size) : 0; + if (block_size > 0) block_idx = std::min(block_idx, blocks_per_row - 1); + + int64_t scale_idx; + if (scale_dims.size() == 3 && scale_dims[2] > 1) { // block-wise + scale_idx = r * blocks_per_row + block_idx; + } else { // per-channel + scale_idx = r; + } + + float scale = static_cast(scales[scale_idx]); + float zp = default_zp_4bit; + + if (zero_points != nullptr) { + int64_t zp_idx; + bool is_lower_nibble; + + if (scale_dims.size() == 3 && scale_dims[2] > 1) { // block-wise + int64_t zp_blocks_packed = (blocks_per_row + zp_pack_size - 1) / zp_pack_size; + zp_idx = r * zp_blocks_packed + block_idx / 2; + is_lower_nibble = (block_idx % 2 == 0); + } else { + zp_idx = r / 2; + is_lower_nibble = (r % 2 == 0); + } + + uint8_t packed_zp = zero_points[zp_idx]; + zp = is_lower_nibble ? static_cast(packed_zp & 0x0F) : static_cast(packed_zp >> 4); + } + + dequantized_data[c * rows + r] = scale * (static_cast(val) - zp); + } + } +} + +template +Status BuildDirectQ4PackedBCache(const uint8_t* prepacked_weights, + const TScale* scales_data, + int64_t num_experts, + int64_t rows, + int64_t cols, + int64_t block_size, + const gsl::span& scales_dims, + MLAS_BLK_QUANT_TYPE qtype, + std::vector>& packed_b_by_expert) { + const size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); + if (packed_size == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to compute MLAS Q4 packed size for cache"); + } + + const bool is_block_wise = (scales_dims.size() == 3 && scales_dims[2] > 1); + const int64_t scales_expert_stride = is_block_wise ? (rows * scales_dims[2]) : rows; + const size_t prepacked_expert_stride = static_cast(rows * cols); + + packed_b_by_expert.clear(); + packed_b_by_expert.resize(static_cast(num_experts)); + + std::vector dequantized_transposed(static_cast(rows * cols)); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const uint8_t* expert_prepacked = prepacked_weights + static_cast(expert_idx) * prepacked_expert_stride; + const TScale* expert_scales = scales_data + expert_idx * scales_expert_stride; + + DequantizePrePacked(expert_prepacked, expert_scales, nullptr, block_size, rows, cols, + dequantized_transposed.data(), scales_dims); + + auto& packed_b = packed_b_by_expert[static_cast(expert_idx)]; + packed_b.resize(packed_size); + MlasQ4GemmPackB(qtype, packed_b.data(), dequantized_transposed.data(), + static_cast(rows), static_cast(cols), static_cast(rows)); + } + + return Status::OK(); +} + +template +Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + + // If scales are prepacked, they are constant initializers. This enables safe shared cache usage. + if (input_idx == 3) { + has_prepacked_fc1_scales_ = true; + return Status::OK(); + } + if (input_idx == 6) { + has_prepacked_fc2_scales_ = true; + return Status::OK(); + } + + // Only support PrePack for FC1 (2), FC2 (5), and FC3 (8) weights + // and only if expert_weight_bits_ == 4 (since we unpack to uint8) + if (expert_weight_bits_ != 4) { + return Status::OK(); + } + + if (input_idx == 2 || input_idx == 5 || input_idx == 8) { + const auto& shape = tensor.Shape(); + const int64_t num_experts = shape[0]; + const int64_t rows = shape[1]; + const int64_t cols_packed = shape[2]; + const int64_t cols = cols_packed * 2; + + size_t packed_size = static_cast(num_experts * rows * cols); + auto packed_buffer = IAllocator::MakeUniquePtr(alloc, packed_size, true); + uint8_t* dst_base = static_cast(packed_buffer.get()); + const uint8_t* src_base = static_cast(tensor.DataRaw()); + + for (int64_t i = 0; i < num_experts; ++i) { + const uint8_t* src = src_base + i * rows * cols_packed; + uint8_t* dst = dst_base + i * rows * cols; + + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + uint8_t packed_val = src[r * cols_packed + (c / 2)]; + uint8_t val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); + + dst[c * rows + r] = val; + } + } + } + + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_buffer)); + prepacked_weights->buffer_sizes_.push_back(packed_size); + is_packed = true; + } + } + + return Status::OK(); +} + +template +Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + + if (expert_weight_bits_ != 4) { + return Status::OK(); + } + + if (input_idx == 2 && !prepacked_buffers.empty()) { + packed_fc1_ = std::move(prepacked_buffers[0]); + used_shared_buffers = true; + } else if (input_idx == 5 && !prepacked_buffers.empty()) { + packed_fc2_ = std::move(prepacked_buffers[0]); + used_shared_buffers = true; + } else if (input_idx == 8 && !prepacked_buffers.empty()) { + packed_fc3_ = std::move(prepacked_buffers[0]); + used_shared_buffers = true; + } + + return Status::OK(); +} + template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), @@ -377,19 +558,21 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(block_size_ >= 16, "block_size must be >= 16 when provided."); ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); } + + use_mlas_q4_gemm_ = ParseEnvironmentVariableWithDefault(kUseMlasQ4GemmMoe, true); } template Status QMoECPU::Compute(OpKernelContext* context) const { const auto* input = context->Input(0); const auto* router_probs = context->Input(1); - const auto* fc1_experts_weights = context->Input(2); + const auto* fc1_experts_weights = packed_fc1_ ? nullptr : context->Input(2); const auto* fc1_scales = context->Input(3); const auto* fc1_experts_bias = context->Input(4); - const auto* fc2_experts_weights = context->Input(5); + const auto* fc2_experts_weights = packed_fc2_ ? nullptr : context->Input(5); const auto* fc2_scales = context->Input(6); const auto* fc2_experts_bias = context->Input(7); - const auto* fc3_experts_weights = context->Input(8); + const auto* fc3_experts_weights = packed_fc3_ ? nullptr : context->Input(8); const auto* fc3_scales = context->Input(9); const auto* fc3_experts_bias = context->Input(10); const auto* fc1_zero_points = context->Input(11); @@ -569,8 +752,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const bool is_fc1_block_wise = (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1); const bool is_fc2_block_wise = (fc2_scales_dims.size() == 3 && fc2_scales_dims[2] > 1); - const uint8_t* fc1_weights_data = fc1_experts_weights->Data(); - const uint8_t* fc2_weights_data = fc2_experts_weights->Data(); + const uint8_t* fc1_weights_data = (packed_fc1_ != nullptr) ? nullptr : fc1_experts_weights->Data(); + const uint8_t* fc2_weights_data = (packed_fc2_ != nullptr) ? nullptr : fc2_experts_weights->Data(); const T* fc1_scales_data = fc1_scales->Data(); const T* fc2_scales_data = fc2_scales->Data(); const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; @@ -605,6 +788,63 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_zp_expert_stride = (hidden_size + zp_pack_size - 1) / zp_pack_size; } + const std::vector>* fc1_direct_q4_cache = nullptr; + const std::vector>* fc2_direct_q4_cache = nullptr; + MLAS_BLK_QUANT_TYPE fc1_direct_qtype = BlkQ4Sym; + MLAS_BLK_QUANT_TYPE fc2_direct_qtype = BlkQ4Sym; + + if (use_mlas_q4_gemm_ && has_prepacked_fc1_scales_ && packed_fc1_ != nullptr && fc1_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, fc1_direct_qtype)) { + std::lock_guard guard(direct_q4_cache_mu_); + auto& cache = fc1_direct_q4_cache_; + if (cache.packed_b_by_expert.empty() || + cache.scales_data_ptr != static_cast(fc1_scales_data) || + cache.rows != fc1_out_features || cache.cols != hidden_size || + cache.num_experts != num_experts || cache.qtype != fc1_direct_qtype) { + std::vector> rebuilt_cache; + ORT_RETURN_IF_ERROR(BuildDirectQ4PackedBCache( + static_cast(packed_fc1_.get()), fc1_scales_data, + num_experts, fc1_out_features, hidden_size, + is_fc1_block_wise ? block_size_ : 0, + fc1_scales_dims, fc1_direct_qtype, rebuilt_cache)); + + cache.scales_data_ptr = static_cast(fc1_scales_data); + cache.rows = fc1_out_features; + cache.cols = hidden_size; + cache.num_experts = num_experts; + cache.qtype = fc1_direct_qtype; + cache.packed_b_by_expert = std::move(rebuilt_cache); + } + fc1_direct_q4_cache = &cache.packed_b_by_expert; + } + + if (use_mlas_q4_gemm_ && has_prepacked_fc2_scales_ && packed_fc2_ != nullptr && fc2_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, fc2_direct_qtype)) { + std::lock_guard guard(direct_q4_cache_mu_); + auto& cache = fc2_direct_q4_cache_; + if (cache.packed_b_by_expert.empty() || + cache.scales_data_ptr != static_cast(fc2_scales_data) || + cache.rows != hidden_size || cache.cols != inter_size || + cache.num_experts != num_experts || cache.qtype != fc2_direct_qtype) { + std::vector> rebuilt_cache; + ORT_RETURN_IF_ERROR(BuildDirectQ4PackedBCache( + static_cast(packed_fc2_.get()), fc2_scales_data, + num_experts, hidden_size, inter_size, + is_fc2_block_wise ? block_size_ : 0, + fc2_scales_dims, fc2_direct_qtype, rebuilt_cache)); + + cache.scales_data_ptr = static_cast(fc2_scales_data); + cache.rows = hidden_size; + cache.cols = inter_size; + cache.num_experts = num_experts; + cache.qtype = fc2_direct_qtype; + cache.packed_b_by_expert = std::move(rebuilt_cache); + } + fc2_direct_q4_cache = &cache.packed_b_by_expert; + } + std::vector> expert_workload; size_t total_work = 0; @@ -718,10 +958,90 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t k = static_cast(hidden_size); MLAS_BLK_QUANT_TYPE q_type = BlkQ4Sym; // Initialize to default - // Direct Q4 GEMM only supports symmetric quantization, so we disable it if zero_points are provided. - bool use_direct_q4_gemm = (fc1_zp_data == nullptr) && - CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, - fc1_out_features, hidden_size, q_type); + bool use_direct_q4_gemm = use_mlas_q4_gemm_ && + ((fc1_direct_q4_cache != nullptr) || + ((packed_fc1_ == nullptr) && (fc1_zp_data == nullptr) && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type))); + + if (packed_fc1_ != nullptr) { + if (fc1_direct_q4_cache != nullptr) { + float* fc1_bias_float = nullptr; + if (has_fc1_bias) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); + } else { + std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } + fc1_bias_float = thread_bias1_buffer; + } + + const auto& packed_b = (*fc1_direct_q4_cache)[static_cast(expert_idx)]; + Status gemm_status = DirectQ4Gemm(A1, packed_b.data(), fc1_bias_float, C1, + num_expert_tokens, fc1_out_features, hidden_size, fc1_direct_qtype, tp); + if (gemm_status.IsOK()) { + goto fc1_gemm_done; + } + } + + if (use_mlas_q4_gemm_ && fc1_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type)) { + // Safe non-cached direct path for dynamic scales. + const uint8_t* current_packed_ptr = static_cast(packed_fc1_.get()) + expert_idx * fc1_out_features * hidden_size; + DequantizePrePacked(current_packed_ptr, fc1_scales_ptr, nullptr, + is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, + B1_dequant, fc1_scales_dims); + + IAllocatorUniquePtr mlas_packed_fc1; + size_t packed_size = MlasQ4GemmPackBSize(q_type, static_cast(fc1_out_features), static_cast(hidden_size)); + mlas_packed_fc1 = IAllocator::MakeUniquePtr(allocator, packed_size); + MlasQ4GemmPackB(q_type, mlas_packed_fc1.get(), B1_dequant, + static_cast(fc1_out_features), static_cast(hidden_size), + static_cast(fc1_out_features)); + + float* fc1_bias_float = nullptr; + if (has_fc1_bias) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); + } else { + std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } + fc1_bias_float = thread_bias1_buffer; + } + + Status gemm_status = DirectQ4Gemm(A1, mlas_packed_fc1.get(), fc1_bias_float, C1, + num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); + if (gemm_status.IsOK()) { + goto fc1_gemm_done; + } + } + + // Dequantize from PrePacked (transposed, unpacked) + const uint8_t* current_packed_ptr = static_cast(packed_fc1_.get()) + expert_idx * fc1_out_features * hidden_size; + + DequantizePrePacked(current_packed_ptr, fc1_scales_ptr, fc1_zp_ptr, + is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, + B1_dequant, fc1_scales_dims); + + // Use MlasGemm with B1_dequant (which is already float transposed) + // GEMM is C = A * B. A [M, K], B [K, N]. + // B1_dequant is [K, N] (RowMajor). + // So we use CblasNoTrans for B. + MlasGemm(CblasNoTrans, CblasNoTrans, + m, n, k, + 1.0f, A1, k, + B1_dequant, n, + 0.0f, C1, n, + tp); + + // Skip the check for packed_b fallback logic below + goto fc1_bias_handling; + } if (use_direct_q4_gemm) { IAllocatorUniquePtr mlas_packed_fc1; @@ -805,6 +1125,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { 0.0f, C1, n, tp); + fc1_bias_handling: + if (has_fc1_bias) { const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { @@ -895,9 +1217,88 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t k2 = static_cast(inter_size); MLAS_BLK_QUANT_TYPE q_type2 = BlkQ4Sym; // Initialize to default - bool use_direct_q4_gemm_fc2 = (fc2_zp_data == nullptr) && - CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, - hidden_size, inter_size, q_type2); + bool use_direct_q4_gemm_fc2 = use_mlas_q4_gemm_ && + ((fc2_direct_q4_cache != nullptr) || + ((packed_fc2_ == nullptr) && (fc2_zp_data == nullptr) && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2))); + + if (packed_fc2_ != nullptr) { + if (fc2_direct_q4_cache != nullptr) { + float* fc2_bias_float = nullptr; + if (has_fc2_bias) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); + } else { + std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); + } + fc2_bias_float = thread_bias2_buffer; + } + + const auto& packed_b = (*fc2_direct_q4_cache)[static_cast(expert_idx)]; + Status gemm_status = DirectQ4Gemm(A2, packed_b.data(), fc2_bias_float, C2, + num_expert_tokens, hidden_size, inter_size, fc2_direct_qtype, tp); + if (gemm_status.IsOK()) { + fc2_bias_added_by_mlas = true; + goto fc2_gemm_done; + } + } + + if (use_mlas_q4_gemm_ && fc2_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2)) { + // Safe non-cached direct path for dynamic scales. + const uint8_t* current_packed_ptr = static_cast(packed_fc2_.get()) + expert_idx * hidden_size * inter_size; + DequantizePrePacked(current_packed_ptr, fc2_scales_ptr, nullptr, + is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, + B2_dequant, fc2_scales_dims); + + IAllocatorUniquePtr mlas_packed_fc2; + size_t packed_size = MlasQ4GemmPackBSize(q_type2, static_cast(hidden_size), static_cast(inter_size)); + mlas_packed_fc2 = IAllocator::MakeUniquePtr(allocator, packed_size); + MlasQ4GemmPackB(q_type2, mlas_packed_fc2.get(), B2_dequant, + static_cast(hidden_size), static_cast(inter_size), + static_cast(hidden_size)); + + float* fc2_bias_float = nullptr; + if (has_fc2_bias) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); + } else { + std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); + } + fc2_bias_float = thread_bias2_buffer; + } + + Status gemm_status = DirectQ4Gemm(A2, mlas_packed_fc2.get(), fc2_bias_float, C2, + num_expert_tokens, hidden_size, inter_size, q_type2, tp); + if (gemm_status.IsOK()) { + fc2_bias_added_by_mlas = true; + goto fc2_gemm_done; + } + } + + // Dequantize from PrePacked (transposed, unpacked) + const uint8_t* current_packed_ptr = static_cast(packed_fc2_.get()) + expert_idx * hidden_size * inter_size; + + DequantizePrePacked(current_packed_ptr, fc2_scales_ptr, fc2_zp_ptr, + is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, + B2_dequant, fc2_scales_dims); + + // Fallback + MlasGemm(CblasNoTrans, CblasNoTrans, + m2, n2, k2, + 1.0f, A2, k2, + B2_dequant, n2, + 0.0f, C2, n2, + tp); + + goto fc2_gemm_done; + } if (use_direct_q4_gemm_fc2) { IAllocatorUniquePtr mlas_packed_fc2; @@ -1115,9 +1516,14 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); + template Status QMoECPU::Compute(OpKernelContext* context) const; +template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); +template Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, bool& used_shared_buffers); template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; +template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); +template Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, bool& used_shared_buffers); // Kernel Registration ONNX_OPERATOR_TYPED_KERNEL_EX( diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index 890580e051a8e..0a23ba8217905 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -5,7 +5,10 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas_q4.h" #include "contrib_ops/cpu/moe/moe_base_cpu.h" +#include +#include namespace onnxruntime { namespace contrib { @@ -26,8 +29,37 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { Status Compute(OpKernelContext* context) const override; private: + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + private: + struct DirectQ4Cache { + const void* scales_data_ptr{nullptr}; + int64_t rows{0}; + int64_t cols{0}; + int64_t num_experts{0}; + MLAS_BLK_QUANT_TYPE qtype{BlkQ4Sym}; + std::vector> packed_b_by_expert; + }; + int64_t expert_weight_bits_; int64_t block_size_; + bool use_mlas_q4_gemm_{false}; + bool has_prepacked_fc1_scales_{false}; + bool has_prepacked_fc2_scales_{false}; + + IAllocatorUniquePtr packed_fc1_; + IAllocatorUniquePtr packed_fc2_; + IAllocatorUniquePtr packed_fc3_; + + mutable std::mutex direct_q4_cache_mu_; + mutable DirectQ4Cache fc1_direct_q4_cache_; + mutable DirectQ4Cache fc2_direct_q4_cache_; }; } // namespace contrib diff --git a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc index 38dd8de01147c..5137c22d6cf61 100644 --- a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc +++ b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc @@ -621,8 +621,8 @@ void DumpNodeInputs( std::cout << " is non-tensor type.\n"; } } else { - // this could happen with an empty Optional input - std::cout << " was missing data type\n"; + // this could happen with an empty Optional input or the tensor is removed after pre-packing. + std::cout << " was missing data type (maybe pre-packed).\n"; } } else { std::cout << "Input " << i << " is optional and was not provided.\n"; diff --git a/onnxruntime/test/python/transformers/benchmark_qmoe.py b/onnxruntime/test/python/transformers/benchmark_qmoe.py new file mode 100644 index 0000000000000..1ae9f7b358a1e --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_qmoe.py @@ -0,0 +1,187 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os +import sys +import time +import unittest + +import numpy +import torch + +# Add current directory to path to allow importing from test_qmoe_cpu +current_dir = os.path.dirname(os.path.abspath(__file__)) +if current_dir not in sys.path: + sys.path.append(current_dir) + +from test_qmoe_cpu import PhiMoEConfig, PhiMoESparseMoeBlock, TensorProto, disable_cpu_qmoe_tests # noqa: E402 + + +class TestQMoESwiGLUBenchmark(unittest.TestCase): + """Benchmark tests for QMoE SwiGLU performance measurement.""" + + def test_qmoe_swiglu_throughput_benchmark(self): + """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" + if disable_cpu_qmoe_tests: + self.skipTest("QMoE CPU tests disabled") + + print("\n=== QMoE SwiGLU Throughput Benchmark ===") + + # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits) + configs = [ + ("Medium-4bit", 2880, 2880, 32, 4, 4), + ("Medium-8bit", 2880, 2880, 32, 4, 8), + ] + + batch_size = 1 + sequence_length = 512 + num_runs = 30 + + results = [] + + for config_name, hidden_size, intermediate_size, num_experts, top_k, quant_bits in configs: + torch.manual_seed(42) + numpy.random.seed(42) + + print(f"\nTesting {config_name}:") + print(f" Hidden: {hidden_size}, Intermediate: {intermediate_size}") + print(f" Experts: {num_experts}, Top-K: {top_k}, Quant: {quant_bits}-bit") + + try: + # Create config and model + config = PhiMoEConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_local_experts=num_experts, + num_experts_per_tok=top_k, + ) + + qmoe_swiglu = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + # Create test input with fixed sequence length to match ONNX model + full_hidden_states = torch.randn(batch_size, sequence_length, hidden_size).to(torch.float32) + + # For TTFT simulation, we'll measure single forward pass time + # This represents the time to process one token in autoregressive generation + + # Warm up with full context + for _ in range(3): + _ = qmoe_swiglu.forward(full_hidden_states) + + # Benchmark PyTorch TTFT (Time to First Token) + # Measure time for a single forward pass (represents token generation time) + torch.manual_seed(42) + + start_time = time.time() + for _ in range(num_runs): + torch_output = qmoe_swiglu.forward(full_hidden_states) + end_time = time.time() + torch_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second (throughput) + # For sequence generation, this represents the rate at which we can generate tokens + torch_tokens_per_sec = 1000.0 / torch_ttft_ms # 1 token / (time_ms / 1000) + + print(f" PyTorch TTFT: {torch_ttft_ms:.3f} ms (per token generation time)") + print(f" PyTorch Throughput: {torch_tokens_per_sec:.1f} tokens/sec") + + # Benchmark ONNX Runtime + ort_ttft_ms = 0 + ort_tokens_per_sec = 0 + speedup = 0 + throughput_ratio = 0 + max_diff = 0 + + model_updated = qmoe_swiglu.recreate_onnx_model() + if model_updated and qmoe_swiglu.ort_sess is not None: + # Warm up ORT with full context + for _ in range(3): + _ = qmoe_swiglu.ort_forward(full_hidden_states) + + torch.manual_seed(42) + + # Measure ONNX Runtime TTFT (Time to First Token) + start_time = time.time() + for _ in range(num_runs): + ort_output = qmoe_swiglu.ort_forward(full_hidden_states) + end_time = time.time() + ort_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second for ONNX Runtime + ort_tokens_per_sec = 1000.0 / ort_ttft_ms # 1 token / (time_ms / 1000) + + speedup = torch_ttft_ms / ort_ttft_ms if ort_ttft_ms > 0 else 0 + throughput_ratio = ort_tokens_per_sec / torch_tokens_per_sec if torch_tokens_per_sec > 0 else 0 + + print(f" ONNX RT TTFT: {ort_ttft_ms:.3f} ms (per token generation time)") + print(f" ONNX RT Throughput: {ort_tokens_per_sec:.1f} tokens/sec") + print(f" TTFT Speedup: {speedup:.2f}x") + print(f" Throughput Gain: {throughput_ratio:.2f}x") + else: + print(" ONNX RT: Not available") + + # Calculate max difference if both outputs available + if torch_output is not None and ort_output is not None: + max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max().item() + print(f" Max diff: {max_diff:.6f}") + + results.append( + { + "config": config_name, + "torch_ttft_ms": torch_ttft_ms, + "torch_tokens_per_sec": torch_tokens_per_sec, + "ort_ttft_ms": ort_ttft_ms, + "ort_tokens_per_sec": ort_tokens_per_sec, + "speedup": speedup, + "throughput_ratio": throughput_ratio, + "max_diff": max_diff, + } + ) + + except Exception as e: + print(f" Error: {e}") + continue + + # Summary + print("\n=== Token Generation Time & Throughput Summary ===") + print( + f"{'Config':<15} {'PT Time':<10} {'PT tok/s':<10} {'ORT Time':<11} {'ORT tok/s':<11} {'Time Gain':<10} {'Throughput':<11} {'Max Diff':<10}" + ) + print("-" * 105) + for result in results: + config = result["config"] + torch_ttft = result["torch_ttft_ms"] + torch_tps = result["torch_tokens_per_sec"] + ort_ttft = result["ort_ttft_ms"] + ort_tps = result["ort_tokens_per_sec"] + speedup = result["speedup"] + throughput_ratio = result["throughput_ratio"] + max_diff = result["max_diff"] + + ort_ttft_str = f"{ort_ttft:.3f}" if ort_ttft > 0 else "N/A" + ort_tps_str = f"{ort_tps:.1f}" if ort_tps > 0 else "N/A" + speedup_str = f"{speedup:.2f}x" if speedup > 0 else "N/A" + throughput_str = f"{throughput_ratio:.2f}x" if throughput_ratio > 0 else "N/A" + + print( + f"{config:<15} {torch_ttft:<10.3f} {torch_tps:<10.1f} {ort_ttft_str:<11} {ort_tps_str:<11} {speedup_str:<10} {throughput_str:<11} {max_diff:<10.6f}" + ) + + print("\nNotes:") + print("- Time: Token generation time in ms (lower is better)") + print("- tok/s: Tokens per second throughput (higher is better)") + print("- Time Gain: ORT speedup for latency (higher is better)") + print("- Throughput: ORT throughput improvement (higher is better)") + + +if __name__ == "__main__": + benchmark = TestQMoESwiGLUBenchmark() + benchmark.test_qmoe_swiglu_throughput_benchmark() diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 238ac4d1f077d..e5fc5c120a7db 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -23,9 +23,11 @@ # normalization on the selected experts. This provides proper weight distribution # while maintaining computational efficiency. # -------------------------------------------------------------------------- +import os import time import unittest from collections import OrderedDict +from contextlib import contextmanager import numpy import torch @@ -76,6 +78,8 @@ class TensorProtoPlaceholder: ort_provider = ["CPUExecutionProvider"] +ORT_USE_MLAS_Q4_GEMM_MOE = "ORT_USE_MLAS_Q4_GEMM_MOE" + torch.manual_seed(42) numpy.random.seed(42) @@ -1137,6 +1141,37 @@ def small_test_cases(): yield batch_size, sequence_length +def with_mlas_q4_mode(test_cases): + expanded_cases = [] + for case in test_cases: + quant_bits = case[2] + expanded_cases.append((*case, False)) + if quant_bits == 4: + expanded_cases.append((*case, True)) + return expanded_cases + + +@contextmanager +def scoped_env_var(name: str, value: str): + previous = os.environ.get(name) + os.environ[name] = value + try: + yield + finally: + if previous is None: + os.environ.pop(name, None) + else: + os.environ[name] = previous + + +def run_parity_with_mlas_q4_mode(test_runner, enable_mlas_q4_gemm: bool): + env_value = "1" if enable_mlas_q4_gemm else "0" + mode = "enabled" if enable_mlas_q4_gemm else "disabled" + print(f"DirectQ4 mode ({ORT_USE_MLAS_Q4_GEMM_MOE}) is {mode}") + with scoped_env_var(ORT_USE_MLAS_Q4_GEMM_MOE, env_value): + test_runner() + + class SwigluMoEBlock(SparseMoeBlockORTHelper): def __init__( self, @@ -1402,8 +1437,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestPhiQMoECPU(unittest.TestCase): - @parameterized.expand(phi3_test_cases) - def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(phi3_test_cases)) + def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): # Create unique seed based on test parameters to ensure different inputs for each test base_seed = 2000 # Different base seed from other tests param_hash = hash((batch_size, sequence_length, quant_bits)) @@ -1438,10 +1473,10 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_test_cases) - def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(phi3_test_cases)) + def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): base_seed = 3000 param_hash = hash((batch_size, sequence_length, quant_bits)) unique_seed = base_seed + abs(param_hash) % 1000 @@ -1463,10 +1498,12 @@ def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quan onnx_dtype=TensorProto.FLOAT, use_asymmetric_quant=True, ) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_blockwise_test_cases) - def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(phi3_blockwise_test_cases)) + def test_phi3_qmoe_blockwise_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(42) numpy.random.seed(42) @@ -1495,10 +1532,12 @@ def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_blockwise_test_cases) - def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(phi3_blockwise_test_cases)) + def test_phi3_qmoe_blockwise_asymmetric_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(43) numpy.random.seed(43) @@ -1516,7 +1555,7 @@ def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_le block_size=block_size, use_asymmetric_quant=True, ) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) disable_cpu_qmoe_tests = False @@ -1539,8 +1578,8 @@ def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_le @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestSwigluQMoECPU(unittest.TestCase): - @parameterized.expand(swiglu_test_cases) - def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(swiglu_test_cases)) + def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): # Create unique seed based on test parameters to ensure different inputs for each test base_seed = 1000 # Different base seed from regular MoE tests param_hash = hash((batch_size, sequence_length, quant_bits)) @@ -1574,10 +1613,10 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_test_cases) - def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(swiglu_test_cases)) + def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): base_seed = 1100 param_hash = hash((batch_size, sequence_length, quant_bits)) unique_seed = base_seed + abs(param_hash) % 1000 @@ -1599,10 +1638,12 @@ def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, qu onnx_dtype=TensorProto.FLOAT, use_asymmetric_quant=True, ) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_blockwise_test_cases) - def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(swiglu_blockwise_test_cases)) + def test_swiglu_qmoe_blockwise_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(42) numpy.random.seed(42) @@ -1630,10 +1671,12 @@ def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, qua self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_blockwise_test_cases) - def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(swiglu_blockwise_test_cases)) + def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(43) numpy.random.seed(43) @@ -1651,7 +1694,7 @@ def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_ block_size=block_size, use_asymmetric_quant=True, ) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) @unittest.skipIf(True, "Skipping QMoE CPU benchmark tests") From 2b25601705801dc3f9cd240da57e73ad701143f1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 16 Feb 2026 21:09:55 +0000 Subject: [PATCH 03/11] move cache to prepack; allow q4gemm for block_size=32 --- .../cpu/moe/moe_quantization_cpu.cc | 270 +++++++----------- .../cpu/moe/moe_quantization_cpu.h | 14 +- 2 files changed, 109 insertions(+), 175 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 43a01da0bffef..de24579b39ba4 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -70,7 +70,7 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, out_qtype = BlkQ4Sym64; } else if (block_size == 128) { out_qtype = BlkQ4Sym128; - } else if (block_size == 0) { + } else if (block_size == 0 || block_size == 32) { out_qtype = BlkQ4Sym; } else { return false; @@ -85,7 +85,7 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, namespace onnxruntime { namespace contrib { -constexpr char* kUseMlasQ4GemmMoe = "ORT_USE_MLAS_Q4_GEMM_MOE"; +constexpr const char* kUseMlasQ4GemmMoe = "ORT_USE_MLAS_Q4_GEMM_MOE"; template void DequantizeBlockWithMlas(const uint8_t* quantized_data, @@ -433,7 +433,8 @@ Status BuildDirectQ4PackedBCache(const uint8_t* prepacked_weights, int64_t block_size, const gsl::span& scales_dims, MLAS_BLK_QUANT_TYPE qtype, - std::vector>& packed_b_by_expert) { + AllocatorPtr allocator, + IAllocatorUniquePtr& packed_b) { const size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); if (packed_size == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to compute MLAS Q4 packed size for cache"); @@ -442,9 +443,10 @@ Status BuildDirectQ4PackedBCache(const uint8_t* prepacked_weights, const bool is_block_wise = (scales_dims.size() == 3 && scales_dims[2] > 1); const int64_t scales_expert_stride = is_block_wise ? (rows * scales_dims[2]) : rows; const size_t prepacked_expert_stride = static_cast(rows * cols); + const size_t total_packed_size = packed_size * static_cast(num_experts); - packed_b_by_expert.clear(); - packed_b_by_expert.resize(static_cast(num_experts)); + packed_b = IAllocator::MakeUniquePtr(allocator, total_packed_size, true); + uint8_t* packed_b_ptr = static_cast(packed_b.get()); std::vector dequantized_transposed(static_cast(rows * cols)); for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { @@ -454,9 +456,7 @@ Status BuildDirectQ4PackedBCache(const uint8_t* prepacked_weights, DequantizePrePacked(expert_prepacked, expert_scales, nullptr, block_size, rows, cols, dequantized_transposed.data(), scales_dims); - auto& packed_b = packed_b_by_expert[static_cast(expert_idx)]; - packed_b.resize(packed_size); - MlasQ4GemmPackB(qtype, packed_b.data(), dequantized_transposed.data(), + MlasQ4GemmPackB(qtype, packed_b_ptr + expert_idx * packed_size, dequantized_transposed.data(), static_cast(rows), static_cast(cols), static_cast(rows)); } @@ -469,7 +469,7 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; - // If scales are prepacked, they are constant initializers. This enables safe shared cache usage. + // If scales are prepacked, they are constant initializers. if (input_idx == 3) { has_prepacked_fc1_scales_ = true; return Status::OK(); @@ -515,6 +515,48 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all prepacked_weights->buffers_.push_back(std::move(packed_buffer)); prepacked_weights->buffer_sizes_.push_back(packed_size); is_packed = true; + + // Try build MLAS Q4 cache if scales are available + if (use_mlas_q4_gemm_) { + const Tensor* scales_tensor = nullptr; + MLAS_BLK_QUANT_TYPE qtype = BlkQ4Sym; + int scales_idx = -1; + int zp_idx = -1; + + if (input_idx == 2) { // FC1 + scales_idx = 3; + zp_idx = 11; + CanUseMlasQ4Gemm(expert_weight_bits_, block_size_, rows, cols, qtype); + } else if (input_idx == 5) { // FC2 + scales_idx = 6; + zp_idx = 12; + CanUseMlasQ4Gemm(expert_weight_bits_, block_size_, rows, cols, qtype); + } + // FC3 (8) not supported for now + + if (scales_idx != -1 && + !Info().node().InputDefs()[zp_idx]->Exists() && + Info().TryGetConstantInput(scales_idx, &scales_tensor) && + scales_tensor != nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, block_size_ > 0 ? block_size_ : 0, rows, cols, qtype)) { + IAllocatorUniquePtr cache_buffer; + const auto& scales_dims = scales_tensor->Shape().GetDims(); + const T* scales_data = scales_tensor->Data(); + // Use the simple packed buffer we just created (buffer 0) as input + const uint8_t* simple_packed = dst_base; + + if (BuildDirectQ4PackedBCache(simple_packed, scales_data, num_experts, rows, cols, + block_size_ > 0 ? block_size_ : 0, scales_dims, qtype, + alloc, cache_buffer) + .IsOK()) { + // Store the size so we can verify later? Container holds size. + // We push it as a SECOND buffer. + size_t cache_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)) * static_cast(num_experts); + prepacked_weights->buffers_.push_back(std::move(cache_buffer)); + prepacked_weights->buffer_sizes_.push_back(cache_size); + } + } + } } } @@ -533,9 +575,15 @@ Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepa if (input_idx == 2 && !prepacked_buffers.empty()) { packed_fc1_ = std::move(prepacked_buffers[0]); + if (prepacked_buffers.size() > 1) { + packed_fc1_mlas_cache_ = std::move(prepacked_buffers[1]); + } used_shared_buffers = true; } else if (input_idx == 5 && !prepacked_buffers.empty()) { packed_fc2_ = std::move(prepacked_buffers[0]); + if (prepacked_buffers.size() > 1) { + packed_fc2_mlas_cache_ = std::move(prepacked_buffers[1]); + } used_shared_buffers = true; } else if (input_idx == 8 && !prepacked_buffers.empty()) { packed_fc3_ = std::move(prepacked_buffers[0]); @@ -559,7 +607,7 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); } - use_mlas_q4_gemm_ = ParseEnvironmentVariableWithDefault(kUseMlasQ4GemmMoe, true); + use_mlas_q4_gemm_ = ParseEnvironmentVariableWithDefault(kUseMlasQ4GemmMoe, false); } template @@ -788,61 +836,20 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_zp_expert_stride = (hidden_size + zp_pack_size - 1) / zp_pack_size; } - const std::vector>* fc1_direct_q4_cache = nullptr; - const std::vector>* fc2_direct_q4_cache = nullptr; MLAS_BLK_QUANT_TYPE fc1_direct_qtype = BlkQ4Sym; MLAS_BLK_QUANT_TYPE fc2_direct_qtype = BlkQ4Sym; - if (use_mlas_q4_gemm_ && has_prepacked_fc1_scales_ && packed_fc1_ != nullptr && fc1_zp_data == nullptr && - CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, - fc1_out_features, hidden_size, fc1_direct_qtype)) { - std::lock_guard guard(direct_q4_cache_mu_); - auto& cache = fc1_direct_q4_cache_; - if (cache.packed_b_by_expert.empty() || - cache.scales_data_ptr != static_cast(fc1_scales_data) || - cache.rows != fc1_out_features || cache.cols != hidden_size || - cache.num_experts != num_experts || cache.qtype != fc1_direct_qtype) { - std::vector> rebuilt_cache; - ORT_RETURN_IF_ERROR(BuildDirectQ4PackedBCache( - static_cast(packed_fc1_.get()), fc1_scales_data, - num_experts, fc1_out_features, hidden_size, - is_fc1_block_wise ? block_size_ : 0, - fc1_scales_dims, fc1_direct_qtype, rebuilt_cache)); - - cache.scales_data_ptr = static_cast(fc1_scales_data); - cache.rows = fc1_out_features; - cache.cols = hidden_size; - cache.num_experts = num_experts; - cache.qtype = fc1_direct_qtype; - cache.packed_b_by_expert = std::move(rebuilt_cache); - } - fc1_direct_q4_cache = &cache.packed_b_by_expert; + // Use pre-packed MLAS cache if available + const void* fc1_direct_q4_cache_ptr = nullptr; + if (use_mlas_q4_gemm_ && packed_fc1_mlas_cache_ && fc1_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, fc1_out_features, hidden_size, fc1_direct_qtype)) { + fc1_direct_q4_cache_ptr = packed_fc1_mlas_cache_.get(); } - if (use_mlas_q4_gemm_ && has_prepacked_fc2_scales_ && packed_fc2_ != nullptr && fc2_zp_data == nullptr && - CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, - hidden_size, inter_size, fc2_direct_qtype)) { - std::lock_guard guard(direct_q4_cache_mu_); - auto& cache = fc2_direct_q4_cache_; - if (cache.packed_b_by_expert.empty() || - cache.scales_data_ptr != static_cast(fc2_scales_data) || - cache.rows != hidden_size || cache.cols != inter_size || - cache.num_experts != num_experts || cache.qtype != fc2_direct_qtype) { - std::vector> rebuilt_cache; - ORT_RETURN_IF_ERROR(BuildDirectQ4PackedBCache( - static_cast(packed_fc2_.get()), fc2_scales_data, - num_experts, hidden_size, inter_size, - is_fc2_block_wise ? block_size_ : 0, - fc2_scales_dims, fc2_direct_qtype, rebuilt_cache)); - - cache.scales_data_ptr = static_cast(fc2_scales_data); - cache.rows = hidden_size; - cache.cols = inter_size; - cache.num_experts = num_experts; - cache.qtype = fc2_direct_qtype; - cache.packed_b_by_expert = std::move(rebuilt_cache); - } - fc2_direct_q4_cache = &cache.packed_b_by_expert; + const void* fc2_direct_q4_cache_ptr = nullptr; + if (use_mlas_q4_gemm_ && packed_fc2_mlas_cache_ && fc2_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, hidden_size, inter_size, fc2_direct_qtype)) { + fc2_direct_q4_cache_ptr = packed_fc2_mlas_cache_.get(); } std::vector> expert_workload; @@ -959,68 +966,39 @@ Status QMoECPU::Compute(OpKernelContext* context) const { MLAS_BLK_QUANT_TYPE q_type = BlkQ4Sym; // Initialize to default bool use_direct_q4_gemm = use_mlas_q4_gemm_ && - ((fc1_direct_q4_cache != nullptr) || + ((fc1_direct_q4_cache_ptr != nullptr) || ((packed_fc1_ == nullptr) && (fc1_zp_data == nullptr) && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, fc1_out_features, hidden_size, q_type))); if (packed_fc1_ != nullptr) { - if (fc1_direct_q4_cache != nullptr) { - float* fc1_bias_float = nullptr; - if (has_fc1_bias) { - const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; - if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); - } else { - std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); - } - fc1_bias_float = thread_bias1_buffer; - } - - const auto& packed_b = (*fc1_direct_q4_cache)[static_cast(expert_idx)]; - Status gemm_status = DirectQ4Gemm(A1, packed_b.data(), fc1_bias_float, C1, - num_expert_tokens, fc1_out_features, hidden_size, fc1_direct_qtype, tp); - if (gemm_status.IsOK()) { - goto fc1_gemm_done; - } - } - if (use_mlas_q4_gemm_ && fc1_zp_data == nullptr && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, fc1_out_features, hidden_size, q_type)) { - // Safe non-cached direct path for dynamic scales. - const uint8_t* current_packed_ptr = static_cast(packed_fc1_.get()) + expert_idx * fc1_out_features * hidden_size; - DequantizePrePacked(current_packed_ptr, fc1_scales_ptr, nullptr, - is_fc1_block_wise ? block_size_ : 0, - fc1_out_features, hidden_size, - B1_dequant, fc1_scales_dims); - - IAllocatorUniquePtr mlas_packed_fc1; - size_t packed_size = MlasQ4GemmPackBSize(q_type, static_cast(fc1_out_features), static_cast(hidden_size)); - mlas_packed_fc1 = IAllocator::MakeUniquePtr(allocator, packed_size); - MlasQ4GemmPackB(q_type, mlas_packed_fc1.get(), B1_dequant, - static_cast(fc1_out_features), static_cast(hidden_size), - static_cast(fc1_out_features)); - - float* fc1_bias_float = nullptr; - if (has_fc1_bias) { - const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; - if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); - } else { - std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + if (fc1_direct_q4_cache_ptr != nullptr) { + float* fc1_bias_float = nullptr; + if (has_fc1_bias) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); + } else { + std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } + fc1_bias_float = thread_bias1_buffer; } - fc1_bias_float = thread_bias1_buffer; - } - Status gemm_status = DirectQ4Gemm(A1, mlas_packed_fc1.get(), fc1_bias_float, C1, - num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); - if (gemm_status.IsOK()) { - goto fc1_gemm_done; + size_t packed_size = MlasQ4GemmPackBSize(q_type, static_cast(fc1_out_features), static_cast(hidden_size)); + const uint8_t* packed_b = static_cast(fc1_direct_q4_cache_ptr) + expert_idx * packed_size; + + Status gemm_status = DirectQ4Gemm(A1, packed_b, fc1_bias_float, C1, + num_expert_tokens, fc1_out_features, hidden_size, fc1_direct_qtype, tp); + if (gemm_status.IsOK()) { + goto fc1_gemm_done; + } } } - // Dequantize from PrePacked (transposed, unpacked) + // Fallback: Dequantize from PrePacked (transposed, unpacked) -> MlasGemm const uint8_t* current_packed_ptr = static_cast(packed_fc1_.get()) + expert_idx * fc1_out_features * hidden_size; DequantizePrePacked(current_packed_ptr, fc1_scales_ptr, fc1_zp_ptr, @@ -1029,9 +1007,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { B1_dequant, fc1_scales_dims); // Use MlasGemm with B1_dequant (which is already float transposed) - // GEMM is C = A * B. A [M, K], B [K, N]. - // B1_dequant is [K, N] (RowMajor). - // So we use CblasNoTrans for B. MlasGemm(CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, A1, k, @@ -1039,7 +1014,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { 0.0f, C1, n, tp); - // Skip the check for packed_b fallback logic below goto fc1_bias_handling; } @@ -1218,66 +1192,36 @@ Status QMoECPU::Compute(OpKernelContext* context) const { MLAS_BLK_QUANT_TYPE q_type2 = BlkQ4Sym; // Initialize to default bool use_direct_q4_gemm_fc2 = use_mlas_q4_gemm_ && - ((fc2_direct_q4_cache != nullptr) || + ((fc2_direct_q4_cache_ptr != nullptr) || ((packed_fc2_ == nullptr) && (fc2_zp_data == nullptr) && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, hidden_size, inter_size, q_type2))); if (packed_fc2_ != nullptr) { - if (fc2_direct_q4_cache != nullptr) { - float* fc2_bias_float = nullptr; - if (has_fc2_bias) { - const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; - if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); - } else { - std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); - } - fc2_bias_float = thread_bias2_buffer; - } - - const auto& packed_b = (*fc2_direct_q4_cache)[static_cast(expert_idx)]; - Status gemm_status = DirectQ4Gemm(A2, packed_b.data(), fc2_bias_float, C2, - num_expert_tokens, hidden_size, inter_size, fc2_direct_qtype, tp); - if (gemm_status.IsOK()) { - fc2_bias_added_by_mlas = true; - goto fc2_gemm_done; - } - } - if (use_mlas_q4_gemm_ && fc2_zp_data == nullptr && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, hidden_size, inter_size, q_type2)) { - // Safe non-cached direct path for dynamic scales. - const uint8_t* current_packed_ptr = static_cast(packed_fc2_.get()) + expert_idx * hidden_size * inter_size; - DequantizePrePacked(current_packed_ptr, fc2_scales_ptr, nullptr, - is_fc2_block_wise ? block_size_ : 0, - hidden_size, inter_size, - B2_dequant, fc2_scales_dims); - - IAllocatorUniquePtr mlas_packed_fc2; - size_t packed_size = MlasQ4GemmPackBSize(q_type2, static_cast(hidden_size), static_cast(inter_size)); - mlas_packed_fc2 = IAllocator::MakeUniquePtr(allocator, packed_size); - MlasQ4GemmPackB(q_type2, mlas_packed_fc2.get(), B2_dequant, - static_cast(hidden_size), static_cast(inter_size), - static_cast(hidden_size)); - - float* fc2_bias_float = nullptr; - if (has_fc2_bias) { - const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; - if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); - } else { - std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); + if (fc2_direct_q4_cache_ptr != nullptr) { + float* fc2_bias_float = nullptr; + if (has_fc2_bias) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); + } else { + std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); + } + fc2_bias_float = thread_bias2_buffer; } - fc2_bias_float = thread_bias2_buffer; - } - Status gemm_status = DirectQ4Gemm(A2, mlas_packed_fc2.get(), fc2_bias_float, C2, - num_expert_tokens, hidden_size, inter_size, q_type2, tp); - if (gemm_status.IsOK()) { - fc2_bias_added_by_mlas = true; - goto fc2_gemm_done; + size_t packed_size = MlasQ4GemmPackBSize(q_type2, static_cast(hidden_size), static_cast(inter_size)); + const uint8_t* packed_b = static_cast(fc2_direct_q4_cache_ptr) + expert_idx * packed_size; + + Status gemm_status = DirectQ4Gemm(A2, packed_b, fc2_bias_float, C2, + num_expert_tokens, hidden_size, inter_size, fc2_direct_qtype, tp); + if (gemm_status.IsOK()) { + fc2_bias_added_by_mlas = true; + goto fc2_gemm_done; + } } } diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index 0a23ba8217905..ed9f150df4bfe 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -38,15 +38,6 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { /*out*/ bool& used_shared_buffers) override; private: - struct DirectQ4Cache { - const void* scales_data_ptr{nullptr}; - int64_t rows{0}; - int64_t cols{0}; - int64_t num_experts{0}; - MLAS_BLK_QUANT_TYPE qtype{BlkQ4Sym}; - std::vector> packed_b_by_expert; - }; - int64_t expert_weight_bits_; int64_t block_size_; bool use_mlas_q4_gemm_{false}; @@ -57,9 +48,8 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { IAllocatorUniquePtr packed_fc2_; IAllocatorUniquePtr packed_fc3_; - mutable std::mutex direct_q4_cache_mu_; - mutable DirectQ4Cache fc1_direct_q4_cache_; - mutable DirectQ4Cache fc2_direct_q4_cache_; + IAllocatorUniquePtr packed_fc1_mlas_cache_; + IAllocatorUniquePtr packed_fc2_mlas_cache_; }; } // namespace contrib From e9dcd411ce4b7ebfa160c737e6081a7f61c2d7d4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 16 Feb 2026 22:27:02 +0000 Subject: [PATCH 04/11] Turn on q4 gemm by default if there is no accuracy loss --- .../cpu/moe/moe_quantization_cpu.cc | 29 ++++++++++++++----- .../cpu/moe/moe_quantization_cpu.h | 1 + .../test/python/transformers/test_qmoe_cpu.py | 18 ++++++++---- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index de24579b39ba4..f820253adacef 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -607,7 +607,15 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); } - use_mlas_q4_gemm_ = ParseEnvironmentVariableWithDefault(kUseMlasQ4GemmMoe, false); + const auto use_mlas_q4_gemm = ParseEnvironmentVariable(kUseMlasQ4GemmMoe); + if (use_mlas_q4_gemm.has_value()) { + use_mlas_q4_gemm_ = *use_mlas_q4_gemm; + use_mlas_q4_gemm_overridden_ = true; + } else { + // Default policy: enable fast path unless this run hits a known accuracy-loss configuration. + use_mlas_q4_gemm_ = true; + use_mlas_q4_gemm_overridden_ = false; + } } template @@ -809,6 +817,13 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const uint8_t* fc1_zp_data = fc1_zero_points ? fc1_zero_points->Data() : nullptr; const uint8_t* fc2_zp_data = fc2_zero_points ? fc2_zero_points->Data() : nullptr; + // Known loss-prone case from parity testing: 4-bit symmetric path (row-wise and block-wise). + const bool known_accuracy_loss_case = (expert_weight_bits_ == 4) && + (fc1_zp_data == nullptr) && (fc2_zp_data == nullptr); + const bool use_mlas_q4_gemm_effective = use_mlas_q4_gemm_overridden_ + ? use_mlas_q4_gemm_ + : (use_mlas_q4_gemm_ && !known_accuracy_loss_case); + const int64_t pack_unit = (8 / expert_weight_bits_); const int64_t fc1_packed_cols = (hidden_size + pack_unit - 1) / pack_unit; const int64_t fc2_packed_cols = (inter_size + pack_unit - 1) / pack_unit; @@ -841,13 +856,13 @@ Status QMoECPU::Compute(OpKernelContext* context) const { // Use pre-packed MLAS cache if available const void* fc1_direct_q4_cache_ptr = nullptr; - if (use_mlas_q4_gemm_ && packed_fc1_mlas_cache_ && fc1_zp_data == nullptr && + if (use_mlas_q4_gemm_effective && packed_fc1_mlas_cache_ && fc1_zp_data == nullptr && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, fc1_out_features, hidden_size, fc1_direct_qtype)) { fc1_direct_q4_cache_ptr = packed_fc1_mlas_cache_.get(); } const void* fc2_direct_q4_cache_ptr = nullptr; - if (use_mlas_q4_gemm_ && packed_fc2_mlas_cache_ && fc2_zp_data == nullptr && + if (use_mlas_q4_gemm_effective && packed_fc2_mlas_cache_ && fc2_zp_data == nullptr && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, hidden_size, inter_size, fc2_direct_qtype)) { fc2_direct_q4_cache_ptr = packed_fc2_mlas_cache_.get(); } @@ -965,14 +980,14 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t k = static_cast(hidden_size); MLAS_BLK_QUANT_TYPE q_type = BlkQ4Sym; // Initialize to default - bool use_direct_q4_gemm = use_mlas_q4_gemm_ && + bool use_direct_q4_gemm = use_mlas_q4_gemm_effective && ((fc1_direct_q4_cache_ptr != nullptr) || ((packed_fc1_ == nullptr) && (fc1_zp_data == nullptr) && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, fc1_out_features, hidden_size, q_type))); if (packed_fc1_ != nullptr) { - if (use_mlas_q4_gemm_ && fc1_zp_data == nullptr && + if (use_mlas_q4_gemm_effective && fc1_zp_data == nullptr && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, fc1_out_features, hidden_size, q_type)) { if (fc1_direct_q4_cache_ptr != nullptr) { @@ -1191,14 +1206,14 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t k2 = static_cast(inter_size); MLAS_BLK_QUANT_TYPE q_type2 = BlkQ4Sym; // Initialize to default - bool use_direct_q4_gemm_fc2 = use_mlas_q4_gemm_ && + bool use_direct_q4_gemm_fc2 = use_mlas_q4_gemm_effective && ((fc2_direct_q4_cache_ptr != nullptr) || ((packed_fc2_ == nullptr) && (fc2_zp_data == nullptr) && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, hidden_size, inter_size, q_type2))); if (packed_fc2_ != nullptr) { - if (use_mlas_q4_gemm_ && fc2_zp_data == nullptr && + if (use_mlas_q4_gemm_effective && fc2_zp_data == nullptr && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, hidden_size, inter_size, q_type2)) { if (fc2_direct_q4_cache_ptr != nullptr) { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index ed9f150df4bfe..5a8059b3eb894 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -41,6 +41,7 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { int64_t expert_weight_bits_; int64_t block_size_; bool use_mlas_q4_gemm_{false}; + bool use_mlas_q4_gemm_overridden_{false}; bool has_prepacked_fc1_scales_{false}; bool has_prepacked_fc2_scales_{false}; diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index e5fc5c120a7db..fa7b8725c572f 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -1145,9 +1145,12 @@ def with_mlas_q4_mode(test_cases): expanded_cases = [] for case in test_cases: quant_bits = case[2] - expanded_cases.append((*case, False)) if quant_bits == 4: + expanded_cases.append((*case, None)) + expanded_cases.append((*case, False)) expanded_cases.append((*case, True)) + else: + expanded_cases.append((*case, None)) return expanded_cases @@ -1164,12 +1167,15 @@ def scoped_env_var(name: str, value: str): os.environ[name] = previous -def run_parity_with_mlas_q4_mode(test_runner, enable_mlas_q4_gemm: bool): - env_value = "1" if enable_mlas_q4_gemm else "0" - mode = "enabled" if enable_mlas_q4_gemm else "disabled" - print(f"DirectQ4 mode ({ORT_USE_MLAS_Q4_GEMM_MOE}) is {mode}") - with scoped_env_var(ORT_USE_MLAS_Q4_GEMM_MOE, env_value): +def run_parity_with_mlas_q4_mode(test_runner, enable_mlas_q4_gemm: bool | None): + if enable_mlas_q4_gemm is None: # No env var test_runner() + else: + env_value = "1" if enable_mlas_q4_gemm else "0" + mode = "enabled" if enable_mlas_q4_gemm else "disabled" + print(f"DirectQ4 mode ({ORT_USE_MLAS_Q4_GEMM_MOE}) is {mode}") + with scoped_env_var(ORT_USE_MLAS_Q4_GEMM_MOE, env_value): + test_runner() class SwigluMoEBlock(SparseMoeBlockORTHelper): From dec886772a7fdd3037a0c6669cdffcc160e9584d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 17 Feb 2026 00:15:15 +0000 Subject: [PATCH 05/11] refine tests --- .../test/python/transformers/benchmark_qmoe.py | 11 ++++++----- onnxruntime/test/python/transformers/test_qmoe_cpu.py | 9 --------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/onnxruntime/test/python/transformers/benchmark_qmoe.py b/onnxruntime/test/python/transformers/benchmark_qmoe.py index 1ae9f7b358a1e..53854e053ef93 100644 --- a/onnxruntime/test/python/transformers/benchmark_qmoe.py +++ b/onnxruntime/test/python/transformers/benchmark_qmoe.py @@ -16,17 +16,18 @@ if current_dir not in sys.path: sys.path.append(current_dir) -from test_qmoe_cpu import PhiMoEConfig, PhiMoESparseMoeBlock, TensorProto, disable_cpu_qmoe_tests # noqa: E402 +from test_qmoe_cpu import PhiMoEConfig, PhiMoESparseMoeBlock, TensorProto # noqa: E402 +# Reduces number of tests to run for faster pipeline checks +pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" + +@unittest.skipIf(pipeline_mode, "Skip benchmark in CI pipeline.") class TestQMoESwiGLUBenchmark(unittest.TestCase): """Benchmark tests for QMoE SwiGLU performance measurement.""" def test_qmoe_swiglu_throughput_benchmark(self): """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" - if disable_cpu_qmoe_tests: - self.skipTest("QMoE CPU tests disabled") - print("\n=== QMoE SwiGLU Throughput Benchmark ===") # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits) @@ -37,7 +38,7 @@ def test_qmoe_swiglu_throughput_benchmark(self): batch_size = 1 sequence_length = 512 - num_runs = 30 + num_runs = 1000 results = [] diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index fa7b8725c572f..8415c7b08b77c 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -1422,8 +1422,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states -disable_cpu_qmoe_tests = False - # Define test cases for different MoE types phi3_test_cases = [ (1, 32, 4), @@ -1441,7 +1439,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ] -@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestPhiQMoECPU(unittest.TestCase): @parameterized.expand(with_mlas_q4_mode(phi3_test_cases)) def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): @@ -1564,8 +1561,6 @@ def test_phi3_qmoe_blockwise_asymmetric_parity_cpu( run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) -disable_cpu_qmoe_tests = False - swiglu_test_cases = [ (1, 32, 4), (1, 32, 8), @@ -1582,7 +1577,6 @@ def test_phi3_qmoe_blockwise_asymmetric_parity_cpu( ] -@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestSwigluQMoECPU(unittest.TestCase): @parameterized.expand(with_mlas_q4_mode(swiglu_test_cases)) def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): @@ -1709,9 +1703,6 @@ class TestQMoESwiGLUBenchmark(unittest.TestCase): def test_qmoe_swiglu_throughput_benchmark(self): """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" - if disable_cpu_qmoe_tests: - self.skipTest("QMoE CPU tests disabled") - print("\n=== QMoE SwiGLU Throughput Benchmark ===") # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits) From 4af13153f3cc8f0f034fd6263601377d8895addb Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 17 Feb 2026 07:45:46 +0000 Subject: [PATCH 06/11] refactor CheckInputs --- onnxruntime/contrib_ops/cpu/moe/moe_helper.h | 142 +++++++++++++----- .../cpu/moe/moe_quantization_cpu.cc | 67 +++++++-- .../cpu/moe/moe_quantization_cpu.h | 7 +- 3 files changed, 166 insertions(+), 50 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index 611d2f989d576..f9122406be633 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -35,31 +35,68 @@ struct MoEParameters { }; namespace moe_helper { +// Helper to check shape dimensions +#define ASSERT_SHAPE_DIMENSION(shape_ptr, dim, name) \ + if (shape_ptr != nullptr) { \ + if (shape_ptr->NumDimensions() != dim) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \ + "' is expected to have ", dim, " dimensions, got ", \ + shape_ptr->NumDimensions()); \ + } \ + } + +#define ASSERT_SHAPE_3D(shape_ptr, name) ASSERT_SHAPE_DIMENSION(shape_ptr, 3, name) + +#define CHECK_SHAPE(shape_ptr, name, ...) \ + if (shape_ptr != nullptr) { \ + const TensorShape& expected_shape = make_shape(__VA_ARGS__); \ + if (*shape_ptr != expected_shape) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \ + "' is expected to have shape ", expected_shape, \ + ", got ", *shape_ptr); \ + } \ + } + template Status CheckInputs(MoEParameters& parameters, - const Tensor* input, // required - const Tensor* router_probs, // required - const Tensor* fc1_experts_weights, // required - const Tensor* fc1_experts_bias, // optional - const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc1_zero_points, // optional, for qMoE - const Tensor* fc2_experts_weights, // required - const Tensor* fc2_experts_bias, // optional - const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc2_zero_points, // optional, for qMoE - const Tensor* fc3_experts_weights, // optional - const Tensor* fc3_experts_bias, // optional - const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc3_zero_points, // optional, for qMoE - const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) + const Tensor* input, // required + const Tensor* router_probs, // required + const TensorShape* fc1_experts_weights_shape, + const Tensor* fc1_experts_bias, // optional + const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc1_zero_points, // optional, for qMoE + const TensorShape* fc2_experts_weights_shape, + const Tensor* fc2_experts_bias, // optional + const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc2_zero_points, // optional, for qMoE + const TensorShape* fc3_experts_weights_shape, + const Tensor* fc3_experts_bias, // optional + const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc3_zero_points, // optional, for qMoE + const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) const bool is_fused_swiglu, const int64_t block_size = 0) { // block size for block-wise quantization // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. + if (input == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is required."); + } ASSERT_TENSOR_2D_OR_3D(input); - if (fc1_experts_weights) ASSERT_TENSOR_3D(fc1_experts_weights); - if (fc2_experts_weights) ASSERT_TENSOR_3D(fc2_experts_weights); + + if (router_probs == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'router_probs' is required."); + } ASSERT_TENSOR_2D(router_probs); + if (fc1_experts_weights_shape == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc1_experts_weights' is required."); + } + ASSERT_SHAPE_3D(fc1_experts_weights_shape, "fc1_experts_weights"); + + if (fc2_experts_weights_shape == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc2_experts_weights' is required."); + } + ASSERT_SHAPE_3D(fc2_experts_weights_shape, "fc2_experts_weights"); + const auto& input_dims = input->Shape().GetDims(); const auto& router_probs_dims = router_probs->Shape().GetDims(); @@ -68,19 +105,19 @@ Status CheckInputs(MoEParameters& parameters, int64_t num_experts = router_probs_dims[1]; int64_t local_num_experts; - if (fc1_experts_weights != nullptr) { - local_num_experts = fc1_experts_weights->Shape().GetDims()[0]; + if (fc1_experts_weights_shape != nullptr) { + local_num_experts = fc1_experts_weights_shape->GetDims()[0]; } else if (fc1_experts_scales != nullptr) { local_num_experts = fc1_experts_scales->Shape().GetDims()[0]; } else { - // Fallback for non-quantized MoE without weights (should not happen in current code paths) - // or if only bias is provided? - local_num_experts = num_experts; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid MoE configuration: both fc1_experts_weights and fc1_experts_scales are null. " + "At least one must be provided."); } int64_t inter_size; - if (fc2_experts_weights != nullptr) { - const auto& dims = fc2_experts_weights->Shape().GetDims(); + if (fc2_experts_weights_shape != nullptr) { + const auto& dims = fc2_experts_weights_shape->GetDims(); inter_size = (dims[1] * dims[2] * pack_size) / hidden_size; } else if (fc3_experts_scales != nullptr) { inter_size = fc3_experts_scales->Shape().GetDims()[1]; @@ -88,14 +125,15 @@ Status CheckInputs(MoEParameters& parameters, int64_t fc1_inter_size = fc1_experts_scales->Shape().GetDims()[1]; inter_size = is_fused_swiglu ? fc1_inter_size / 2 : fc1_inter_size; } else { - // Should not happen for valid QMoE calls - inter_size = 0; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid MoE configuration: unable to infer inter_size because " + "fc2_experts_weights, fc3_experts_scales, and fc1_experts_scales are all null."); } bool legacy_shape = false; - if (fc2_experts_weights != nullptr && fc1_experts_weights != nullptr) { - const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); - const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); + if (fc2_experts_weights_shape != nullptr && fc1_experts_weights_shape != nullptr) { + const auto& fc2_experts_weights_dims = fc2_experts_weights_shape->GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights_shape->GetDims(); legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); } @@ -106,13 +144,13 @@ Status CheckInputs(MoEParameters& parameters, if (legacy_shape) { // legacy shape does not match column major memory layout. This is for backward compatibility. - if (fc1_experts_weights) CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size); - if (fc2_experts_weights) CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size); - if (fc3_experts_weights) CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size); + CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, hidden_size, fc1_inter_size / pack_size); + CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, hidden_size, inter_size / pack_size); } else { - if (fc1_experts_weights) CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size); - if (fc2_experts_weights) CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size); - if (fc3_experts_weights) CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, fc1_inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, hidden_size, inter_size / pack_size); + CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, inter_size, hidden_size / pack_size); } CHECK_TENSOR_SHAPE(router_probs, num_rows, num_experts); @@ -194,9 +232,11 @@ Status CheckInputs(MoEParameters& parameters, } } - if (fc3_experts_weights == nullptr) { + if (fc3_experts_weights_shape == nullptr) { + // If fc3 weights are not provided, ensure no other fc3 parameters are provided ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr && fc3_zero_points == nullptr); } else { + // If fc3 weights are provided, ensure scales logic is consistent ORT_ENFORCE(fc1_experts_scales == nullptr || fc3_experts_scales != nullptr); // MOE no scale, or qMOE need scales } @@ -226,6 +266,36 @@ Status CheckInputs(MoEParameters& parameters, return Status::OK(); } +template +Status CheckInputs(MoEParameters& parameters, + const Tensor* input, // required + const Tensor* router_probs, // required + const Tensor* fc1_experts_weights, // required + const Tensor* fc1_experts_bias, // optional + const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc1_zero_points, // optional, for qMoE + const Tensor* fc2_experts_weights, // required + const Tensor* fc2_experts_bias, // optional + const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc2_zero_points, // optional, for qMoE + const Tensor* fc3_experts_weights, // optional + const Tensor* fc3_experts_bias, // optional + const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc3_zero_points, // optional, for qMoE + const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) + const bool is_fused_swiglu, + const int64_t block_size = 0) { // block size for block-wise quantization + + const TensorShape* fc1_shape = (fc1_experts_weights != nullptr) ? &fc1_experts_weights->Shape() : nullptr; + const TensorShape* fc2_shape = (fc2_experts_weights != nullptr) ? &fc2_experts_weights->Shape() : nullptr; + const TensorShape* fc3_shape = (fc3_experts_weights != nullptr) ? &fc3_experts_weights->Shape() : nullptr; + + return CheckInputs(parameters, input, router_probs, fc1_shape, fc1_experts_bias, fc1_experts_scales, fc1_zero_points, + fc2_shape, fc2_experts_bias, fc2_experts_scales, fc2_zero_points, + fc3_shape, fc3_experts_bias, fc3_experts_scales, fc3_zero_points, + pack_size, is_fused_swiglu, block_size); +} + } // namespace moe_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index f820253adacef..483e5184f63ac 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -471,11 +471,9 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all // If scales are prepacked, they are constant initializers. if (input_idx == 3) { - has_prepacked_fc1_scales_ = true; return Status::OK(); } if (input_idx == 6) { - has_prepacked_fc2_scales_ = true; return Status::OK(); } @@ -511,11 +509,33 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all } } + if (input_idx == 2) { + fc1_shape_ = shape; + } else if (input_idx == 5) { + fc2_shape_ = shape; + } else if (input_idx == 8) { + fc3_shape_ = shape; + } + if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_buffer)); prepacked_weights->buffer_sizes_.push_back(packed_size); is_packed = true; + // Pack Shape (Buffer 1) + auto dims = shape.GetDims(); + size_t rank_bytes = sizeof(int64_t); + size_t dims_bytes = dims.size() * sizeof(int64_t); + size_t shape_size = rank_bytes + dims_bytes; + + auto shape_buffer = IAllocator::MakeUniquePtr(alloc, shape_size); + int64_t* buffer_data = static_cast(shape_buffer.get()); + *buffer_data = static_cast(dims.size()); + memcpy(buffer_data + 1, dims.data(), dims_bytes); + + prepacked_weights->buffers_.push_back(std::move(shape_buffer)); + prepacked_weights->buffer_sizes_.push_back(shape_size); + // Try build MLAS Q4 cache if scales are available if (use_mlas_q4_gemm_) { const Tensor* scales_tensor = nullptr; @@ -550,7 +570,7 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all alloc, cache_buffer) .IsOK()) { // Store the size so we can verify later? Container holds size. - // We push it as a SECOND buffer. + // We push it as a THIRD buffer (Buffer 2) now. size_t cache_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)) * static_cast(num_experts); prepacked_weights->buffers_.push_back(std::move(cache_buffer)); prepacked_weights->buffer_sizes_.push_back(cache_size); @@ -576,17 +596,38 @@ Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepa if (input_idx == 2 && !prepacked_buffers.empty()) { packed_fc1_ = std::move(prepacked_buffers[0]); if (prepacked_buffers.size() > 1) { - packed_fc1_mlas_cache_ = std::move(prepacked_buffers[1]); + int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); + int64_t rank = buffer_data[0]; + std::vector dims(rank); + memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t)); + fc1_shape_ = TensorShape(dims); + } + if (prepacked_buffers.size() > 2) { + packed_fc1_mlas_cache_ = std::move(prepacked_buffers[2]); } used_shared_buffers = true; } else if (input_idx == 5 && !prepacked_buffers.empty()) { packed_fc2_ = std::move(prepacked_buffers[0]); if (prepacked_buffers.size() > 1) { - packed_fc2_mlas_cache_ = std::move(prepacked_buffers[1]); + int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); + int64_t rank = buffer_data[0]; + std::vector dims(rank); + memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t)); + fc2_shape_ = TensorShape(dims); + } + if (prepacked_buffers.size() > 2) { + packed_fc2_mlas_cache_ = std::move(prepacked_buffers[2]); } used_shared_buffers = true; } else if (input_idx == 8 && !prepacked_buffers.empty()) { packed_fc3_ = std::move(prepacked_buffers[0]); + if (prepacked_buffers.size() > 1) { + int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); + int64_t rank = buffer_data[0]; + std::vector dims(rank); + memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t)); + fc3_shape_ = TensorShape(dims); + } used_shared_buffers = true; } @@ -635,17 +676,21 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const auto* fc2_zero_points = context->Input(12); const auto* fc3_zero_points = context->Input(13); + const TensorShape* fc1_shape_ptr = packed_fc1_ ? &fc1_shape_ : (fc1_experts_weights ? &fc1_experts_weights->Shape() : nullptr); + const TensorShape* fc2_shape_ptr = packed_fc2_ ? &fc2_shape_ : (fc2_experts_weights ? &fc2_experts_weights->Shape() : nullptr); + const TensorShape* fc3_shape_ptr = packed_fc3_ ? &fc3_shape_ : (fc3_experts_weights ? &fc3_experts_weights->Shape() : nullptr); + MoEParameters moe_params; ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias, fc1_scales, fc1_zero_points, - fc2_experts_weights, fc2_experts_bias, fc2_scales, fc2_zero_points, - fc3_experts_weights, fc3_experts_bias, fc3_scales, fc3_zero_points, + fc1_shape_ptr, fc1_experts_bias, fc1_scales, fc1_zero_points, + fc2_shape_ptr, fc2_experts_bias, fc2_scales, fc2_zero_points, + fc3_shape_ptr, fc3_experts_bias, fc3_scales, fc3_zero_points, expert_weight_bits_ == 4 ? 2 : 1, true, block_size_)); - if (fc3_experts_weights || fc3_experts_bias || fc3_scales || fc3_zero_points) { + if (fc3_shape_ptr || fc3_experts_bias || fc3_scales || fc3_zero_points) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); } @@ -808,8 +853,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const bool is_fc1_block_wise = (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1); const bool is_fc2_block_wise = (fc2_scales_dims.size() == 3 && fc2_scales_dims[2] > 1); - const uint8_t* fc1_weights_data = (packed_fc1_ != nullptr) ? nullptr : fc1_experts_weights->Data(); - const uint8_t* fc2_weights_data = (packed_fc2_ != nullptr) ? nullptr : fc2_experts_weights->Data(); + const uint8_t* fc1_weights_data = (packed_fc1_ != nullptr) ? nullptr : fc1_experts_weights->template Data(); + const uint8_t* fc2_weights_data = (packed_fc2_ != nullptr) ? nullptr : fc2_experts_weights->template Data(); const T* fc1_scales_data = fc1_scales->Data(); const T* fc2_scales_data = fc2_scales->Data(); const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index 5a8059b3eb894..94105a4661ec1 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -7,7 +7,6 @@ #include "core/framework/op_kernel.h" #include "core/mlas/inc/mlas_q4.h" #include "contrib_ops/cpu/moe/moe_base_cpu.h" -#include #include namespace onnxruntime { @@ -42,13 +41,15 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { int64_t block_size_; bool use_mlas_q4_gemm_{false}; bool use_mlas_q4_gemm_overridden_{false}; - bool has_prepacked_fc1_scales_{false}; - bool has_prepacked_fc2_scales_{false}; IAllocatorUniquePtr packed_fc1_; IAllocatorUniquePtr packed_fc2_; IAllocatorUniquePtr packed_fc3_; + TensorShape fc1_shape_; + TensorShape fc2_shape_; + TensorShape fc3_shape_; + IAllocatorUniquePtr packed_fc1_mlas_cache_; IAllocatorUniquePtr packed_fc2_mlas_cache_; }; From 23cdf2d9c63e2a8c9af37a8a3c4cd1eba25412ea Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 17 Feb 2026 09:30:42 +0000 Subject: [PATCH 07/11] fix build --- .../contrib_ops/cpu/moe/moe_quantization_cpu.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 483e5184f63ac..4fe45bd692185 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -598,8 +598,8 @@ Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepa if (prepacked_buffers.size() > 1) { int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); int64_t rank = buffer_data[0]; - std::vector dims(rank); - memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t)); + std::vector dims(static_cast(rank)); + memcpy(dims.data(), buffer_data + 1, static_cast(rank) * sizeof(int64_t)); fc1_shape_ = TensorShape(dims); } if (prepacked_buffers.size() > 2) { @@ -611,8 +611,8 @@ Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepa if (prepacked_buffers.size() > 1) { int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); int64_t rank = buffer_data[0]; - std::vector dims(rank); - memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t)); + std::vector dims(static_cast(rank)); + memcpy(dims.data(), buffer_data + 1, static_cast(rank) * sizeof(int64_t)); fc2_shape_ = TensorShape(dims); } if (prepacked_buffers.size() > 2) { @@ -624,8 +624,8 @@ Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepa if (prepacked_buffers.size() > 1) { int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); int64_t rank = buffer_data[0]; - std::vector dims(rank); - memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t)); + std::vector dims(static_cast(rank)); + memcpy(dims.data(), buffer_data + 1, static_cast(rank) * sizeof(int64_t)); fc3_shape_ = TensorShape(dims); } used_shared_buffers = true; From 6d351fed32585bb9c242697ccc8371a71e53c5f2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 19 Feb 2026 04:27:40 +0000 Subject: [PATCH 08/11] robust shape storage --- onnxruntime/contrib_ops/cpu/moe/moe_helper.h | 70 +++++++------------ .../cpu/moe/moe_quantization_cpu.cc | 63 ++++++----------- .../cpu/moe/moe_quantization_cpu.h | 7 +- 3 files changed, 49 insertions(+), 91 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index f9122406be633..bd30418030dc2 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -59,24 +59,24 @@ namespace moe_helper { template Status CheckInputs(MoEParameters& parameters, - const Tensor* input, // required - const Tensor* router_probs, // required - const TensorShape* fc1_experts_weights_shape, - const Tensor* fc1_experts_bias, // optional - const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc1_zero_points, // optional, for qMoE - const TensorShape* fc2_experts_weights_shape, - const Tensor* fc2_experts_bias, // optional - const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc2_zero_points, // optional, for qMoE - const TensorShape* fc3_experts_weights_shape, - const Tensor* fc3_experts_bias, // optional - const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc3_zero_points, // optional, for qMoE - const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) + const Tensor* input, // required + const Tensor* router_probs, // required + const TensorShape* fc1_experts_weights_shape, // required + const Tensor* fc1_experts_bias, // optional + const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc1_zero_points, // optional, for qMoE + const TensorShape* fc2_experts_weights_shape, // required + const Tensor* fc2_experts_bias, // optional + const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc2_zero_points, // optional, for qMoE + const TensorShape* fc3_experts_weights_shape, // optional + const Tensor* fc3_experts_bias, // optional + const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc3_zero_points, // optional, for qMoE + const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) const bool is_fused_swiglu, const int64_t block_size = 0) { // block size for block-wise quantization - // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. + // Required inputs if (input == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is required."); } @@ -104,39 +104,17 @@ Status CheckInputs(MoEParameters& parameters, int64_t hidden_size = input_dims[input_dims.size() - 1]; int64_t num_experts = router_probs_dims[1]; - int64_t local_num_experts; - if (fc1_experts_weights_shape != nullptr) { - local_num_experts = fc1_experts_weights_shape->GetDims()[0]; - } else if (fc1_experts_scales != nullptr) { - local_num_experts = fc1_experts_scales->Shape().GetDims()[0]; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Invalid MoE configuration: both fc1_experts_weights and fc1_experts_scales are null. " - "At least one must be provided."); - } + int64_t local_num_experts = fc1_experts_weights_shape->GetDims()[0]; - int64_t inter_size; - if (fc2_experts_weights_shape != nullptr) { - const auto& dims = fc2_experts_weights_shape->GetDims(); - inter_size = (dims[1] * dims[2] * pack_size) / hidden_size; - } else if (fc3_experts_scales != nullptr) { - inter_size = fc3_experts_scales->Shape().GetDims()[1]; - } else if (fc1_experts_scales != nullptr) { - int64_t fc1_inter_size = fc1_experts_scales->Shape().GetDims()[1]; - inter_size = is_fused_swiglu ? fc1_inter_size / 2 : fc1_inter_size; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Invalid MoE configuration: unable to infer inter_size because " - "fc2_experts_weights, fc3_experts_scales, and fc1_experts_scales are all null."); - } + int64_t inter_size = (fc2_experts_weights_shape->GetDims()[1] * + fc2_experts_weights_shape->GetDims()[2] * pack_size) / + hidden_size; bool legacy_shape = false; - if (fc2_experts_weights_shape != nullptr && fc1_experts_weights_shape != nullptr) { - const auto& fc2_experts_weights_dims = fc2_experts_weights_shape->GetDims(); - const auto& fc1_experts_weights_dims = fc1_experts_weights_shape->GetDims(); - legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || - (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); - } + const auto& fc2_experts_weights_dims = fc2_experts_weights_shape->GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights_shape->GetDims(); + legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || + (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); // Fused swiglu doubles the output dimension of FC1 since it fused two GEMMs into one. const int64_t fc1_inter_size = is_fused_swiglu ? (inter_size + inter_size) : inter_size; diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 4fe45bd692185..3c4e04b06b45b 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -12,6 +12,7 @@ #include "core/common/safeint.h" #include "core/common/narrow.h" #include "core/framework/tensor_type_and_shape.h" +#include "core/framework/tensorprotoutils.h" #include "core/util/math.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/cpu/moe/moe_utils.h" @@ -522,20 +523,6 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all prepacked_weights->buffer_sizes_.push_back(packed_size); is_packed = true; - // Pack Shape (Buffer 1) - auto dims = shape.GetDims(); - size_t rank_bytes = sizeof(int64_t); - size_t dims_bytes = dims.size() * sizeof(int64_t); - size_t shape_size = rank_bytes + dims_bytes; - - auto shape_buffer = IAllocator::MakeUniquePtr(alloc, shape_size); - int64_t* buffer_data = static_cast(shape_buffer.get()); - *buffer_data = static_cast(dims.size()); - memcpy(buffer_data + 1, dims.data(), dims_bytes); - - prepacked_weights->buffers_.push_back(std::move(shape_buffer)); - prepacked_weights->buffer_sizes_.push_back(shape_size); - // Try build MLAS Q4 cache if scales are available if (use_mlas_q4_gemm_) { const Tensor* scales_tensor = nullptr; @@ -584,9 +571,10 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all } template -Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, - int input_idx, - /*out*/ bool& used_shared_buffers) { +Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, + /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; if (expert_weight_bits_ != 4) { @@ -596,38 +584,17 @@ Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepa if (input_idx == 2 && !prepacked_buffers.empty()) { packed_fc1_ = std::move(prepacked_buffers[0]); if (prepacked_buffers.size() > 1) { - int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); - int64_t rank = buffer_data[0]; - std::vector dims(static_cast(rank)); - memcpy(dims.data(), buffer_data + 1, static_cast(rank) * sizeof(int64_t)); - fc1_shape_ = TensorShape(dims); - } - if (prepacked_buffers.size() > 2) { - packed_fc1_mlas_cache_ = std::move(prepacked_buffers[2]); + packed_fc1_mlas_cache_ = std::move(prepacked_buffers[1]); } used_shared_buffers = true; } else if (input_idx == 5 && !prepacked_buffers.empty()) { packed_fc2_ = std::move(prepacked_buffers[0]); if (prepacked_buffers.size() > 1) { - int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); - int64_t rank = buffer_data[0]; - std::vector dims(static_cast(rank)); - memcpy(dims.data(), buffer_data + 1, static_cast(rank) * sizeof(int64_t)); - fc2_shape_ = TensorShape(dims); - } - if (prepacked_buffers.size() > 2) { - packed_fc2_mlas_cache_ = std::move(prepacked_buffers[2]); + packed_fc2_mlas_cache_ = std::move(prepacked_buffers[1]); } used_shared_buffers = true; } else if (input_idx == 8 && !prepacked_buffers.empty()) { packed_fc3_ = std::move(prepacked_buffers[0]); - if (prepacked_buffers.size() > 1) { - int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); - int64_t rank = buffer_data[0]; - std::vector dims(static_cast(rank)); - memcpy(dims.data(), buffer_data + 1, static_cast(rank) * sizeof(int64_t)); - fc3_shape_ = TensorShape(dims); - } used_shared_buffers = true; } @@ -648,6 +615,18 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); } + // Initialize shapes from InputDefs + const auto& input_defs = op_kernel_info.node().InputDefs(); + if (input_defs.size() > 2 && input_defs[2]->Exists() && input_defs[2]->Shape()) { + fc1_shape_ = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*input_defs[2]->Shape()); + } + if (input_defs.size() > 5 && input_defs[5]->Exists() && input_defs[5]->Shape()) { + fc2_shape_ = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*input_defs[5]->Shape()); + } + if (input_defs.size() > 8 && input_defs[8]->Exists() && input_defs[8]->Shape()) { + fc3_shape_ = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*input_defs[8]->Shape()); + } + const auto use_mlas_q4_gemm = ParseEnvironmentVariable(kUseMlasQ4GemmMoe); if (use_mlas_q4_gemm.has_value()) { use_mlas_q4_gemm_ = *use_mlas_q4_gemm; @@ -1523,11 +1502,11 @@ template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); -template Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, bool& used_shared_buffers); +template Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); -template Status QMoECPU::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, bool& used_shared_buffers); +template Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); // Kernel Registration ONNX_OPERATOR_TYPED_KERNEL_EX( diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index 94105a4661ec1..fc4991fe697d1 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -32,9 +32,10 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) override; - Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, - int input_idx, - /*out*/ bool& used_shared_buffers) override; + Status UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) override; private: int64_t expert_weight_bits_; From 0b4c783abda5f9a8255f60b9b0972b2684dd9f03 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 19 Feb 2026 07:36:32 +0000 Subject: [PATCH 09/11] refactoring --- .../cpu/moe/moe_quantization_cpu.cc | 66 +++++++++++-------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 3c4e04b06b45b..b44a583157dfc 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -12,7 +12,6 @@ #include "core/common/safeint.h" #include "core/common/narrow.h" #include "core/framework/tensor_type_and_shape.h" -#include "core/framework/tensorprotoutils.h" #include "core/util/math.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/cpu/moe/moe_utils.h" @@ -523,6 +522,20 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all prepacked_weights->buffer_sizes_.push_back(packed_size); is_packed = true; + // Pack Shape (Buffer 1) + auto dims = shape.GetDims(); + size_t rank_bytes = sizeof(int64_t); + size_t dims_bytes = dims.size() * sizeof(int64_t); + size_t shape_size = rank_bytes + dims_bytes; + + auto shape_buffer = IAllocator::MakeUniquePtr(alloc, shape_size); + int64_t* buffer_data = static_cast(shape_buffer.get()); + *buffer_data = static_cast(dims.size()); + memcpy(buffer_data + 1, dims.data(), dims_bytes); + + prepacked_weights->buffers_.push_back(std::move(shape_buffer)); + prepacked_weights->buffer_sizes_.push_back(shape_size); + // Try build MLAS Q4 cache if scales are available if (use_mlas_q4_gemm_) { const Tensor* scales_tensor = nullptr; @@ -581,21 +594,34 @@ Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& pr return Status::OK(); } - if (input_idx == 2 && !prepacked_buffers.empty()) { - packed_fc1_ = std::move(prepacked_buffers[0]); - if (prepacked_buffers.size() > 1) { - packed_fc1_mlas_cache_ = std::move(prepacked_buffers[1]); - } - used_shared_buffers = true; - } else if (input_idx == 5 && !prepacked_buffers.empty()) { - packed_fc2_ = std::move(prepacked_buffers[0]); - if (prepacked_buffers.size() > 1) { - packed_fc2_mlas_cache_ = std::move(prepacked_buffers[1]); + if ((input_idx == 2 || input_idx == 5 || input_idx == 8) && !prepacked_buffers.empty()) { + auto parse_shape = [&](TensorShape& shape) { + if (prepacked_buffers.size() > 1) { + int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); + int64_t rank = buffer_data[0]; + std::vector dims(static_cast(rank)); + memcpy(dims.data(), buffer_data + 1, static_cast(rank) * sizeof(int64_t)); + shape = TensorShape(dims); + } + }; + + if (input_idx == 2) { + packed_fc1_ = std::move(prepacked_buffers[0]); + parse_shape(fc1_shape_); + if (prepacked_buffers.size() > 2) { + packed_fc1_mlas_cache_ = std::move(prepacked_buffers[2]); + } + } else if (input_idx == 5) { + packed_fc2_ = std::move(prepacked_buffers[0]); + parse_shape(fc2_shape_); + if (prepacked_buffers.size() > 2) { + packed_fc2_mlas_cache_ = std::move(prepacked_buffers[2]); + } + } else /*if (input_idx == 8)*/ { + packed_fc3_ = std::move(prepacked_buffers[0]); + parse_shape(fc3_shape_); } used_shared_buffers = true; - } else if (input_idx == 8 && !prepacked_buffers.empty()) { - packed_fc3_ = std::move(prepacked_buffers[0]); - used_shared_buffers = true; } return Status::OK(); @@ -615,18 +641,6 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); } - // Initialize shapes from InputDefs - const auto& input_defs = op_kernel_info.node().InputDefs(); - if (input_defs.size() > 2 && input_defs[2]->Exists() && input_defs[2]->Shape()) { - fc1_shape_ = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*input_defs[2]->Shape()); - } - if (input_defs.size() > 5 && input_defs[5]->Exists() && input_defs[5]->Shape()) { - fc2_shape_ = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*input_defs[5]->Shape()); - } - if (input_defs.size() > 8 && input_defs[8]->Exists() && input_defs[8]->Shape()) { - fc3_shape_ = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*input_defs[8]->Shape()); - } - const auto use_mlas_q4_gemm = ParseEnvironmentVariable(kUseMlasQ4GemmMoe); if (use_mlas_q4_gemm.has_value()) { use_mlas_q4_gemm_ = *use_mlas_q4_gemm; From abef62bebbdd4a65b1e99fe8cf156f948f9f9455 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 20 Feb 2026 20:13:30 +0000 Subject: [PATCH 10/11] review feedback --- .../cpu/moe/moe_quantization_cpu.cc | 23 +++++++------------ .../cpu/moe/moe_quantization_cpu.h | 3 --- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index b44a583157dfc..13a15379abc47 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -477,13 +477,13 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all return Status::OK(); } - // Only support PrePack for FC1 (2), FC2 (5), and FC3 (8) weights + // Only support PrePack for FC1 (2) and FC2 (5) weights // and only if expert_weight_bits_ == 4 (since we unpack to uint8) if (expert_weight_bits_ != 4) { return Status::OK(); } - if (input_idx == 2 || input_idx == 5 || input_idx == 8) { + if (input_idx == 2 || input_idx == 5) { const auto& shape = tensor.Shape(); const int64_t num_experts = shape[0]; const int64_t rows = shape[1]; @@ -513,8 +513,6 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all fc1_shape_ = shape; } else if (input_idx == 5) { fc2_shape_ = shape; - } else if (input_idx == 8) { - fc3_shape_ = shape; } if (prepacked_weights) { @@ -546,19 +544,16 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all if (input_idx == 2) { // FC1 scales_idx = 3; zp_idx = 11; - CanUseMlasQ4Gemm(expert_weight_bits_, block_size_, rows, cols, qtype); } else if (input_idx == 5) { // FC2 scales_idx = 6; zp_idx = 12; - CanUseMlasQ4Gemm(expert_weight_bits_, block_size_, rows, cols, qtype); } - // FC3 (8) not supported for now if (scales_idx != -1 && !Info().node().InputDefs()[zp_idx]->Exists() && Info().TryGetConstantInput(scales_idx, &scales_tensor) && scales_tensor != nullptr && - CanUseMlasQ4Gemm(expert_weight_bits_, block_size_ > 0 ? block_size_ : 0, rows, cols, qtype)) { + CanUseMlasQ4Gemm(expert_weight_bits_, block_size_, rows, cols, qtype)) { IAllocatorUniquePtr cache_buffer; const auto& scales_dims = scales_tensor->Shape().GetDims(); const T* scales_data = scales_tensor->Data(); @@ -566,7 +561,7 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all const uint8_t* simple_packed = dst_base; if (BuildDirectQ4PackedBCache(simple_packed, scales_data, num_experts, rows, cols, - block_size_ > 0 ? block_size_ : 0, scales_dims, qtype, + block_size_, scales_dims, qtype, alloc, cache_buffer) .IsOK()) { // Store the size so we can verify later? Container holds size. @@ -594,7 +589,7 @@ Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& pr return Status::OK(); } - if ((input_idx == 2 || input_idx == 5 || input_idx == 8) && !prepacked_buffers.empty()) { + if ((input_idx == 2 || input_idx == 5) && !prepacked_buffers.empty()) { auto parse_shape = [&](TensorShape& shape) { if (prepacked_buffers.size() > 1) { int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); @@ -617,9 +612,6 @@ Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& pr if (prepacked_buffers.size() > 2) { packed_fc2_mlas_cache_ = std::move(prepacked_buffers[2]); } - } else /*if (input_idx == 8)*/ { - packed_fc3_ = std::move(prepacked_buffers[0]); - parse_shape(fc3_shape_); } used_shared_buffers = true; } @@ -635,6 +627,7 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, "Attribute 'expert_weight_bits' must be 4 or 8."); block_size_ = op_kernel_info.GetAttrOrDefault("block_size", 0); + ORT_ENFORCE(block_size_ >= 0); if (block_size_ > 0) { ORT_ENFORCE(block_size_ >= 16, "block_size must be >= 16 when provided."); @@ -662,7 +655,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const auto* fc2_experts_weights = packed_fc2_ ? nullptr : context->Input(5); const auto* fc2_scales = context->Input(6); const auto* fc2_experts_bias = context->Input(7); - const auto* fc3_experts_weights = packed_fc3_ ? nullptr : context->Input(8); + const auto* fc3_experts_weights = context->Input(8); const auto* fc3_scales = context->Input(9); const auto* fc3_experts_bias = context->Input(10); const auto* fc1_zero_points = context->Input(11); @@ -671,7 +664,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const TensorShape* fc1_shape_ptr = packed_fc1_ ? &fc1_shape_ : (fc1_experts_weights ? &fc1_experts_weights->Shape() : nullptr); const TensorShape* fc2_shape_ptr = packed_fc2_ ? &fc2_shape_ : (fc2_experts_weights ? &fc2_experts_weights->Shape() : nullptr); - const TensorShape* fc3_shape_ptr = packed_fc3_ ? &fc3_shape_ : (fc3_experts_weights ? &fc3_experts_weights->Shape() : nullptr); + const TensorShape* fc3_shape_ptr = fc3_experts_weights ? &fc3_experts_weights->Shape() : nullptr; MoEParameters moe_params; ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index fc4991fe697d1..3bbbabc26405d 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -37,7 +37,6 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { int input_idx, /*out*/ bool& used_shared_buffers) override; - private: int64_t expert_weight_bits_; int64_t block_size_; bool use_mlas_q4_gemm_{false}; @@ -45,11 +44,9 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { IAllocatorUniquePtr packed_fc1_; IAllocatorUniquePtr packed_fc2_; - IAllocatorUniquePtr packed_fc3_; TensorShape fc1_shape_; TensorShape fc2_shape_; - TensorShape fc3_shape_; IAllocatorUniquePtr packed_fc1_mlas_cache_; IAllocatorUniquePtr packed_fc2_mlas_cache_; From 11d30ce000a99716b93bc602cc506726742ec44f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 20 Feb 2026 22:21:08 +0000 Subject: [PATCH 11/11] address AI feedback --- .../cpu/moe/moe_quantization_cpu.cc | 63 ++++++++++--------- .../cpu/moe/moe_quantization_cpu.h | 2 + .../python/transformers/benchmark_qmoe.py | 3 + 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 13a15379abc47..81d2b0f8efdc6 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -76,7 +76,7 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, return false; } - size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(cols), static_cast(rows)); + size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(rows), static_cast(cols)); return expected_size > 0; } @@ -550,7 +550,7 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all } if (scales_idx != -1 && - !Info().node().InputDefs()[zp_idx]->Exists() && + (zp_idx >= static_cast(Info().node().InputDefs().size()) || !Info().node().InputDefs()[zp_idx]->Exists()) && Info().TryGetConstantInput(scales_idx, &scales_tensor) && scales_tensor != nullptr && CanUseMlasQ4Gemm(expert_weight_bits_, block_size_, rows, cols, qtype)) { @@ -564,8 +564,7 @@ Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all block_size_, scales_dims, qtype, alloc, cache_buffer) .IsOK()) { - // Store the size so we can verify later? Container holds size. - // We push it as a THIRD buffer (Buffer 2) now. + // Store the MLAS Q4 cache as buffer 2 (after unpacked weights and shape). size_t cache_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)) * static_cast(num_experts); prepacked_weights->buffers_.push_back(std::move(cache_buffer)); prepacked_weights->buffer_sizes_.push_back(cache_size); @@ -673,7 +672,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_shape_ptr, fc2_experts_bias, fc2_scales, fc2_zero_points, fc3_shape_ptr, fc3_experts_bias, fc3_scales, fc3_zero_points, expert_weight_bits_ == 4 ? 2 : 1, - true, + activation_type_ == ActivationType::SwiGLU, block_size_)); if (fc3_shape_ptr || fc3_experts_bias || fc3_scales || fc3_zero_points) { @@ -1079,12 +1078,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if (convert_status.IsOK()) { float* fc1_bias_float = nullptr; - IAllocatorUniquePtr fc1_bias_buffer; if (has_fc1_bias) { const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; - fc1_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(fc1_out_features)); - fc1_bias_float = fc1_bias_buffer.get(); + fc1_bias_float = thread_bias1_buffer; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), fc1_bias_float, static_cast(fc1_out_features)); @@ -1186,22 +1183,30 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc1_gemm_done: - const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); - if (num_expert_tokens >= activation_threshold && tp != nullptr) { - const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); - const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; - - if (num_activation_blocks > 1) { - concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_activation_blocks), [&](std::ptrdiff_t block_idx) { - const int64_t start_token = block_idx * activation_block_size; - const int64_t end_token = std::min(start_token + activation_block_size, num_expert_tokens); - - for (int64_t i = start_token; i < end_token; ++i) { + if (activation_type_ == ActivationType::SwiGLU) { + const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); + if (num_expert_tokens >= activation_threshold && tp != nullptr) { + const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); + const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; + + if (num_activation_blocks > 1) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_activation_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_token = block_idx * activation_block_size; + const int64_t end_token = std::min(start_token + activation_block_size, num_expert_tokens); + + for (int64_t i = start_token; i < end_token; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + }); + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { const float* C1_token = C1 + i * fc1_out_features; float* A2_token = A2 + i * inter_size; ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); } - }); + } } else { for (int64_t i = 0; i < num_expert_tokens; ++i) { const float* C1_token = C1 + i * fc1_out_features; @@ -1210,11 +1215,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } } } else { - for (int64_t i = 0; i < num_expert_tokens; ++i) { - const float* C1_token = C1 + i * fc1_out_features; - float* A2_token = A2 + i * inter_size; - ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); - } + ApplyActivationVectorized(C1, num_expert_tokens * fc1_out_features); + std::copy(C1, C1 + (num_expert_tokens * fc1_out_features), A2); } const T* fc2_scales_ptr; @@ -1306,12 +1308,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if (convert_status.IsOK()) { float* fc2_bias_float = nullptr; - IAllocatorUniquePtr fc2_bias_buffer; if (has_fc2_bias) { const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; - fc2_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(hidden_size)); - fc2_bias_float = fc2_bias_buffer.get(); + fc2_bias_float = thread_bias2_buffer; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), fc2_bias_float, static_cast(hidden_size)); @@ -1505,6 +1505,13 @@ Status QMoECPU::Compute(OpKernelContext* context) const { return Status::OK(); } +template +void QMoECPU::ApplyActivationVectorized(float* data, int64_t size) const { + for (int64_t i = 0; i < size; ++i) { + data[i] = ApplyActivation(data[i], activation_type_); + } +} + template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index 3bbbabc26405d..f678a27190c90 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -37,6 +37,8 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { int input_idx, /*out*/ bool& used_shared_buffers) override; + void ApplyActivationVectorized(float* data, int64_t size) const; + int64_t expert_weight_bits_; int64_t block_size_; bool use_mlas_q4_gemm_{false}; diff --git a/onnxruntime/test/python/transformers/benchmark_qmoe.py b/onnxruntime/test/python/transformers/benchmark_qmoe.py index 53854e053ef93..b96c9cdcf5c3a 100644 --- a/onnxruntime/test/python/transformers/benchmark_qmoe.py +++ b/onnxruntime/test/python/transformers/benchmark_qmoe.py @@ -46,6 +46,9 @@ def test_qmoe_swiglu_throughput_benchmark(self): torch.manual_seed(42) numpy.random.seed(42) + torch_output = None + ort_output = None + print(f"\nTesting {config_name}:") print(f" Hidden: {hidden_size}, Intermediate: {intermediate_size}") print(f" Experts: {num_experts}, Top-K: {top_k}, Quant: {quant_bits}-bit")