diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 149ecdb969bd5..90c59bd6ddf51 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -4613,7 +4613,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Whether to use sparse mixer
-#### Inputs (7 - 14)
+#### Inputs (7 - 15)
- input : T
@@ -4644,6 +4644,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- 2D tensor with shape (num_experts, hidden_size / pack_size), or 3D tensor with shape (num_experts, hidden_size, inter_size / block_size / pack_size) when block_size is provided.
- fc3_zero_points (optional) : T1
- 2D optional tensor with shape (num_experts, inter_size / pack_size), or 3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size / pack_size) when block_size is provided.
+- router_weights (optional) : T
+- 2D optional tensor with shape (num_tokens, num_experts). When provided, router_probs is used only for Top-K expert selection, and router_weights is used for aggregating expert outputs (the values at the selected expert indices are gathered and used as mixing weights). This enables DeepSeek-style noaux_tc routing where different tensors are used for selection and aggregation. When not provided, router_probs is used for both selection and aggregation (backward compatible).
#### Outputs
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 507c6722bc349..f70f04ab9b344 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -602,7 +602,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
-|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*in* fc1_zero_points:**T1**
*in* fc2_zero_points:**T1**
*in* fc3_zero_points:**T1**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(float), tensor(float16)|
+|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*in* fc1_zero_points:**T1**
*in* fc2_zero_points:**T1**
*in* fc3_zero_points:**T1**
*in* router_weights:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(float), tensor(float16)|
|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
@@ -1023,7 +1023,7 @@ Do not modify directly.*
|PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PagedAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* key_cache:**T**
*in* value_cache:**T**
*in* cumulative_sequence_length:**S**
*in* past_seqlens:**S**
*in* block_table:**S**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* key_cache_out:**T**
*out* value_cache_out:**T**|1+|**S** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)|
-|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*in* fc1_zero_points:**T1**
*in* fc2_zero_points:**T1**
*in* fc3_zero_points:**T1**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(bfloat16), tensor(float16)|
+|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*in* fc1_zero_points:**T1**
*in* fc2_zero_points:**T1**
*in* fc3_zero_points:**T1**
*in* router_weights:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(bfloat16), tensor(float16)|
|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc
index 9a9c61f863efc..1e52fddba5f73 100644
--- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc
+++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc
@@ -660,6 +660,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const {
const auto* fc1_zero_points = context->Input(11);
const auto* fc2_zero_points = context->Input(12);
const auto* fc3_zero_points = context->Input(13);
+ const auto* router_weights = context->Input(14);
const TensorShape* fc1_shape_ptr = packed_fc1_ ? &fc1_shape_ : (fc1_experts_weights ? &fc1_experts_weights->Shape() : nullptr);
const TensorShape* fc2_shape_ptr = packed_fc2_ ? &fc2_shape_ : (fc2_experts_weights ? &fc2_experts_weights->Shape() : nullptr);
@@ -708,6 +709,28 @@ Status QMoECPU::Compute(OpKernelContext* context) const {
router_logits_float = reinterpret_cast(router_probs->Data());
}
+ // Handle optional router_weights input for separate selection/aggregation tensors
+ const bool has_router_weights = (router_weights != nullptr);
+ IAllocatorUniquePtr router_weights_float_buffer;
+ const float* router_weights_float = nullptr;
+ if (has_router_weights) {
+ const auto& rw_shape = router_weights->Shape();
+ if (rw_shape.NumDimensions() != 2 || rw_shape[0] != num_tokens || rw_shape[1] != num_experts) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'router_weights' is expected to have shape (",
+ num_tokens, ", ", num_experts, "), got ", rw_shape);
+ }
+ if constexpr (std::is_same_v) {
+ router_weights_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts));
+ router_weights_float = router_weights_float_buffer.get();
+ MlasConvertHalfToFloatBuffer(reinterpret_cast(router_weights->Data()),
+ const_cast(router_weights_float),
+ static_cast(num_tokens * num_experts));
+ } else {
+ router_weights_float = reinterpret_cast(router_weights->Data());
+ }
+ }
+
auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_));
int* route_expert = route_expert_ptr.get();
auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_));
@@ -743,22 +766,58 @@ Status QMoECPU::Compute(OpKernelContext* context) const {
std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_),
sorted_logits.end(), std::greater<>());
- float max_logit = sorted_logits[0].first;
+ if (has_router_weights) {
+ // When router_weights is provided, use it for aggregation weights instead of softmax of router_probs.
+ // Gather weights from router_weights at the selected expert indices.
+ // Note: top_k_exp is reused here as a scratch buffer for the gathered weights.
+ const float* weights_row = router_weights_float + i * num_experts;
+ if (normalize_routing_weights_) {
+ float weight_sum = 0.0f;
+ for (size_t j = 0; j < narrow(k_); ++j) {
+ int64_t expert_idx = sorted_logits[j].second;
+ top_k_exp[j] = weights_row[expert_idx];
+ weight_sum += top_k_exp[j];
+ }
+ const float inv_weight_sum = (weight_sum == 0.0f) ? 0.0f : (1.0f / weight_sum);
+ for (size_t j = 0; j < narrow(k_); ++j) {
+ int64_t expert_idx = sorted_logits[j].second;
+ int64_t route_idx = i * k_ + narrow(j);
+ route_expert[route_idx] = narrow(expert_idx);
+ route_scale[route_idx] = top_k_exp[j] * inv_weight_sum;
+ if (route_scale[route_idx] > 1e-8f) {
+ local_expert_token_map[static_cast(expert_idx)].push_back(route_idx);
+ }
+ }
+ } else {
+ for (size_t j = 0; j < narrow(k_); ++j) {
+ int64_t expert_idx = sorted_logits[j].second;
+ int64_t route_idx = i * k_ + narrow(j);
+ route_expert[route_idx] = narrow(expert_idx);
+ route_scale[route_idx] = weights_row[expert_idx];
+ if (route_scale[route_idx] > 1e-8f) {
+ local_expert_token_map[static_cast(expert_idx)].push_back(route_idx);
+ }
+ }
+ }
+ } else {
+ // Default path: compute softmax weights from router_probs for aggregation.
+ float max_logit = sorted_logits[0].first;
- float sum_exp = 0.0f;
- for (size_t j = 0; j < narrow(k_); ++j) {
- top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit);
- sum_exp += top_k_exp[j];
- }
+ float sum_exp = 0.0f;
+ for (size_t j = 0; j < narrow(k_); ++j) {
+ top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit);
+ sum_exp += top_k_exp[j];
+ }
- const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp);
- for (size_t j = 0; j < narrow(k_); ++j) {
- int64_t expert_idx = sorted_logits[j].second;
- int64_t route_idx = i * k_ + narrow(j);
- route_expert[route_idx] = narrow(expert_idx);
- route_scale[route_idx] = top_k_exp[j] * inv_sum;
- if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights
- local_expert_token_map[static_cast(expert_idx)].push_back(route_idx);
+ const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp);
+ for (size_t j = 0; j < narrow(k_); ++j) {
+ int64_t expert_idx = sorted_logits[j].second;
+ int64_t route_idx = i * k_ + narrow(j);
+ route_expert[route_idx] = narrow(expert_idx);
+ route_scale[route_idx] = top_k_exp[j] * inv_sum;
+ if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights
+ local_expert_token_map[static_cast(expert_idx)].push_back(route_idx);
+ }
}
}
}
diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
index 79b25ba91ebbb..3b9f3680a839e 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
@@ -149,8 +149,11 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
const Tensor* fc1_zero_points = context->Input(11);
const Tensor* fc2_zero_points = context->Input(12);
const Tensor* fc3_zero_points = context->Input(13);
+ const Tensor* router_weights = context->Input(14);
ORT_ENFORCE(fc1_zero_points == nullptr && fc2_zero_points == nullptr && fc3_zero_points == nullptr,
"Zero points are not yet implemented on CUDA for QMoE.");
+ ORT_ENFORCE(router_weights == nullptr,
+ "Separate router_weights is not yet implemented on CUDA for QMoE.");
MoEParameters moe_params;
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs(
diff --git a/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc b/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc
index 39a1d1230ddf1..12608c817b201 100755
--- a/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc
+++ b/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc
@@ -184,6 +184,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
const Tensor* fc1_zero_points = context.Input(11);
const Tensor* fc2_zero_points = context.Input(12);
const Tensor* fc3_zero_points = context.Input(13);
+ const Tensor* router_weights = context.Input(14);
MoEParameters moe_params;
@@ -192,6 +193,11 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
"zero_points for QMoE are not yet supported on WebGPU.");
}
+ if (router_weights) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
+ "Separate router_weights is not yet implemented on WebGPU for QMoE.");
+ }
+
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs(
moe_params, hidden_state, router_logits,
fc1_experts_weights, fc1_experts_bias_optional, fc1_scales, fc1_zero_points,
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index cb562806faecf..5ccba675b4ecf 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1569,6 +1569,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size / pack_size) when block_size is provided.",
"T1",
OpSchema::Optional)
+ .Input(14,
+ "router_weights",
+ "2D optional tensor with shape (num_tokens, num_experts). "
+ "When provided, router_probs is used only for Top-K expert selection, and router_weights is used "
+ "for aggregating expert outputs (the values at the selected expert indices are gathered and used as "
+ "mixing weights). This enables DeepSeek-style noaux_tc routing where different tensors are used for "
+ "selection and aggregation. When not provided, router_probs is used for both selection and aggregation "
+ "(backward compatible).",
+ "T",
+ OpSchema::Optional)
.Output(0,
"output",
"output tensor with same shape of input",
diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc
index ab740ea38fb74..a22faa4fc7906 100644
--- a/onnxruntime/test/contrib_ops/moe_test.cc
+++ b/onnxruntime/test/contrib_ops/moe_test.cc
@@ -1690,6 +1690,93 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) {
#endif
}
+TEST(MoETest, QMoETest_CPU_RouterWeights) {
+ // Test that separate router_weights for aggregation works correctly.
+ // router_probs is used only for Top-K expert selection.
+ // router_weights is used only for weighting expert outputs.
+ auto cpu_ep = DefaultCpuExecutionProvider();
+ if (!cpu_ep) {
+ GTEST_SKIP() << "CPU execution provider not available";
+ }
+
+ int num_rows = 2;
+ int num_experts = 2;
+ int hidden_size = 16;
+ int inter_size = 16;
+
+ const std::vector input = {
+ -0.5f, 0.2f, 1.1f, -0.3f, 0.8f, -0.1f, 0.4f, -0.7f, 0.9f, -0.2f, 0.6f, 0.1f, -0.4f, 0.3f, -0.8f, 0.7f,
+ 0.1f, 0.7f, -0.4f, 0.2f, 0.8f, -0.3f, 0.5f, -0.1f, 0.6f, 0.4f, -0.7f, 0.3f, 0.9f, -0.2f, 0.1f, 0.8f};
+
+ // router_probs is only used for Top-K selection.
+ const std::vector router_probs = {0.1f, 0.9f, 0.8f, 0.2f};
+
+ // router_weights is used for aggregation and intentionally differs from router_probs.
+ const std::vector router_weights = {3.0f, 1.0f, 1.0f, 2.0f};
+
+ // Use zero-valued int8 weights and expert-specific FC2 bias so the final output is
+ // a simple weighted combination of constant expert outputs.
+ std::vector fc1_experts_weights(num_experts * 2 * inter_size * hidden_size, 128);
+ std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 128);
+
+ std::vector fc1_scales(num_experts * inter_size * 2, 0.1f);
+ std::vector fc2_scales(num_experts * hidden_size, 0.1f);
+ std::vector fc2_experts_bias;
+ fc2_experts_bias.insert(fc2_experts_bias.end(), static_cast(hidden_size), 1.0f);
+ fc2_experts_bias.insert(fc2_experts_bias.end(), static_cast(hidden_size), 3.0f);
+
+ std::vector input_dims = {num_rows, hidden_size};
+ std::vector router_probs_dims = {num_rows, num_experts};
+ std::vector router_weights_dims = {num_rows, num_experts};
+ std::vector fc1_experts_weights_dims = {num_experts, 2 * inter_size, hidden_size};
+ std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size};
+ std::vector fc1_scales_dims = {num_experts, inter_size * 2};
+ std::vector fc2_scales_dims = {num_experts, hidden_size};
+ std::vector fc2_experts_bias_dims = {num_experts, hidden_size};
+ std::vector output_dims = {num_rows, hidden_size};
+
+ auto run_test = [&](int64_t normalize_routing_weights, const std::vector& expected_output) {
+ OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain);
+ cpu_tester.AddAttribute("k", 2);
+ cpu_tester.AddAttribute("activation_type", "swiglu");
+ cpu_tester.AddAttribute("swiglu_fusion", static_cast(1));
+ cpu_tester.AddAttribute("normalize_routing_weights", normalize_routing_weights);
+ cpu_tester.AddAttribute("expert_weight_bits", 8);
+
+ cpu_tester.AddInput("input", input_dims, ToFloat16(input));
+ cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs));
+ cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights);
+ cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales);
+ cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias
+ cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights);
+ cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales);
+ cpu_tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, ToFloat16(fc2_experts_bias));
+ cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights
+ cpu_tester.AddOptionalInputEdge(); // fc3_scales
+ cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias
+ cpu_tester.AddOptionalInputEdge(); // fc1_zero_points
+ cpu_tester.AddOptionalInputEdge(); // fc2_zero_points
+ cpu_tester.AddOptionalInputEdge(); // fc3_zero_points
+ cpu_tester.AddInput("router_weights", router_weights_dims, ToFloat16(router_weights));
+ cpu_tester.AddOutput("output", output_dims, ToFloat16(expected_output));
+ cpu_tester.SetOutputTolerance(0.05f);
+
+ std::vector> cpu_execution_providers;
+ cpu_execution_providers.push_back(DefaultCpuExecutionProvider());
+ cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers);
+ };
+
+ std::vector expected_output_no_norm;
+ expected_output_no_norm.insert(expected_output_no_norm.end(), static_cast(hidden_size), 6.0f);
+ expected_output_no_norm.insert(expected_output_no_norm.end(), static_cast(hidden_size), 7.0f);
+ run_test(0, expected_output_no_norm);
+
+ std::vector expected_output_norm;
+ expected_output_norm.insert(expected_output_norm.end(), static_cast(hidden_size), 1.5f);
+ expected_output_norm.insert(expected_output_norm.end(), static_cast(hidden_size), 7.0f / 3.0f);
+ run_test(1, expected_output_norm);
+}
+
// Test for CPU MoE implementation
static void RunMoECpuTest(const std::vector& input, const std::vector& router_probs,
const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights,