diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 5a7934fd37278..e42181f61a771 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -547,6 +547,7 @@ Do not modify directly.*
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(float16), tensor(uint8)
**T4** = tensor(int32)|
|MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)|
+|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 34410a5f42630..d959d11e3fd43 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -108,6 +108,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MoE);
// ******** End: Quantization ******************* //
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
@@ -275,6 +276,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
};
for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h
index eae96c186d471..84580b310f6b3 100644
--- a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h
+++ b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h
@@ -6,7 +6,7 @@
#include "core/common/common.h"
#include "core/framework/tensor_shape.h"
#include "core/framework/op_kernel.h"
-#include "contrib_ops/cpu/moe/moe_helper.h"
+#include "moe_helper.h"
#include
namespace onnxruntime {
diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc
new file mode 100644
index 0000000000000..73be099181aa9
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc
@@ -0,0 +1,605 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cpu/moe/moe_cpu.h"
+#include "contrib_ops/cpu/moe/moe_utils.h"
+#include "contrib_ops/cpu/moe/moe_helper.h"
+#include "core/framework/op_kernel.h"
+#include "core/providers/common.h"
+#include "core/providers/cpu/math/gemm_helper.h"
+#include "core/util/math_cpuonly.h"
+#include "core/mlas/inc/mlas.h"
+#include "core/framework/float16.h"
+#include "core/framework/allocator.h"
+#include "core/platform/threadpool.h"
+
+#include
+#include
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+MoE::MoE(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), MoEBaseCPU(op_kernel_info) {
+ if (activation_type_ == ActivationType::SwiGLU && swiglu_fusion_ != 1) {
+ ORT_THROW("CPU MoE only supports interleaved SwiGLU format. Please set swiglu_fusion=1.");
+ }
+}
+
+template
+Status MoE::Compute(OpKernelContext* context) const {
+ const Tensor* input = context->Input(0);
+ const Tensor* router_probs = context->Input(1);
+ const Tensor* fc1_experts_weights = context->Input(2);
+ const Tensor* fc1_experts_bias = context->Input(3);
+ const Tensor* fc2_experts_weights = context->Input(4);
+ const Tensor* fc2_experts_bias = context->Input(5);
+ const Tensor* fc3_experts_weights = context->Input(6);
+ const Tensor* fc3_experts_bias = context->Input(7);
+
+ // FC3 not supported
+ if (fc3_experts_weights != nullptr || fc3_experts_bias != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
+ "FC3 is not implemented for CPU MoE.");
+ }
+
+ MoEParameters moe_params;
+ ORT_RETURN_IF_ERROR(moe_helper::CheckInputs(
+ moe_params, input, router_probs,
+ fc1_experts_weights, fc1_experts_bias, nullptr,
+ fc2_experts_weights, fc2_experts_bias, nullptr,
+ fc3_experts_weights, fc3_experts_bias, nullptr,
+ 1,
+ activation_type_ == ActivationType::SwiGLU));
+
+ Tensor* output = context->Output(0, input->Shape());
+
+ return ComputeMoE(context, input, router_probs, fc1_experts_weights, fc1_experts_bias,
+ fc2_experts_weights, fc2_experts_bias, output);
+}
+
+template
+Status MoE::ComputeMoE(const OpKernelContext* context,
+ const Tensor* input,
+ const Tensor* router_probs,
+ const Tensor* fc1_experts_weights,
+ const Tensor* fc1_experts_bias,
+ const Tensor* fc2_experts_weights,
+ const Tensor* fc2_experts_bias,
+ Tensor* output) const {
+ const auto& input_shape = input->Shape();
+ const auto& router_shape = router_probs->Shape();
+ const auto& fc2_shape = fc2_experts_weights->Shape();
+
+ const int64_t num_tokens = input_shape.Size() / input_shape[input_shape.NumDimensions() - 1];
+ const int64_t hidden_size = input_shape[input_shape.NumDimensions() - 1];
+ const int64_t num_experts = router_shape[1];
+ const int64_t inter_size = (fc2_shape[1] * fc2_shape[2]) / hidden_size;
+ const bool is_swiglu = activation_type_ == ActivationType::SwiGLU;
+ const int64_t fc1_output_size = is_swiglu ? (inter_size * 2) : inter_size;
+
+ const T* input_data = input->Data();
+ const T* router_data = router_probs->Data();
+ const T* fc1_weights_data = fc1_experts_weights->Data();
+ const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr;
+ const T* fc2_weights_data = fc2_experts_weights->Data();
+ const T* fc2_bias_data = fc2_experts_bias ? fc2_experts_bias->Data() : nullptr;
+ T* output_data = output->MutableData();
+
+ AllocatorPtr allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+
+ const T* input_data_to_use = input_data;
+ IAllocatorUniquePtr input_data_copy_ptr;
+ if (normalize_routing_weights_) {
+ input_data_copy_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size));
+ T* input_data_copy = input_data_copy_ptr.get();
+ std::copy(input_data, input_data + (num_tokens * hidden_size), input_data_copy);
+ input_data_to_use = input_data_copy;
+ }
+
+ std::fill_n(output_data, output->Shape().Size(), T{});
+
+ IAllocatorUniquePtr router_logits_float_buffer;
+ const float* router_logits_float = nullptr;
+ if constexpr (std::is_same_v) {
+ router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts));
+ router_logits_float = router_logits_float_buffer.get();
+ MlasConvertHalfToFloatBuffer(reinterpret_cast(router_data), const_cast(router_logits_float), static_cast(num_tokens * num_experts));
+ } else {
+ router_logits_float = reinterpret_cast(router_data);
+ }
+
+ auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_));
+ int* route_expert = route_expert_ptr.get();
+ auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_));
+ float* route_scale = route_scale_ptr.get();
+
+ auto* tp = context->GetOperatorThreadPool();
+ int num_routing_threads = 1;
+ if (tp != nullptr && num_tokens >= 1024) {
+ int max_threads = concurrency::ThreadPool::DegreeOfParallelism(tp);
+ num_routing_threads = std::min(static_cast(num_tokens / 512), max_threads);
+ num_routing_threads = std::max(1, num_routing_threads);
+ }
+
+ std::vector>> thread_local_expert_token_maps(num_routing_threads);
+ for (auto& map : thread_local_expert_token_maps) {
+ map.resize(static_cast(num_experts));
+ for (auto& expert_map : map) {
+ expert_map.reserve(static_cast(std::max(static_cast(1), num_tokens / num_experts / num_routing_threads * 2)));
+ }
+ }
+
+ concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) {
+ auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens));
+ auto& local_expert_token_map = thread_local_expert_token_maps[thread_id];
+
+ std::vector> sorted_logits(static_cast(num_experts));
+ std::vector full_softmax(static_cast(num_experts));
+
+ for (int64_t i = work.start; i < work.end; ++i) {
+ const float* logits = router_logits_float + i * num_experts;
+
+ float max_logit = logits[0];
+ for (int64_t j = 1; j < num_experts; ++j) {
+ max_logit = std::max(max_logit, logits[j]);
+ }
+
+ float sum_exp = 0.0f;
+ for (int64_t j = 0; j < num_experts; ++j) {
+ full_softmax[static_cast(j)] = std::exp(logits[j] - max_logit);
+ sum_exp += full_softmax[static_cast(j)];
+ }
+
+ const float inv_sum_exp = 1.0f / sum_exp;
+ for (int64_t j = 0; j < num_experts; ++j) {
+ full_softmax[static_cast(j)] *= inv_sum_exp;
+ sorted_logits[static_cast(j)] = {full_softmax[static_cast(j)], j};
+ }
+
+ std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>());
+
+ if (normalize_routing_weights_) {
+ float top_k_sum = 0.0f;
+ for (int64_t j = 0; j < k_; ++j) {
+ top_k_sum += sorted_logits[static_cast(j)].first;
+ }
+ const float inv_top_k_sum = 1.0f / top_k_sum;
+
+ for (int64_t j = 0; j < k_; ++j) {
+ int64_t expert_idx = sorted_logits[static_cast(j)].second;
+ int64_t route_idx = i * k_ + j;
+ float normalized_weight = sorted_logits[static_cast(j)].first * inv_top_k_sum;
+
+ route_expert[route_idx] = static_cast(expert_idx);
+ route_scale[route_idx] = normalized_weight;
+ if (normalized_weight > 0.0f) {
+ local_expert_token_map[static_cast(expert_idx)].push_back(route_idx);
+ }
+ }
+ } else {
+ for (int64_t j = 0; j < k_; ++j) {
+ int64_t expert_idx = sorted_logits[static_cast(j)].second;
+ int64_t route_idx = i * k_ + j;
+ float weight = sorted_logits[static_cast(j)].first;
+
+ route_expert[route_idx] = static_cast(expert_idx);
+ route_scale[route_idx] = weight;
+ if (weight > 0.0f) {
+ local_expert_token_map[static_cast(expert_idx)].push_back(route_idx);
+ }
+ }
+ }
+ }
+ });
+
+ std::vector> expert_token_map(static_cast(num_experts));
+
+ for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
+ size_t total_tokens_for_expert = 0;
+ for (int t = 0; t < num_routing_threads; ++t) {
+ total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size();
+ }
+ expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert);
+ }
+
+ for (int t = 0; t < num_routing_threads; ++t) {
+ for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
+ auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)];
+ if (!local_tokens.empty()) {
+ auto& expert_map = expert_token_map[static_cast(expert_idx)];
+ expert_map.insert(expert_map.end(),
+ std::make_move_iterator(local_tokens.begin()),
+ std::make_move_iterator(local_tokens.end()));
+ }
+ }
+ }
+
+ for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
+ auto& expert_map = expert_token_map[static_cast(expert_idx)];
+ if (!expert_map.empty()) {
+ std::sort(expert_map.begin(), expert_map.end());
+ }
+ }
+
+ IAllocatorUniquePtr input_float_buffer;
+ const float* input_float;
+ if constexpr (std::is_same_v) {
+ input_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size));
+ input_float = input_float_buffer.get();
+ MlasConvertHalfToFloatBuffer(reinterpret_cast(input_data_to_use), const_cast(input_float), static_cast(num_tokens * hidden_size));
+ } else {
+ input_float = reinterpret_cast(input_data_to_use);
+ }
+
+ int num_expert_threads = 1;
+ if (tp != nullptr) {
+ int total_active_experts = 0;
+ for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
+ if (!expert_token_map[static_cast(expert_idx)].empty()) {
+ total_active_experts++;
+ }
+ }
+
+ if (total_active_experts > 0) {
+ int max_threads = concurrency::ThreadPool::DegreeOfParallelism(tp);
+ num_expert_threads = std::min(total_active_experts, max_threads);
+ num_expert_threads = std::min(num_expert_threads, 8);
+ }
+ }
+
+ // Calculate maximum possible tokens per expert for buffer sizing
+ int64_t max_tokens_per_expert = 0;
+ for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
+ const auto& routes = expert_token_map[static_cast(expert_idx)];
+ max_tokens_per_expert = std::max(max_tokens_per_expert, static_cast(routes.size()));
+ }
+
+ // Thread-local buffer pool for expert processing
+ struct ThreadLocalBuffers {
+ IAllocatorUniquePtr A1_buffer;
+ IAllocatorUniquePtr batch_weights_buffer;
+ IAllocatorUniquePtr token_ids_buffer;
+ IAllocatorUniquePtr A1_t_buffer;
+ IAllocatorUniquePtr C2_buffer;
+ // Additional buffers for ProcessExpertBatch to avoid repeated allocations
+ IAllocatorUniquePtr fc1_output_buffer;
+ IAllocatorUniquePtr activation_output_buffer;
+ int64_t current_capacity = 0;
+ int64_t current_fc1_capacity = 0;
+ int64_t current_activation_capacity = 0;
+
+ void EnsureCapacity(AllocatorPtr& allocator, int64_t required_tokens, int64_t hidden_size,
+ int64_t fc1_output_size, int64_t inter_size) {
+ if (required_tokens > current_capacity) {
+ // Use high watermark approach - allocate more than needed for future reuse
+ int64_t new_capacity = std::max(required_tokens * 2, current_capacity + 512);
+
+ A1_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size));
+ batch_weights_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity));
+ token_ids_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity));
+ A1_t_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size));
+ C2_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size));
+
+ current_capacity = new_capacity;
+ }
+
+ // Ensure ProcessExpertBatch buffers have sufficient capacity
+ int64_t required_fc1_capacity = required_tokens * fc1_output_size;
+ int64_t required_activation_capacity = required_tokens * inter_size;
+
+ if (required_fc1_capacity > current_fc1_capacity) {
+ int64_t new_fc1_capacity = std::max(required_fc1_capacity * 2, current_fc1_capacity + (512 * fc1_output_size));
+ fc1_output_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_fc1_capacity));
+ current_fc1_capacity = new_fc1_capacity;
+ }
+
+ if (required_activation_capacity > current_activation_capacity) {
+ int64_t new_activation_capacity = std::max(required_activation_capacity * 2, current_activation_capacity + (512 * inter_size));
+ activation_output_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_activation_capacity));
+ current_activation_capacity = new_activation_capacity;
+ }
+ }
+ };
+
+ // Pre-allocate thread-local buffer pools
+ std::vector thread_buffers(num_expert_threads);
+ for (int i = 0; i < num_expert_threads; ++i) {
+ thread_buffers[i].EnsureCapacity(allocator, max_tokens_per_expert, hidden_size, fc1_output_size, inter_size);
+ }
+
+ const size_t output_buffer_size = static_cast(output->Shape().Size());
+ auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size);
+ float* thread_local_outputs = thread_local_outputs_ptr.get();
+
+ // Initialize thread-local outputs with vectorized operation
+ std::fill_n(thread_local_outputs, static_cast(num_expert_threads) * output_buffer_size, 0.0f);
+
+ // Optimized expert processing with thread-local buffer reuse
+ concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) {
+ int thread_id = static_cast(thread_id_pd);
+ auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts));
+
+ float* local_output = thread_local_outputs + static_cast(thread_id) * output_buffer_size;
+ ThreadLocalBuffers& buffers = thread_buffers[thread_id];
+
+ for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) {
+ const auto& routes = expert_token_map[static_cast(expert_idx)];
+ if (routes.empty()) continue;
+
+ const int64_t num_expert_tokens = static_cast(routes.size());
+
+ // Ensure thread-local buffers have sufficient capacity
+ buffers.EnsureCapacity(allocator, num_expert_tokens, hidden_size, fc1_output_size, inter_size);
+
+ // Use pre-allocated buffers from thread-local pool
+ float* A1 = buffers.A1_buffer.get();
+ float* batch_weights = buffers.batch_weights_buffer.get();
+ int64_t* token_ids = buffers.token_ids_buffer.get();
+ T* A1_t = buffers.A1_t_buffer.get();
+ T* C2 = buffers.C2_buffer.get();
+ T* fc1_output = buffers.fc1_output_buffer.get();
+ T* activation_output = buffers.activation_output_buffer.get();
+
+ // Optimized data gathering with better memory access patterns
+ for (int64_t r = 0; r < num_expert_tokens; ++r) {
+ int64_t route_idx = routes[static_cast(r)];
+ int64_t token = route_idx / k_;
+
+ token_ids[r] = token;
+ batch_weights[r] = route_scale[route_idx];
+
+ // Use SIMD-friendly copy for better performance
+ const float* src = input_float + token * hidden_size;
+ float* dst = A1 + static_cast(r) * static_cast(hidden_size);
+ std::copy(src, src + hidden_size, dst);
+ }
+
+ const T* fc1_expert_weights = fc1_weights_data + expert_idx * fc1_output_size * hidden_size;
+ const T* fc1_expert_bias = fc1_bias_data ? fc1_bias_data + expert_idx * fc1_output_size : nullptr;
+ const T* fc2_expert_weights = fc2_weights_data + expert_idx * hidden_size * inter_size;
+ const T* fc2_expert_bias = fc2_bias_data ? fc2_bias_data + expert_idx * hidden_size : nullptr;
+
+ // Convert input to T only when needed for computation
+ for (size_t i = 0; i < static_cast(num_expert_tokens * hidden_size); ++i) {
+ A1_t[i] = static_cast(A1[i]);
+ }
+
+ ORT_IGNORE_RETURN_VALUE(ProcessExpertBatch(A1_t, token_ids, batch_weights,
+ num_expert_tokens, expert_idx,
+ fc1_expert_weights, fc1_expert_bias,
+ fc2_expert_weights, fc2_expert_bias,
+ C2, hidden_size, inter_size,
+ fc1_output, activation_output));
+
+ // Optimized output accumulation with vectorized operations
+ for (int64_t r = 0; r < num_expert_tokens; ++r) {
+ int64_t token = token_ids[r];
+ const T* expert_output_t = C2 + static_cast(r) * static_cast(hidden_size);
+ float w = batch_weights[r];
+ float* dest = local_output + static_cast(token) * static_cast(hidden_size);
+
+ // Use explicit loop for better vectorization opportunities
+ for (int64_t j = 0; j < hidden_size; ++j) {
+ dest[j] += w * static_cast(expert_output_t[j]);
+ }
+ }
+ }
+ });
+
+ auto accumulate = [&](float* buffer) {
+ std::fill_n(buffer, output_buffer_size, 0.0f);
+
+ for (size_t j = 0; j < output_buffer_size; ++j) {
+ double sum = 0.0;
+ double c = 0.0;
+
+ for (int i = 0; i < num_expert_threads; ++i) {
+ const size_t thread_offset = static_cast(i) * output_buffer_size;
+ double y = static_cast(thread_local_outputs[thread_offset + j]) - c;
+ double t = sum + y;
+ c = (t - sum) - y;
+ sum = t;
+ }
+ buffer[j] = static_cast(sum);
+ }
+ };
+
+ if constexpr (std::is_same_v) {
+ auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size);
+ float* final_output_float = final_output_float_ptr.get();
+ accumulate(final_output_float);
+
+ MlasConvertFloatToHalfBuffer(final_output_float,
+ reinterpret_cast(output->MutableData()),
+ static_cast(output_buffer_size));
+ } else {
+ auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size);
+ float* final_output_float = final_output_float_ptr.get();
+ accumulate(final_output_float);
+
+ float* out_ptr = reinterpret_cast(output->MutableData());
+ memcpy(out_ptr, final_output_float, output_buffer_size * sizeof(float));
+ }
+ return Status::OK();
+}
+template
+Status MoE::ProcessExpertBatch(const T* input_tokens,
+ const int64_t* token_expert_ids,
+ const float* token_weights,
+ int64_t batch_size,
+ int64_t expert_id,
+ const T* fc1_weights,
+ const T* fc1_bias,
+ const T* fc2_weights,
+ const T* fc2_bias,
+ T* output_buffer,
+ int64_t hidden_size,
+ int64_t inter_size,
+ T* fc1_output_buffer,
+ T* activation_output_buffer) const {
+ const bool is_swiglu = activation_type_ == ActivationType::SwiGLU;
+ const int64_t fc1_output_size = is_swiglu ? (inter_size * 2) : inter_size;
+
+ constexpr int64_t stack_threshold = 1024;
+ const bool use_stack = (batch_size * fc1_output_size) <= stack_threshold;
+
+ std::vector fc1_output_vec;
+ std::vector activation_output_vec;
+ T* fc1_output;
+ T* activation_output;
+
+ if (use_stack) {
+ fc1_output_vec.resize(static_cast(batch_size * fc1_output_size));
+ activation_output_vec.resize(static_cast(batch_size * inter_size));
+ fc1_output = fc1_output_vec.data();
+ activation_output = activation_output_vec.data();
+ } else {
+ fc1_output_vec.resize(static_cast(batch_size * fc1_output_size));
+ activation_output_vec.resize(static_cast(batch_size * inter_size));
+ fc1_output = fc1_output_vec.data();
+ activation_output = activation_output_vec.data();
+ }
+
+ ORT_RETURN_IF_ERROR(ComputeGEMM(input_tokens, fc1_weights, fc1_output,
+ batch_size, hidden_size, fc1_output_size, true));
+
+ if (fc1_bias) {
+ for (int64_t batch = 0; batch < batch_size; ++batch) {
+ T* batch_output = fc1_output + batch * fc1_output_size;
+ // Explicit loop for better vectorization
+ for (int64_t i = 0; i < fc1_output_size; ++i) {
+ batch_output[i] = static_cast(static_cast(batch_output[i]) +
+ static_cast(fc1_bias[i]));
+ }
+ }
+ }
+
+ if (is_swiglu) {
+ for (int64_t batch = 0; batch < batch_size; ++batch) {
+ ApplySwiGLUVectorized(fc1_output + batch * fc1_output_size,
+ activation_output + batch * inter_size,
+ inter_size);
+ }
+ } else {
+ ApplyActivationVectorized(fc1_output, batch_size * fc1_output_size);
+ std::copy(fc1_output, fc1_output + (batch_size * fc1_output_size), activation_output);
+ }
+
+ ORT_RETURN_IF_ERROR(ComputeGEMM(activation_output, fc2_weights, output_buffer,
+ batch_size, inter_size, hidden_size, true));
+
+ if (fc2_bias) {
+ for (int64_t batch = 0; batch < batch_size; ++batch) {
+ T* batch_output = output_buffer + batch * hidden_size;
+ for (int64_t i = 0; i < hidden_size; ++i) {
+ batch_output[i] = static_cast(static_cast(batch_output[i]) +
+ static_cast(fc2_bias[i]));
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+template <>
+Status MoE::ComputeGEMM(const float* A, const float* B, float* C,
+ int64_t M, int64_t K, int64_t N, bool transpose_B) const {
+ MLAS_SGEMM_DATA_PARAMS params;
+ params.A = A;
+ params.lda = static_cast(K);
+ params.alpha = 1.0f;
+ params.beta = 0.0f;
+ params.C = C;
+ params.ldc = static_cast(N);
+ params.B = B;
+
+ if (transpose_B) {
+ params.ldb = static_cast(K);
+ MlasGemm(CblasNoTrans, CblasTrans, static_cast(M), static_cast(N), static_cast(K), params, nullptr);
+ } else {
+ params.ldb = static_cast(N);
+ MlasGemm(CblasNoTrans, CblasNoTrans, static_cast(M), static_cast(N), static_cast(K), params, nullptr);
+ }
+
+ return Status::OK();
+}
+
+template <>
+Status MoE::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFloat16* C,
+ int64_t M, int64_t K, int64_t N, bool transpose_B) const {
+ MLAS_HALF_GEMM_DATA_PARAMS params;
+ params.A = A;
+ params.lda = static_cast(K);
+ params.C = C;
+ params.ldc = static_cast(N);
+ params.AIsfp32 = false;
+ params.BIsfp32 = false;
+ params.B = B;
+
+ if (transpose_B) {
+ params.ldb = static_cast(K);
+ } else {
+ params.ldb = static_cast(N);
+ }
+
+ MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, nullptr);
+ return Status::OK();
+}
+
+template
+void MoE::ApplyActivationVectorized(T* data, int64_t size) const {
+ for (int64_t i = 0; i < size; ++i) {
+ float val = static_cast(data[i]);
+ data[i] = static_cast(ApplyActivation(val, activation_type_));
+ }
+}
+
+template
+void MoE::ApplySwiGLUVectorized(const T* input, T* output, int64_t size) const {
+ for (int64_t i = 0; i < size; ++i) {
+ float gate = static_cast(input[2 * i]);
+ float linear = static_cast(input[2 * i + 1]);
+
+ gate = std::min(gate, swiglu_limit_);
+ linear = std::clamp(linear, -swiglu_limit_, swiglu_limit_);
+
+ float sigmoid_arg = activation_alpha_ * gate;
+ float sigmoid_out;
+ if (sigmoid_arg > 0) {
+ float exp_neg = std::exp(-sigmoid_arg);
+ sigmoid_out = 1.0f / (1.0f + exp_neg);
+ } else {
+ float exp_pos = std::exp(sigmoid_arg);
+ sigmoid_out = exp_pos / (1.0f + exp_pos);
+ }
+
+ float swish_out = gate * sigmoid_out;
+ output[i] = static_cast(swish_out * (linear + activation_beta_));
+ }
+}
+
+template <>
+void MoE::ApplySwiGLUVectorized(const float* input, float* output, int64_t size) const {
+ ApplySwiGLUActivation(input, output, size, true,
+ activation_alpha_, activation_beta_, swiglu_limit_);
+}
+
+template class MoE;
+template class MoE;
+
+#define REGISTER_KERNEL_TYPED(type) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ MoE, kMSDomain, 1, type, kCpuExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .MayInplace(0, 0) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ MoE);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h
new file mode 100644
index 0000000000000..60d8217015b5b
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h
@@ -0,0 +1,53 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+#include "contrib_ops/cpu/moe/moe_base_cpu.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+class MoE final : public OpKernel, public MoEBaseCPU {
+ public:
+ explicit MoE(const OpKernelInfo& op_kernel_info);
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ Status ComputeMoE(const OpKernelContext* context,
+ const Tensor* input,
+ const Tensor* router_probs,
+ const Tensor* fc1_experts_weights,
+ const Tensor* fc1_experts_bias,
+ const Tensor* fc2_experts_weights,
+ const Tensor* fc2_experts_bias,
+ Tensor* output) const;
+
+ Status ProcessExpertBatch(const T* input_tokens,
+ const int64_t* token_expert_ids,
+ const float* token_weights,
+ int64_t num_tokens,
+ int64_t expert_id,
+ const T* fc1_weights,
+ const T* fc1_bias,
+ const T* fc2_weights,
+ const T* fc2_bias,
+ T* output_buffer,
+ int64_t hidden_size,
+ int64_t inter_size,
+ T* fc1_output_buffer,
+ T* activation_output_buffer) const;
+
+ Status ComputeGEMM(const T* A, const T* B, T* C,
+ int64_t M, int64_t K, int64_t N,
+ bool transpose_B = false) const;
+
+ void ApplyActivationVectorized(T* data, int64_t size) const;
+ void ApplySwiGLUVectorized(const T* input, T* output, int64_t size) const;
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc
index 2c59210bfabd4..5a3c5d1dd0364 100644
--- a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc
+++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc
@@ -37,10 +37,18 @@ void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t
gate_val = std::min(gate_val, clamp_limit);
linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit);
+ // Use numerically stable sigmoid computation (matches CUDA kernel behavior)
float sigmoid_arg = activation_alpha * gate_val;
- float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
- float swish_out = gate_val * sigmoid_out;
+ float sigmoid_out;
+ if (sigmoid_arg > 0) {
+ float exp_neg = std::exp(-sigmoid_arg);
+ sigmoid_out = 1.0f / (1.0f + exp_neg);
+ } else {
+ float exp_pos = std::exp(sigmoid_arg);
+ sigmoid_out = exp_pos / (1.0f + exp_pos);
+ }
+ float swish_out = gate_val * sigmoid_out;
output_data[i] = swish_out * (linear_val + activation_beta);
}
} else {
diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc
index 0690b8894eb7a..ab740ea38fb74 100644
--- a/onnxruntime/test/contrib_ops/moe_test.cc
+++ b/onnxruntime/test/contrib_ops/moe_test.cc
@@ -1690,6 +1690,97 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) {
#endif
}
+// Test for CPU MoE implementation
+static void RunMoECpuTest(const std::vector& input, const std::vector& router_probs,
+ const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights,
+ const std::vector& fc3_experts_weights, const std::vector& fc1_experts_bias,
+ const std::vector& fc2_experts_bias, const std::vector& output_data, int num_rows,
+ int num_experts, int hidden_size, int inter_size, std::string activation_type,
+ int normalize_routing_weights = 1, int top_k = 1) {
+ OpTester tester("MoE", 1, onnxruntime::kMSDomain);
+ tester.AddAttribute("k", static_cast(top_k));
+ tester.AddAttribute("activation_type", activation_type);
+ tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights));
+
+ bool is_swiglu = (activation_type == "swiglu");
+
+ if (is_swiglu) {
+ tester.AddAttribute("swiglu_fusion", static_cast(1));
+ tester.AddAttribute("activation_beta", 1.0f);
+ }
+
+ std::vector input_dims = {num_rows, hidden_size};
+ std::vector router_probs_dims = {num_rows, num_experts};
+
+ int64_t fc1_output_size = is_swiglu ? (2 * inter_size) : inter_size;
+
+ std::vector fc1_experts_weights_dims = {num_experts, hidden_size, fc1_output_size};
+ std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size};
+ std::vector fc3_experts_weights_dims = fc1_experts_weights_dims;
+ std::vector fc1_experts_bias_dims = {num_experts, fc1_output_size};
+ std::vector fc2_experts_bias_dims = {num_experts, hidden_size};
+ std::vector output_dims = {num_rows, hidden_size};
+
+ tester.AddInput("input", input_dims, input);
+ tester.AddInput("router_probs", router_probs_dims, router_probs);
+ tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights);
+ if (!fc1_experts_bias.empty()) {
+ tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias);
+ } else {
+ tester.AddOptionalInputEdge();
+ }
+ tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights);
+ if (!fc2_experts_bias.empty()) {
+ tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias);
+ } else {
+ tester.AddOptionalInputEdge();
+ }
+ if (!fc3_experts_weights.empty()) {
+ tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights);
+ } else {
+ tester.AddOptionalInputEdge();
+ }
+ tester.AddOptionalInputEdge(); // fc3_experts_bias
+
+ tester.AddOutput("output", output_dims, output_data);
+ tester.SetOutputTolerance(0.05f);
+
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCpuExecutionProvider());
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(MoETest, MoECpuTest_BasicSwiGLU) {
+ int num_rows = 2;
+ int num_experts = 2;
+ int hidden_size = 4;
+ int inter_size = 8;
+
+ // Simple test data
+ const std::vector input = {
+ 1.0f, 2.0f, 3.0f, 4.0f,
+ 5.0f, 6.0f, 7.0f, 8.0f};
+
+ const std::vector router_probs = {
+ 0.8f, 0.2f,
+ 0.3f, 0.7f};
+
+ const std::vector fc1_experts_weights(num_experts * hidden_size * (2 * inter_size), 0.1f);
+
+ const std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 0.1f);
+
+ const std::vector fc3_experts_weights = {}; // No FC3
+ const std::vector fc1_experts_bias = {}; // No bias
+ const std::vector fc2_experts_bias = {}; // No bias
+
+ const std::vector output_data = {
+ 1.169694f, 1.169694f, 1.169694f, 1.169694f,
+ 6.970291f, 6.970291f, 6.970291f, 6.970291f};
+
+ RunMoECpuTest(input, router_probs, fc1_experts_weights, fc2_experts_weights,
+ fc3_experts_weights, fc1_experts_bias, fc2_experts_bias, output_data,
+ num_rows, num_experts, hidden_size, inter_size, "swiglu");
+}
#endif
} // namespace test
diff --git a/onnxruntime/test/python/transformers/test_moe_cpu.py b/onnxruntime/test/python/transformers/test_moe_cpu.py
new file mode 100644
index 0000000000000..d6cbcc64733d4
--- /dev/null
+++ b/onnxruntime/test/python/transformers/test_moe_cpu.py
@@ -0,0 +1,473 @@
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+#
+# Regular MoE CPU kernel testing implementation - SwiGLU Interleaved Only
+#
+# This file tests the non-quantized MoE CPU implementation with SwiGLU
+# activation in interleaved format and validates parity between
+# PyTorch reference implementation and ONNX Runtime CPU kernel.
+#
+# Based on the CUDA test structure for consistency.
+# --------------------------------------------------------------------------
+
+import itertools
+import time
+import unittest
+
+import numpy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from onnx import TensorProto, helper
+from parameterized import parameterized
+
+from onnxruntime import InferenceSession, SessionOptions
+
+# Device and provider settings for CPU
+device = torch.device("cpu")
+ort_provider = ["CPUExecutionProvider"]
+
+torch.manual_seed(42)
+numpy.random.seed(42)
+
+onnx_to_torch_type_map = {
+ TensorProto.FLOAT16: torch.float16,
+ TensorProto.FLOAT: torch.float,
+ TensorProto.BFLOAT16: torch.bfloat16,
+ TensorProto.UINT8: torch.uint8,
+}
+
+ort_to_numpy_type_map = {
+ TensorProto.FLOAT16: numpy.float16,
+ TensorProto.FLOAT: numpy.float32,
+ TensorProto.UINT8: numpy.uint8,
+}
+
+ort_dtype_name_map = {
+ TensorProto.FLOAT16: "FP16",
+ TensorProto.FLOAT: "FP32",
+ TensorProto.BFLOAT16: "BF16",
+}
+
+
+class SwigluMoeConfig:
+ def __init__(
+ self,
+ hidden_size=2048,
+ intermediate_size=2048,
+ num_experts_per_token=2,
+ num_local_experts=8,
+ ):
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_experts_per_token = num_experts_per_token
+ self.num_local_experts = num_local_experts
+
+
+def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0):
+ x_glu = x[..., ::2]
+ x_linear = x[..., 1::2]
+
+ if limit is not None:
+ x_glu = x_glu.clamp(max=limit)
+ x_linear = x_linear.clamp(min=-limit, max=limit)
+
+ y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1)
+ return y
+
+
+class SwigluMlp(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_size = config.intermediate_size
+ self.hidden_dim = config.hidden_size
+ self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True)
+ self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True)
+
+ def forward(self, x):
+ x1 = self.w1(x)
+ y = swiglu(x1)
+ y = self.w2(y)
+ return y
+
+
+def make_onnx_intializer(name: str, tensor: torch.Tensor, shape, onnx_dtype):
+ torch_dtype = onnx_to_torch_type_map[onnx_dtype]
+ if torch_dtype == torch.bfloat16:
+ numpy_vals_uint16 = tensor.to(torch.bfloat16).cpu().view(torch.uint16).numpy()
+ initializer = helper.make_tensor(
+ name=name,
+ data_type=TensorProto.BFLOAT16,
+ dims=shape,
+ vals=numpy_vals_uint16.tobytes(),
+ raw=True,
+ )
+ else:
+ initializer = helper.make_tensor(
+ name=name,
+ data_type=onnx_dtype,
+ dims=shape,
+ vals=tensor.flatten().detach().cpu().numpy().astype(numpy.uint8).tolist()
+ if onnx_dtype == TensorProto.UINT8
+ else tensor.detach().to(torch_dtype).flatten().tolist(),
+ raw=False,
+ )
+ return initializer
+
+
+def create_swiglu_moe_onnx_graph(
+ num_tokens: int,
+ num_experts: int,
+ hidden_size: int,
+ inter_size: int,
+ topk: int,
+ onnx_dtype: int,
+ quant_bits: int,
+ fc1_experts_weights: torch.Tensor,
+ fc1_experts_bias: torch.Tensor,
+ fc2_experts_weights: torch.Tensor,
+ fc2_experts_bias: torch.Tensor,
+ fc1_experts_weight_scale: torch.Tensor = None,
+ fc2_experts_weight_scale: torch.Tensor = None,
+):
+ use_quant = quant_bits > 0
+ op_name = "QMoE" if use_quant else "MoE"
+
+ inputs = (
+ [
+ "input",
+ "router_probs",
+ "fc1_experts_weights",
+ "fc1_experts_weight_scale",
+ "fc1_experts_bias",
+ "fc2_experts_weights",
+ "fc2_experts_weight_scale",
+ "fc2_experts_bias",
+ ]
+ if use_quant
+ else [
+ "input",
+ "router_probs",
+ "fc1_experts_weights",
+ "fc1_experts_bias",
+ "fc2_experts_weights",
+ "fc2_experts_bias",
+ ]
+ )
+
+ nodes = [
+ helper.make_node(
+ op_name,
+ inputs,
+ ["output"],
+ "MoE_0",
+ k=topk,
+ normalize_routing_weights=1,
+ activation_type="swiglu",
+ swiglu_fusion=1,
+ activation_alpha=1.702,
+ activation_beta=1.0,
+ domain="com.microsoft",
+ ),
+ ]
+
+ if use_quant:
+ nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)])
+
+ components = 2 if quant_bits == 4 else 1
+ fc1_weight_shape = [num_experts, 2 * inter_size, hidden_size // components]
+ fc1_bias_shape = [num_experts, 2 * inter_size]
+ fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size]
+
+ fc2_weight_shape = [num_experts, hidden_size, inter_size // components]
+ fc2_bias_shape = [num_experts, hidden_size]
+ fc2_experts_weight_scale_shape = [num_experts, hidden_size]
+
+ weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype
+
+ torch_dtype = onnx_to_torch_type_map[onnx_dtype]
+ weight_torch_dtype = onnx_to_torch_type_map[weight_onnx_type]
+
+ initializers = [
+ make_onnx_intializer(
+ "fc1_experts_weights", fc1_experts_weights.to(weight_torch_dtype), fc1_weight_shape, weight_onnx_type
+ ),
+ make_onnx_intializer("fc1_experts_bias", fc1_experts_bias.to(torch_dtype), fc1_bias_shape, onnx_dtype),
+ make_onnx_intializer(
+ "fc2_experts_weights", fc2_experts_weights.to(weight_torch_dtype), fc2_weight_shape, weight_onnx_type
+ ),
+ make_onnx_intializer("fc2_experts_bias", fc2_experts_bias.to(torch_dtype), fc2_bias_shape, onnx_dtype),
+ ]
+
+ if use_quant:
+ initializers.extend(
+ [
+ make_onnx_intializer(
+ "fc1_experts_weight_scale",
+ fc1_experts_weight_scale.to(torch_dtype),
+ fc1_experts_weight_scale_shape,
+ onnx_dtype,
+ ),
+ make_onnx_intializer(
+ "fc2_experts_weight_scale",
+ fc2_experts_weight_scale.to(torch_dtype),
+ fc2_experts_weight_scale_shape,
+ onnx_dtype,
+ ),
+ ]
+ )
+
+ graph_inputs = [
+ helper.make_tensor_value_info("input", onnx_dtype, [num_tokens, hidden_size]),
+ ]
+
+ graph_inputs.append(
+ helper.make_tensor_value_info(
+ "router_probs",
+ onnx_dtype,
+ [num_tokens, num_experts],
+ )
+ )
+
+ graph_outputs = [
+ helper.make_tensor_value_info("output", onnx_dtype, [num_tokens, hidden_size]),
+ ]
+
+ graph = helper.make_graph(
+ nodes,
+ "MoE_Graph",
+ graph_inputs,
+ graph_outputs,
+ initializers,
+ )
+
+ model = helper.make_model(graph)
+ return model.SerializeToString()
+
+
+class SparseMoeBlockORTHelper(nn.Module):
+ def __init__(self, quant_bits=0, onnx_dtype=None):
+ super().__init__()
+ self.quant_bits = quant_bits
+ if onnx_dtype is None:
+ self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT
+ else:
+ self.onnx_dtype = onnx_dtype
+ self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32
+
+ def create_ort_session(self, moe_onnx_graph):
+ sess_options = SessionOptions()
+ sess_options.log_severity_level = 2
+
+ try:
+ ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider)
+ except Exception as e:
+ print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}")
+ print("Skipping ONNX Runtime execution for this test case.")
+ return None
+
+ return ort_session
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ pass
+
+ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor:
+ if self.ort_sess is None:
+ return None
+
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states_flat = hidden_states.view(-1, hidden_dim)
+ router_logits = self.gate(hidden_states_flat)
+
+ torch_dtype = onnx_to_torch_type_map[self.onnx_dtype]
+
+ tensors = {
+ "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype),
+ "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype),
+ "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype),
+ }
+
+ ort_inputs = {
+ "input": tensors["input"].detach().cpu().numpy(),
+ "router_probs": tensors["router_probs"].detach().cpu().numpy(),
+ }
+
+ if enable_performance_test:
+ repeat = 1000
+ s = time.time()
+ for _ in range(repeat):
+ self.ort_sess.run(None, ort_inputs)
+ e = time.time()
+ print(f"MoE CPU kernel time: {(e - s) / repeat * 1000} ms")
+ ort_outputs = self.ort_sess.run(None, ort_inputs)
+ else:
+ ort_outputs = self.ort_sess.run(None, ort_inputs)
+
+ output_tensor = torch.from_numpy(ort_outputs[0]).to(device)
+
+ return output_tensor.reshape(batch_size, sequence_length, hidden_dim)
+
+ def parity_check(self):
+ hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device)
+ torch_output = self.forward(hidden_state)
+ ort_output = self.ort_forward(hidden_state)
+
+ dtype_str = ort_dtype_name_map[self.onnx_dtype]
+
+ ort_dtype_quant_bits_tolerance_map = {
+ "FP32:0": (5e-3, 1e-3),
+ }
+
+ atol, rtol = ort_dtype_quant_bits_tolerance_map[f"{dtype_str}:{self.quant_bits}"]
+ if ort_output is not None:
+ print(
+ f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str},"
+ f" batch: {self.batch_size}, seq_len: {self.sequence_length},"
+ f" max_diff: {(torch_output.cpu() - ort_output.cpu()).abs().max()}"
+ )
+ torch.testing.assert_close(
+ ort_output.cpu().to(torch.float32), torch_output.cpu().to(torch.float32), rtol=rtol, atol=atol
+ )
+
+ def benchmark_ort(self):
+ hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device)
+ self.ort_forward(hidden_state, enable_performance_test=True)
+
+
+class SwigluMoEBlock(SparseMoeBlockORTHelper):
+ def __init__(
+ self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None
+ ):
+ super().__init__(quant_bits, onnx_dtype=onnx_dtype)
+ self.hidden_dim = config.hidden_size
+ self.ffn_dim = config.intermediate_size
+ self.num_experts = config.num_local_experts
+ self.top_k = config.num_experts_per_token
+
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True)
+
+ self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)])
+
+ fc1_w_list, fc2_w_list = [], []
+ fc1_b_list, fc2_b_list = [], []
+
+ for expert in self.experts:
+ w1_weight = expert.w1.weight.data.clone()
+ w2_weight = expert.w2.weight.data.clone()
+ w1_bias = expert.w1.bias.data.clone()
+ w2_bias = expert.w2.bias.data.clone()
+
+ fc1_w_list.append(w1_weight)
+ fc2_w_list.append(w2_weight)
+ fc1_b_list.append(w1_bias)
+ fc2_b_list.append(w2_bias)
+
+ fc1_experts_weights = torch.stack(fc1_w_list, dim=0)
+ fc2_experts_weights = torch.stack(fc2_w_list, dim=0)
+ fc1_experts_bias = torch.stack(fc1_b_list, dim=0)
+ fc2_experts_bias = torch.stack(fc2_b_list, dim=0)
+
+ self.batch_size = batch_size
+ self.sequence_length = sequence_length
+
+ self.moe_onnx_graph = create_swiglu_moe_onnx_graph(
+ num_tokens=self.batch_size * self.sequence_length,
+ num_experts=self.num_experts,
+ hidden_size=self.hidden_dim,
+ inter_size=self.ffn_dim,
+ topk=self.top_k,
+ onnx_dtype=self.onnx_dtype,
+ quant_bits=self.quant_bits,
+ fc1_experts_weights=fc1_experts_weights,
+ fc1_experts_bias=fc1_experts_bias,
+ fc2_experts_weights=fc2_experts_weights,
+ fc2_experts_bias=fc2_experts_bias,
+ )
+
+ self.ort_sess = self.create_ort_session(self.moe_onnx_graph)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ router_logits = self.gate(hidden_states)
+
+ # Compute full softmax over all experts (same as CUDA)
+ full_probs = F.softmax(router_logits, dim=-1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(full_probs, self.top_k, dim=-1)
+
+ # For normalize_routing_weights=1: normalize by sum of top-k values (same as CUDA)
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ final_hidden_states = torch.zeros(
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ for expert_idx in range(self.num_experts):
+ expert_layer = self.experts[expert_idx]
+ idx, top_x = torch.where(expert_mask[expert_idx])
+
+ if top_x.shape[0] == 0:
+ continue
+
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
+
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
+
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+ return final_hidden_states
+
+
+swiglu_test_cases = list(
+ itertools.product(
+ [1, 2], # batch_size
+ [16, 32], # sequence_length
+ [0], # quant_bits (CPU kernel only supports float32)
+ )
+)
+
+perf_test_cases = list(
+ itertools.product(
+ [1], # batch_size
+ [128], # sequence_length
+ [0], # quant_bits (CPU kernel only supports float32)
+ )
+)
+
+
+class TestSwigluMoECPU(unittest.TestCase):
+ @parameterized.expand(swiglu_test_cases)
+ def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits):
+ config = SwigluMoeConfig(hidden_size=64, intermediate_size=256, num_experts_per_token=2, num_local_experts=4)
+ moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits)
+ moe.to(device)
+ moe.parity_check()
+
+
+class TestSwigluMoECPUPerf(unittest.TestCase):
+ @parameterized.expand(perf_test_cases)
+ def test_swiglu_moe_perf(self, batch_size, sequence_length, quant_bits):
+ hidden_size = 1024
+ intermediate_size = 2048
+ num_experts_per_token = 4
+ num_local_experts = 16
+ config = SwigluMoeConfig(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_experts_per_token=num_experts_per_token,
+ num_local_experts=num_local_experts,
+ )
+ moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits)
+ moe.to(device)
+ moe.benchmark_ort()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py
index efaaca29a01b6..403becbe0616a 100644
--- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py
+++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py
@@ -254,7 +254,8 @@ def create_cpu_moe_onnx_graph(
inter_size = intermediate_size
topk = top_k
- use_quant = True
+ # Only override use_quant for backward compatibility if not explicitly set
+ # use_quant = True # This line was causing issues for regular MoE tests
if fc1_scales is None and use_quant:
return None
@@ -263,30 +264,52 @@ def create_cpu_moe_onnx_graph(
if not has_onnx:
return None
- assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE"
- assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE"
- assert fc1_scales is not None, "FC1 scales must be provided for QMoE"
- assert fc2_scales is not None, "FC2 scales must be provided for QMoE"
- assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE"
- assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE"
+ if use_quant:
+ # Assertions only apply to quantized MoE
+ assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE"
+ assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE"
+ assert fc1_scales is not None, "FC1 scales must be provided for QMoE"
+ assert fc2_scales is not None, "FC2 scales must be provided for QMoE"
+ assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE"
+ assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE"
if not has_onnx:
return None
- op_name = "QMoE"
- inputs = [
- "input",
- "router_probs",
- "fc1_experts_weights",
- "fc1_scales",
- "",
- "fc2_experts_weights",
- "fc2_scales",
- "",
- ]
+ # Set operator name and inputs based on quantization mode
+ if use_quant:
+ op_name = "QMoE"
+ inputs = [
+ "input",
+ "router_probs",
+ "fc1_experts_weights",
+ "fc1_scales",
+ "",
+ "fc2_experts_weights",
+ "fc2_scales",
+ "",
+ ]
+ else:
+ # For regular (non-quantized) MoE, use different operator and input layout
+ op_name = "MoE" # Regular MoE operator
+ inputs = [
+ "input",
+ "router_probs",
+ "fc1_experts_weights",
+ "fc1_experts_bias" if fc1_bias is not None else "", # fc1_bias as input 3
+ "fc2_experts_weights",
+ "fc2_experts_bias" if fc2_bias is not None else "", # fc2_bias as input 5
+ "", # fc3_experts_weights (not used)
+ "", # fc3_experts_bias (not used)
+ ]
activation = "swiglu" if use_swiglu else "silu"
+ # Set normalization behavior based on operator type:
+ # - QMoE: Raw logits passed, needs normalization in C++ kernel
+ # - Regular MoE: Pre-computed probabilities passed, no additional normalization needed
+ normalize_routing = 1 if use_quant else 0
+
nodes = [
helper.make_node(
op_name,
@@ -294,13 +317,14 @@ def create_cpu_moe_onnx_graph(
["output"],
"MoE_0",
k=topk,
- normalize_routing_weights=1, # Use proper routing normalization to match PyTorch behavior
+ normalize_routing_weights=normalize_routing,
activation_type=activation,
# Add new attributes with backwards-compatible default values
- swiglu_fusion=1 if (use_swiglu and swiglu_interleaved) else 0, # 1 = fused and interleaved
+ swiglu_fusion=1 if use_swiglu else 0, # 1 if using SwiGLU activation
swiglu_limit=7.0,
activation_alpha=1.702,
activation_beta=1.0,
+ swiglu_interleaved=1 if swiglu_interleaved else 0, # Enable this attribute
domain="com.microsoft",
),
]
@@ -339,79 +363,106 @@ def create_cpu_moe_onnx_graph(
fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size)
fc2_scale_size = num_experts * hidden_size
- # Handle scale tensors - fc1_scales and fc2_scales are guaranteed to be not None due to earlier assertions
- # Handle different possible scale tensor structures for fc1_scales
- if len(fc1_scales.shape) == 4:
- # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output
- if use_swiglu:
- fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy()
+ # Handle scale tensors based on quantization mode
+ if use_quant:
+ # Handle different possible scale tensor structures for fc1_scales
+ if len(fc1_scales.shape) == 4:
+ # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output
+ if use_swiglu:
+ fc1_scale_tensor = (
+ fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy()
+ )
+ else:
+ fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy()
+ elif len(fc1_scales.shape) == 2:
+ # 2D case: already flattened, just ensure correct size
+ fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy()
+ if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size:
+ # For SwiGLU, duplicate the scales to cover both gate and value components
+ fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten()
+ elif fc1_scale_tensor.size > fc1_scale_size:
+ # Truncate to expected size
+ fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size]
else:
- fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy()
- elif len(fc1_scales.shape) == 2:
- # 2D case: already flattened, just ensure correct size
- fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy()
- if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size:
- # For SwiGLU, duplicate the scales to cover both gate and value components
- fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten()
- elif fc1_scale_tensor.size > fc1_scale_size:
- # Truncate to expected size
- fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size]
- else:
- # Other cases: flatten and truncate/pad as needed
- fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy()
- if fc1_scale_tensor.size > fc1_scale_size:
- fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size]
- elif fc1_scale_tensor.size < fc1_scale_size:
- # Pad with ones if too small
- pad_size = fc1_scale_size - fc1_scale_tensor.size
- fc1_scale_tensor = numpy.concatenate([fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)])
-
- # Process scale tensor for proper shape
- fc1_scale_data_list = fc1_scale_tensor.tolist()
- fc1_scale_data = fc1_scale_data_list
-
- # Handle different possible scale tensor structures for fc2_scales
- if len(fc2_scales.shape) == 4:
- # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output
- fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy()
- elif len(fc2_scales.shape) == 2:
- # 2D case: already flattened, just ensure correct size
- fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy()
- if fc2_scale_tensor.size > fc2_scale_size:
- # Truncate to expected size
- fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size]
+ # Other cases: flatten and truncate/pad as needed
+ fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy()
+ if fc1_scale_tensor.size > fc1_scale_size:
+ fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size]
+ elif fc1_scale_tensor.size < fc1_scale_size:
+ # Pad with ones if too small
+ pad_size = fc1_scale_size - fc1_scale_tensor.size
+ fc1_scale_tensor = numpy.concatenate(
+ [fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)]
+ )
+
+ # Process scale tensor for proper shape
+ fc1_scale_data_list = fc1_scale_tensor.tolist()
+ fc1_scale_data = fc1_scale_data_list
+
+ # Handle different possible scale tensor structures for fc2_scales
+ if len(fc2_scales.shape) == 4:
+ # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output
+ fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy()
+ elif len(fc2_scales.shape) == 2:
+ # 2D case: already flattened, just ensure correct size
+ fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy()
+ if fc2_scale_tensor.size > fc2_scale_size:
+ # Truncate to expected size
+ fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size]
+ else:
+ # Other cases: flatten and truncate/pad as needed
+ fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy()
+ if fc2_scale_tensor.size > fc2_scale_size:
+ fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size]
+ elif fc2_scale_tensor.size < fc2_scale_size:
+ # Pad with ones if too small
+ pad_size = fc2_scale_size - fc2_scale_tensor.size
+ fc2_scale_tensor = numpy.concatenate(
+ [fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)]
+ )
+
+ # Process scale tensor for proper shape
+ fc2_scale_data_list = fc2_scale_tensor.tolist()
+ fc2_scale_data = fc2_scale_data_list
+
+ initializers.extend(
+ [
+ helper.make_tensor(
+ "fc1_scales",
+ onnx_dtype,
+ fc1_scale_shape,
+ fc1_scale_data,
+ raw=False,
+ ),
+ helper.make_tensor(
+ "fc2_scales",
+ onnx_dtype,
+ fc2_scale_shape,
+ fc2_scale_data,
+ raw=False,
+ ),
+ ]
+ )
else:
- # Other cases: flatten and truncate/pad as needed
- fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy()
- if fc2_scale_tensor.size > fc2_scale_size:
- fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size]
- elif fc2_scale_tensor.size < fc2_scale_size:
- # Pad with ones if too small
- pad_size = fc2_scale_size - fc2_scale_tensor.size
- fc2_scale_tensor = numpy.concatenate([fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)])
-
- # Process scale tensor for proper shape
- fc2_scale_data_list = fc2_scale_tensor.tolist()
- fc2_scale_data = fc2_scale_data_list
-
- initializers.extend(
- [
- helper.make_tensor(
- "fc1_scales",
- onnx_dtype,
- fc1_scale_shape,
- fc1_scale_data,
- raw=False,
- ),
- helper.make_tensor(
- "fc2_scales",
- onnx_dtype,
- fc2_scale_shape,
- fc2_scale_data,
- raw=False,
- ),
- ]
- )
+ # For non-quantized mode, add bias tensors if provided
+ if fc1_bias is not None:
+ initializers.append(
+ helper.make_tensor(
+ "fc1_experts_bias",
+ onnx_dtype,
+ list(fc1_bias.shape),
+ fc1_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(),
+ )
+ )
+ if fc2_bias is not None:
+ initializers.append(
+ helper.make_tensor(
+ "fc2_experts_bias",
+ onnx_dtype,
+ list(fc2_bias.shape),
+ fc2_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(),
+ )
+ )
graph_inputs = [
helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]),
@@ -619,17 +670,54 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor:
if self.ort_sess is None:
+ print(f"ERROR: ORT session is None for {self.__class__.__name__}")
return None
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states_flat)
+ # Different routing logic for QMoE vs regular MoE:
+ # - QMoE expects raw logits (does its own softmax internally)
+ # - Regular MoE expects pre-computed routing probabilities
+ if hasattr(self, "quant_bits") and self.quant_bits > 0:
+ # QMoE: Pass raw logits directly (QMoE does softmax internally)
+ router_input = router_logits
+ # print("DEBUG: Using QMoE routing (raw logits)")
+ else:
+ # Regular MoE: Apply the same routing logic as PyTorch reference
+ # This converts raw logits to proper routing probabilities
+ routing_weights, selected_experts = masked_sampling_omp_inference(
+ router_logits,
+ top_k=self.top_k,
+ jitter_eps=self.router_jitter_noise,
+ training=False,
+ )
+
+ # IMPORTANT: The routing weights from masked_sampling_omp_inference sum to top_k,
+ # but ONNX Runtime expects normalized probabilities that sum to 1.0
+ # Normalize the routing weights per token
+ routing_weights = routing_weights / routing_weights.sum(dim=1, keepdim=True)
+
+ # Create proper router probabilities tensor that matches PyTorch routing
+ router_input = torch.zeros_like(router_logits)
+ for i in range(router_logits.shape[0]): # For each token
+ for j in range(self.top_k): # For each top-k expert
+ expert_idx = selected_experts[i, j]
+ router_input[i, expert_idx] = routing_weights[i, j]
+
+ # print("DEBUG: Using regular MoE routing (processed probabilities)")
+
+ # print(f"DEBUG: router_input stats: mean={router_input.mean():.6f}, std={router_input.std():.6f}")
+ # print(
+ # f"DEBUG: hidden_states_flat stats: mean={hidden_states_flat.mean():.6f}, std={hidden_states_flat.std():.6f}"
+ # )
+
torch_dtype = onnx_to_torch_type_map[self.onnx_dtype]
tensors = {
"input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype),
- "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype),
+ "router_probs": router_input.clone().to(device=device, dtype=torch_dtype),
"output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype),
}
@@ -656,10 +744,14 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False
buffer_ptr=tensor.data_ptr(),
)
+ # print("DEBUG: About to run ORT inference...")
+
iobinding.synchronize_inputs()
self.ort_sess.run_with_iobinding(iobinding)
iobinding.synchronize_outputs()
+ # print("DEBUG: ORT inference completed successfully")
+
if enable_performance_test:
repeat = 100
s = time.time()
@@ -691,28 +783,10 @@ def recreate_onnx_model(self):
w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit)
if self.use_swiglu:
- if self.swiglu_interleaved:
- pass
- else:
- w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit)
-
- gate_weights = pre_qweight1
- value_weights = pre_qweight3
- gate_scales = w1_scale
- value_scales = w3_scale
-
- pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0)
- w1_scale = torch.cat([gate_scales, value_scales], dim=0)
-
- if self.swiglu_interleaved:
- self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone())
-
- else:
- intermediate_size = self.experts[i].w1.weight.shape[0]
- gate_dequant = w1_qdq[:intermediate_size].contiguous().clone()
- value_dequant = w1_qdq[intermediate_size:].contiguous().clone()
- self.experts[i].w1.weight.data = gate_dequant
- self.experts[i].w3.weight.data = value_dequant
+ # For SwiGLU, CPU kernel now always expects interleaved format
+ # SwigluMlp weights are already in interleaved format [gate_0, linear_0, gate_1, linear_1, ...]
+ # No conversion needed - both CPU and CUDA use interleaved format
+ self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone())
else:
self.experts[i].w1.weight.data = w1_qdq.contiguous().clone()
@@ -754,7 +828,7 @@ def recreate_onnx_model(self):
use_swiglu=self.use_swiglu,
use_quant=True, # Always use QMoE
quant_bits=self.quant_bits,
- swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False,
+ swiglu_interleaved=True, # CPU kernel now always expects interleaved format
)
except Exception:
self.moe_onnx_graph = None
@@ -1043,10 +1117,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class TestPhiQMoECPU(unittest.TestCase):
@parameterized.expand(phi3_test_cases)
def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits):
- torch.manual_seed(42)
- numpy.random.seed(42)
+ # Create unique seed based on test parameters to ensure different inputs for each test
+ base_seed = 2000 # Different base seed from other tests
+ param_hash = hash((batch_size, sequence_length, quant_bits))
+ unique_seed = base_seed + abs(param_hash) % 1000
- test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}"
+ torch.manual_seed(unique_seed)
+ numpy.random.seed(unique_seed)
+
+ test_config = (
+ f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, seed={unique_seed}"
+ )
print(f"Running Phi3 QMoE test: {test_config}")
config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2)
@@ -1086,10 +1167,17 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits):
class TestSwigluQMoECPU(unittest.TestCase):
@parameterized.expand(swiglu_test_cases)
def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits):
- torch.manual_seed(42)
- numpy.random.seed(42)
+ # Create unique seed based on test parameters to ensure different inputs for each test
+ base_seed = 1000 # Different base seed from regular MoE tests
+ param_hash = hash((batch_size, sequence_length, quant_bits))
+ unique_seed = base_seed + abs(param_hash) % 1000
+
+ torch.manual_seed(unique_seed)
+ numpy.random.seed(unique_seed)
- test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}"
+ test_config = (
+ f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, seed={unique_seed}"
+ )
print(f"Running SwiGLU test: {test_config}")
config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2)
@@ -1114,5 +1202,173 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits):
swiglu_moe.parity_check()
+@unittest.skipIf(True, "Skipping QMoE CPU benchmark tests")
+class TestQMoESwiGLUBenchmark(unittest.TestCase):
+ """Benchmark tests for QMoE SwiGLU performance measurement."""
+
+ def test_qmoe_swiglu_throughput_benchmark(self):
+ """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations."""
+ if disable_cpu_qmoe_tests:
+ self.skipTest("QMoE CPU tests disabled")
+
+ print("\n=== QMoE SwiGLU Throughput Benchmark ===")
+
+ # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits)
+ configs = [
+ ("Medium-4bit", 2880, 2880, 32, 4, 4),
+ ("Medium-8bit", 2880, 2880, 32, 4, 8),
+ ]
+
+ batch_size = 1
+ sequence_length = 512
+ num_runs = 30
+
+ results = []
+
+ for config_name, hidden_size, intermediate_size, num_experts, top_k, quant_bits in configs:
+ torch.manual_seed(42)
+ numpy.random.seed(42)
+
+ print(f"\nTesting {config_name}:")
+ print(f" Hidden: {hidden_size}, Intermediate: {intermediate_size}")
+ print(f" Experts: {num_experts}, Top-K: {top_k}, Quant: {quant_bits}-bit")
+
+ try:
+ # Create config and model
+ config = PhiMoEConfig(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_local_experts=num_experts,
+ num_experts_per_tok=top_k,
+ )
+
+ qmoe_swiglu = PhiMoESparseMoeBlock(
+ config,
+ batch_size=batch_size,
+ sequence_length=sequence_length,
+ quant_bits=quant_bits,
+ onnx_dtype=TensorProto.FLOAT,
+ )
+
+ # Create test input with fixed sequence length to match ONNX model
+ full_hidden_states = torch.randn(batch_size, sequence_length, hidden_size).to(torch.float32)
+
+ # For TTFT simulation, we'll measure single forward pass time
+ # This represents the time to process one token in autoregressive generation
+
+ # Initialize variables
+ torch_output = None
+ ort_output = None
+
+ # Warm up with full context
+ for _ in range(3):
+ _ = qmoe_swiglu.forward(full_hidden_states)
+
+ # Benchmark PyTorch TTFT (Time to First Token)
+ # Measure time for a single forward pass (represents token generation time)
+ torch.manual_seed(42)
+
+ start_time = time.time()
+ for _ in range(num_runs):
+ torch_output = qmoe_swiglu.forward(full_hidden_states)
+ end_time = time.time()
+ torch_ttft_ms = (end_time - start_time) / num_runs * 1000
+
+ # Calculate tokens per second (throughput)
+ # For sequence generation, this represents the rate at which we can generate tokens
+ torch_tokens_per_sec = 1000.0 / torch_ttft_ms # 1 token / (time_ms / 1000)
+
+ print(f" PyTorch TTFT: {torch_ttft_ms:.3f} ms (per token generation time)")
+ print(f" PyTorch Throughput: {torch_tokens_per_sec:.1f} tokens/sec")
+
+ # Benchmark ONNX Runtime
+ ort_ttft_ms = 0
+ ort_tokens_per_sec = 0
+ speedup = 0
+ throughput_ratio = 0
+ max_diff = 0
+
+ model_updated = qmoe_swiglu.recreate_onnx_model()
+ if model_updated and qmoe_swiglu.ort_sess is not None:
+ # Warm up ORT with full context
+ for _ in range(3):
+ _ = qmoe_swiglu.ort_forward(full_hidden_states)
+
+ torch.manual_seed(42)
+
+ # Measure ONNX Runtime TTFT (Time to First Token)
+ start_time = time.time()
+ for _ in range(num_runs):
+ ort_output = qmoe_swiglu.ort_forward(full_hidden_states)
+ end_time = time.time()
+ ort_ttft_ms = (end_time - start_time) / num_runs * 1000
+
+ # Calculate tokens per second for ONNX Runtime
+ ort_tokens_per_sec = 1000.0 / ort_ttft_ms # 1 token / (time_ms / 1000)
+
+ speedup = torch_ttft_ms / ort_ttft_ms if ort_ttft_ms > 0 else 0
+ throughput_ratio = ort_tokens_per_sec / torch_tokens_per_sec if torch_tokens_per_sec > 0 else 0
+
+ print(f" ONNX RT TTFT: {ort_ttft_ms:.3f} ms (per token generation time)")
+ print(f" ONNX RT Throughput: {ort_tokens_per_sec:.1f} tokens/sec")
+ print(f" TTFT Speedup: {speedup:.2f}x")
+ print(f" Throughput Gain: {throughput_ratio:.2f}x")
+ else:
+ print(" ONNX RT: Not available")
+
+ # Calculate max difference if both outputs available
+ if torch_output is not None and ort_output is not None:
+ max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max().item()
+ print(f" Max diff: {max_diff:.6f}")
+
+ results.append(
+ {
+ "config": config_name,
+ "torch_ttft_ms": torch_ttft_ms,
+ "torch_tokens_per_sec": torch_tokens_per_sec,
+ "ort_ttft_ms": ort_ttft_ms,
+ "ort_tokens_per_sec": ort_tokens_per_sec,
+ "speedup": speedup,
+ "throughput_ratio": throughput_ratio,
+ "max_diff": max_diff,
+ }
+ )
+
+ except Exception as e:
+ print(f" Error: {e}")
+ continue
+
+ # Summary
+ print("\n=== Token Generation Time & Throughput Summary ===")
+ print(
+ f"{'Config':<15} {'PT Time':<10} {'PT tok/s':<10} {'ORT Time':<11} {'ORT tok/s':<11} {'Time Gain':<10} {'Throughput':<11} {'Max Diff':<10}"
+ )
+ print("-" * 105)
+ for result in results:
+ config = result["config"]
+ torch_ttft = result["torch_ttft_ms"]
+ torch_tps = result["torch_tokens_per_sec"]
+ ort_ttft = result["ort_ttft_ms"]
+ ort_tps = result["ort_tokens_per_sec"]
+ speedup = result["speedup"]
+ throughput_ratio = result["throughput_ratio"]
+ max_diff = result["max_diff"]
+
+ ort_ttft_str = f"{ort_ttft:.3f}" if ort_ttft > 0 else "N/A"
+ ort_tps_str = f"{ort_tps:.1f}" if ort_tps > 0 else "N/A"
+ speedup_str = f"{speedup:.2f}x" if speedup > 0 else "N/A"
+ throughput_str = f"{throughput_ratio:.2f}x" if throughput_ratio > 0 else "N/A"
+
+ print(
+ f"{config:<15} {torch_ttft:<10.3f} {torch_tps:<10.1f} {ort_ttft_str:<11} {ort_tps_str:<11} {speedup_str:<10} {throughput_str:<11} {max_diff:<10.6f}"
+ )
+
+ print("\nNotes:")
+ print("- Time: Token generation time in ms (lower is better)")
+ print("- tok/s: Tokens per second throughput (higher is better)")
+ print("- Time Gain: ORT speedup for latency (higher is better)")
+ print("- Throughput: ORT throughput improvement (higher is better)")
+
+
if __name__ == "__main__":
unittest.main()