diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 149ecdb969bd5..90c59bd6ddf51 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -4613,7 +4613,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Whether to use sparse mixer
-#### Inputs (7 - 14) +#### Inputs (7 - 15)
input : T
@@ -4644,6 +4644,8 @@ This version of the operator has been available since version 1 of the 'com.micr
2D tensor with shape (num_experts, hidden_size / pack_size), or 3D tensor with shape (num_experts, hidden_size, inter_size / block_size / pack_size) when block_size is provided.
fc3_zero_points (optional) : T1
2D optional tensor with shape (num_experts, inter_size / pack_size), or 3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size / pack_size) when block_size is provided.
+
router_weights (optional) : T
+
2D optional tensor with shape (num_tokens, num_experts). When provided, router_probs is used only for Top-K expert selection, and router_weights is used for aggregating expert outputs (the values at the selected expert indices are gathered and used as mixing weights). This enables DeepSeek-style noaux_tc routing where different tensors are used for selection and aggregation. When not provided, router_probs is used for both selection and aggregation (backward compatible).
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 507c6722bc349..f70f04ab9b344 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -602,7 +602,7 @@ Do not modify directly.* |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)| -|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*in* fc1_zero_points:**T1**
*in* fc2_zero_points:**T1**
*in* fc3_zero_points:**T1**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*in* fc1_zero_points:**T1**
*in* fc2_zero_points:**T1**
*in* fc3_zero_points:**T1**
*in* router_weights:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(float), tensor(float16)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| @@ -1023,7 +1023,7 @@ Do not modify directly.* |PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |PagedAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* key_cache:**T**
*in* value_cache:**T**
*in* cumulative_sequence_length:**S**
*in* past_seqlens:**S**
*in* block_table:**S**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* key_cache_out:**T**
*out* value_cache_out:**T**|1+|**S** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| -|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*in* fc1_zero_points:**T1**
*in* fc2_zero_points:**T1**
*in* fc3_zero_points:**T1**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(bfloat16), tensor(float16)| +|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*in* fc1_zero_points:**T1**
*in* fc2_zero_points:**T1**
*in* fc3_zero_points:**T1**
*in* router_weights:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(bfloat16), tensor(float16)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 9a9c61f863efc..1e52fddba5f73 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -660,6 +660,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const auto* fc1_zero_points = context->Input(11); const auto* fc2_zero_points = context->Input(12); const auto* fc3_zero_points = context->Input(13); + const auto* router_weights = context->Input(14); const TensorShape* fc1_shape_ptr = packed_fc1_ ? &fc1_shape_ : (fc1_experts_weights ? &fc1_experts_weights->Shape() : nullptr); const TensorShape* fc2_shape_ptr = packed_fc2_ ? &fc2_shape_ : (fc2_experts_weights ? &fc2_experts_weights->Shape() : nullptr); @@ -708,6 +709,28 @@ Status QMoECPU::Compute(OpKernelContext* context) const { router_logits_float = reinterpret_cast(router_probs->Data()); } + // Handle optional router_weights input for separate selection/aggregation tensors + const bool has_router_weights = (router_weights != nullptr); + IAllocatorUniquePtr router_weights_float_buffer; + const float* router_weights_float = nullptr; + if (has_router_weights) { + const auto& rw_shape = router_weights->Shape(); + if (rw_shape.NumDimensions() != 2 || rw_shape[0] != num_tokens || rw_shape[1] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'router_weights' is expected to have shape (", + num_tokens, ", ", num_experts, "), got ", rw_shape); + } + if constexpr (std::is_same_v) { + router_weights_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); + router_weights_float = router_weights_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_weights->Data()), + const_cast(router_weights_float), + static_cast(num_tokens * num_experts)); + } else { + router_weights_float = reinterpret_cast(router_weights->Data()); + } + } + auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); int* route_expert = route_expert_ptr.get(); auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); @@ -743,22 +766,58 @@ Status QMoECPU::Compute(OpKernelContext* context) const { std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); - float max_logit = sorted_logits[0].first; + if (has_router_weights) { + // When router_weights is provided, use it for aggregation weights instead of softmax of router_probs. + // Gather weights from router_weights at the selected expert indices. + // Note: top_k_exp is reused here as a scratch buffer for the gathered weights. + const float* weights_row = router_weights_float + i * num_experts; + if (normalize_routing_weights_) { + float weight_sum = 0.0f; + for (size_t j = 0; j < narrow(k_); ++j) { + int64_t expert_idx = sorted_logits[j].second; + top_k_exp[j] = weights_row[expert_idx]; + weight_sum += top_k_exp[j]; + } + const float inv_weight_sum = (weight_sum == 0.0f) ? 0.0f : (1.0f / weight_sum); + for (size_t j = 0; j < narrow(k_); ++j) { + int64_t expert_idx = sorted_logits[j].second; + int64_t route_idx = i * k_ + narrow(j); + route_expert[route_idx] = narrow(expert_idx); + route_scale[route_idx] = top_k_exp[j] * inv_weight_sum; + if (route_scale[route_idx] > 1e-8f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } else { + for (size_t j = 0; j < narrow(k_); ++j) { + int64_t expert_idx = sorted_logits[j].second; + int64_t route_idx = i * k_ + narrow(j); + route_expert[route_idx] = narrow(expert_idx); + route_scale[route_idx] = weights_row[expert_idx]; + if (route_scale[route_idx] > 1e-8f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } + } else { + // Default path: compute softmax weights from router_probs for aggregation. + float max_logit = sorted_logits[0].first; - float sum_exp = 0.0f; - for (size_t j = 0; j < narrow(k_); ++j) { - top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit); - sum_exp += top_k_exp[j]; - } + float sum_exp = 0.0f; + for (size_t j = 0; j < narrow(k_); ++j) { + top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit); + sum_exp += top_k_exp[j]; + } - const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); - for (size_t j = 0; j < narrow(k_); ++j) { - int64_t expert_idx = sorted_logits[j].second; - int64_t route_idx = i * k_ + narrow(j); - route_expert[route_idx] = narrow(expert_idx); - route_scale[route_idx] = top_k_exp[j] * inv_sum; - if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights - local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); + for (size_t j = 0; j < narrow(k_); ++j) { + int64_t expert_idx = sorted_logits[j].second; + int64_t route_idx = i * k_ + narrow(j); + route_expert[route_idx] = narrow(expert_idx); + route_scale[route_idx] = top_k_exp[j] * inv_sum; + if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } } } } diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 79b25ba91ebbb..3b9f3680a839e 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -149,8 +149,11 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { const Tensor* fc1_zero_points = context->Input(11); const Tensor* fc2_zero_points = context->Input(12); const Tensor* fc3_zero_points = context->Input(13); + const Tensor* router_weights = context->Input(14); ORT_ENFORCE(fc1_zero_points == nullptr && fc2_zero_points == nullptr && fc3_zero_points == nullptr, "Zero points are not yet implemented on CUDA for QMoE."); + ORT_ENFORCE(router_weights == nullptr, + "Separate router_weights is not yet implemented on CUDA for QMoE."); MoEParameters moe_params; ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs( diff --git a/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc b/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc index 39a1d1230ddf1..12608c817b201 100755 --- a/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc +++ b/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc @@ -184,6 +184,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const { const Tensor* fc1_zero_points = context.Input(11); const Tensor* fc2_zero_points = context.Input(12); const Tensor* fc3_zero_points = context.Input(13); + const Tensor* router_weights = context.Input(14); MoEParameters moe_params; @@ -192,6 +193,11 @@ Status QMoE::ComputeInternal(ComputeContext& context) const { "zero_points for QMoE are not yet supported on WebGPU."); } + if (router_weights) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Separate router_weights is not yet implemented on WebGPU for QMoE."); + } + ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs( moe_params, hidden_state, router_logits, fc1_experts_weights, fc1_experts_bias_optional, fc1_scales, fc1_zero_points, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index cb562806faecf..5ccba675b4ecf 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1569,6 +1569,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size / pack_size) when block_size is provided.", "T1", OpSchema::Optional) + .Input(14, + "router_weights", + "2D optional tensor with shape (num_tokens, num_experts). " + "When provided, router_probs is used only for Top-K expert selection, and router_weights is used " + "for aggregating expert outputs (the values at the selected expert indices are gathered and used as " + "mixing weights). This enables DeepSeek-style noaux_tc routing where different tensors are used for " + "selection and aggregation. When not provided, router_probs is used for both selection and aggregation " + "(backward compatible).", + "T", + OpSchema::Optional) .Output(0, "output", "output tensor with same shape of input", diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index ab740ea38fb74..a22faa4fc7906 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -1690,6 +1690,93 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { #endif } +TEST(MoETest, QMoETest_CPU_RouterWeights) { + // Test that separate router_weights for aggregation works correctly. + // router_probs is used only for Top-K expert selection. + // router_weights is used only for weighting expert outputs. + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + + int num_rows = 2; + int num_experts = 2; + int hidden_size = 16; + int inter_size = 16; + + const std::vector input = { + -0.5f, 0.2f, 1.1f, -0.3f, 0.8f, -0.1f, 0.4f, -0.7f, 0.9f, -0.2f, 0.6f, 0.1f, -0.4f, 0.3f, -0.8f, 0.7f, + 0.1f, 0.7f, -0.4f, 0.2f, 0.8f, -0.3f, 0.5f, -0.1f, 0.6f, 0.4f, -0.7f, 0.3f, 0.9f, -0.2f, 0.1f, 0.8f}; + + // router_probs is only used for Top-K selection. + const std::vector router_probs = {0.1f, 0.9f, 0.8f, 0.2f}; + + // router_weights is used for aggregation and intentionally differs from router_probs. + const std::vector router_weights = {3.0f, 1.0f, 1.0f, 2.0f}; + + // Use zero-valued int8 weights and expert-specific FC2 bias so the final output is + // a simple weighted combination of constant expert outputs. + std::vector fc1_experts_weights(num_experts * 2 * inter_size * hidden_size, 128); + std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 128); + + std::vector fc1_scales(num_experts * inter_size * 2, 0.1f); + std::vector fc2_scales(num_experts * hidden_size, 0.1f); + std::vector fc2_experts_bias; + fc2_experts_bias.insert(fc2_experts_bias.end(), static_cast(hidden_size), 1.0f); + fc2_experts_bias.insert(fc2_experts_bias.end(), static_cast(hidden_size), 3.0f); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector router_weights_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, 2 * inter_size, hidden_size}; + std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector fc2_experts_bias_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + auto run_test = [&](int64_t normalize_routing_weights, const std::vector& expected_output) { + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", 2); + cpu_tester.AddAttribute("activation_type", "swiglu"); + cpu_tester.AddAttribute("swiglu_fusion", static_cast(1)); + cpu_tester.AddAttribute("normalize_routing_weights", normalize_routing_weights); + cpu_tester.AddAttribute("expert_weight_bits", 8); + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); + cpu_tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, ToFloat16(fc2_experts_bias)); + cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights + cpu_tester.AddOptionalInputEdge(); // fc3_scales + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + cpu_tester.AddOptionalInputEdge(); // fc1_zero_points + cpu_tester.AddOptionalInputEdge(); // fc2_zero_points + cpu_tester.AddOptionalInputEdge(); // fc3_zero_points + cpu_tester.AddInput("router_weights", router_weights_dims, ToFloat16(router_weights)); + cpu_tester.AddOutput("output", output_dims, ToFloat16(expected_output)); + cpu_tester.SetOutputTolerance(0.05f); + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); + }; + + std::vector expected_output_no_norm; + expected_output_no_norm.insert(expected_output_no_norm.end(), static_cast(hidden_size), 6.0f); + expected_output_no_norm.insert(expected_output_no_norm.end(), static_cast(hidden_size), 7.0f); + run_test(0, expected_output_no_norm); + + std::vector expected_output_norm; + expected_output_norm.insert(expected_output_norm.end(), static_cast(hidden_size), 1.5f); + expected_output_norm.insert(expected_output_norm.end(), static_cast(hidden_size), 7.0f / 3.0f); + run_test(1, expected_output_norm); +} + // Test for CPU MoE implementation static void RunMoECpuTest(const std::vector& input, const std::vector& router_probs, const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights,