diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 5a7934fd37278..e42181f61a771 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -547,6 +547,7 @@ Do not modify directly.* |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(float16), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 34410a5f42630..d959d11e3fd43 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -108,6 +108,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MoE); // ******** End: Quantization ******************* // #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -275,6 +276,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h index eae96c186d471..84580b310f6b3 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h @@ -6,7 +6,7 @@ #include "core/common/common.h" #include "core/framework/tensor_shape.h" #include "core/framework/op_kernel.h" -#include "contrib_ops/cpu/moe/moe_helper.h" +#include "moe_helper.h" #include namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc new file mode 100644 index 0000000000000..73be099181aa9 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc @@ -0,0 +1,605 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/moe/moe_cpu.h" +#include "contrib_ops/cpu/moe/moe_utils.h" +#include "contrib_ops/cpu/moe/moe_helper.h" +#include "core/framework/op_kernel.h" +#include "core/providers/common.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "core/util/math_cpuonly.h" +#include "core/mlas/inc/mlas.h" +#include "core/framework/float16.h" +#include "core/framework/allocator.h" +#include "core/platform/threadpool.h" + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +template +MoE::MoE(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), MoEBaseCPU(op_kernel_info) { + if (activation_type_ == ActivationType::SwiGLU && swiglu_fusion_ != 1) { + ORT_THROW("CPU MoE only supports interleaved SwiGLU format. Please set swiglu_fusion=1."); + } +} + +template +Status MoE::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* router_probs = context->Input(1); + const Tensor* fc1_experts_weights = context->Input(2); + const Tensor* fc1_experts_bias = context->Input(3); + const Tensor* fc2_experts_weights = context->Input(4); + const Tensor* fc2_experts_bias = context->Input(5); + const Tensor* fc3_experts_weights = context->Input(6); + const Tensor* fc3_experts_bias = context->Input(7); + + // FC3 not supported + if (fc3_experts_weights != nullptr || fc3_experts_bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "FC3 is not implemented for CPU MoE."); + } + + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias, nullptr, + fc2_experts_weights, fc2_experts_bias, nullptr, + fc3_experts_weights, fc3_experts_bias, nullptr, + 1, + activation_type_ == ActivationType::SwiGLU)); + + Tensor* output = context->Output(0, input->Shape()); + + return ComputeMoE(context, input, router_probs, fc1_experts_weights, fc1_experts_bias, + fc2_experts_weights, fc2_experts_bias, output); +} + +template +Status MoE::ComputeMoE(const OpKernelContext* context, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias, + Tensor* output) const { + const auto& input_shape = input->Shape(); + const auto& router_shape = router_probs->Shape(); + const auto& fc2_shape = fc2_experts_weights->Shape(); + + const int64_t num_tokens = input_shape.Size() / input_shape[input_shape.NumDimensions() - 1]; + const int64_t hidden_size = input_shape[input_shape.NumDimensions() - 1]; + const int64_t num_experts = router_shape[1]; + const int64_t inter_size = (fc2_shape[1] * fc2_shape[2]) / hidden_size; + const bool is_swiglu = activation_type_ == ActivationType::SwiGLU; + const int64_t fc1_output_size = is_swiglu ? (inter_size * 2) : inter_size; + + const T* input_data = input->Data(); + const T* router_data = router_probs->Data(); + const T* fc1_weights_data = fc1_experts_weights->Data(); + const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; + const T* fc2_weights_data = fc2_experts_weights->Data(); + const T* fc2_bias_data = fc2_experts_bias ? fc2_experts_bias->Data() : nullptr; + T* output_data = output->MutableData(); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + const T* input_data_to_use = input_data; + IAllocatorUniquePtr input_data_copy_ptr; + if (normalize_routing_weights_) { + input_data_copy_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + T* input_data_copy = input_data_copy_ptr.get(); + std::copy(input_data, input_data + (num_tokens * hidden_size), input_data_copy); + input_data_to_use = input_data_copy; + } + + std::fill_n(output_data, output->Shape().Size(), T{}); + + IAllocatorUniquePtr router_logits_float_buffer; + const float* router_logits_float = nullptr; + if constexpr (std::is_same_v) { + router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); + router_logits_float = router_logits_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_data), const_cast(router_logits_float), static_cast(num_tokens * num_experts)); + } else { + router_logits_float = reinterpret_cast(router_data); + } + + auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + int* route_expert = route_expert_ptr.get(); + auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + float* route_scale = route_scale_ptr.get(); + + auto* tp = context->GetOperatorThreadPool(); + int num_routing_threads = 1; + if (tp != nullptr && num_tokens >= 1024) { + int max_threads = concurrency::ThreadPool::DegreeOfParallelism(tp); + num_routing_threads = std::min(static_cast(num_tokens / 512), max_threads); + num_routing_threads = std::max(1, num_routing_threads); + } + + std::vector>> thread_local_expert_token_maps(num_routing_threads); + for (auto& map : thread_local_expert_token_maps) { + map.resize(static_cast(num_experts)); + for (auto& expert_map : map) { + expert_map.reserve(static_cast(std::max(static_cast(1), num_tokens / num_experts / num_routing_threads * 2))); + } + } + + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { + auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); + auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; + + std::vector> sorted_logits(static_cast(num_experts)); + std::vector full_softmax(static_cast(num_experts)); + + for (int64_t i = work.start; i < work.end; ++i) { + const float* logits = router_logits_float + i * num_experts; + + float max_logit = logits[0]; + for (int64_t j = 1; j < num_experts; ++j) { + max_logit = std::max(max_logit, logits[j]); + } + + float sum_exp = 0.0f; + for (int64_t j = 0; j < num_experts; ++j) { + full_softmax[static_cast(j)] = std::exp(logits[j] - max_logit); + sum_exp += full_softmax[static_cast(j)]; + } + + const float inv_sum_exp = 1.0f / sum_exp; + for (int64_t j = 0; j < num_experts; ++j) { + full_softmax[static_cast(j)] *= inv_sum_exp; + sorted_logits[static_cast(j)] = {full_softmax[static_cast(j)], j}; + } + + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); + + if (normalize_routing_weights_) { + float top_k_sum = 0.0f; + for (int64_t j = 0; j < k_; ++j) { + top_k_sum += sorted_logits[static_cast(j)].first; + } + const float inv_top_k_sum = 1.0f / top_k_sum; + + for (int64_t j = 0; j < k_; ++j) { + int64_t expert_idx = sorted_logits[static_cast(j)].second; + int64_t route_idx = i * k_ + j; + float normalized_weight = sorted_logits[static_cast(j)].first * inv_top_k_sum; + + route_expert[route_idx] = static_cast(expert_idx); + route_scale[route_idx] = normalized_weight; + if (normalized_weight > 0.0f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } else { + for (int64_t j = 0; j < k_; ++j) { + int64_t expert_idx = sorted_logits[static_cast(j)].second; + int64_t route_idx = i * k_ + j; + float weight = sorted_logits[static_cast(j)].first; + + route_expert[route_idx] = static_cast(expert_idx); + route_scale[route_idx] = weight; + if (weight > 0.0f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } + } + }); + + std::vector> expert_token_map(static_cast(num_experts)); + + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + size_t total_tokens_for_expert = 0; + for (int t = 0; t < num_routing_threads; ++t) { + total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); + } + expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); + } + + for (int t = 0; t < num_routing_threads; ++t) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; + if (!local_tokens.empty()) { + auto& expert_map = expert_token_map[static_cast(expert_idx)]; + expert_map.insert(expert_map.end(), + std::make_move_iterator(local_tokens.begin()), + std::make_move_iterator(local_tokens.end())); + } + } + } + + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + auto& expert_map = expert_token_map[static_cast(expert_idx)]; + if (!expert_map.empty()) { + std::sort(expert_map.begin(), expert_map.end()); + } + } + + IAllocatorUniquePtr input_float_buffer; + const float* input_float; + if constexpr (std::is_same_v) { + input_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + input_float = input_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(input_data_to_use), const_cast(input_float), static_cast(num_tokens * hidden_size)); + } else { + input_float = reinterpret_cast(input_data_to_use); + } + + int num_expert_threads = 1; + if (tp != nullptr) { + int total_active_experts = 0; + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + if (!expert_token_map[static_cast(expert_idx)].empty()) { + total_active_experts++; + } + } + + if (total_active_experts > 0) { + int max_threads = concurrency::ThreadPool::DegreeOfParallelism(tp); + num_expert_threads = std::min(total_active_experts, max_threads); + num_expert_threads = std::min(num_expert_threads, 8); + } + } + + // Calculate maximum possible tokens per expert for buffer sizing + int64_t max_tokens_per_expert = 0; + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const auto& routes = expert_token_map[static_cast(expert_idx)]; + max_tokens_per_expert = std::max(max_tokens_per_expert, static_cast(routes.size())); + } + + // Thread-local buffer pool for expert processing + struct ThreadLocalBuffers { + IAllocatorUniquePtr A1_buffer; + IAllocatorUniquePtr batch_weights_buffer; + IAllocatorUniquePtr token_ids_buffer; + IAllocatorUniquePtr A1_t_buffer; + IAllocatorUniquePtr C2_buffer; + // Additional buffers for ProcessExpertBatch to avoid repeated allocations + IAllocatorUniquePtr fc1_output_buffer; + IAllocatorUniquePtr activation_output_buffer; + int64_t current_capacity = 0; + int64_t current_fc1_capacity = 0; + int64_t current_activation_capacity = 0; + + void EnsureCapacity(AllocatorPtr& allocator, int64_t required_tokens, int64_t hidden_size, + int64_t fc1_output_size, int64_t inter_size) { + if (required_tokens > current_capacity) { + // Use high watermark approach - allocate more than needed for future reuse + int64_t new_capacity = std::max(required_tokens * 2, current_capacity + 512); + + A1_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size)); + batch_weights_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity)); + token_ids_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity)); + A1_t_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size)); + C2_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size)); + + current_capacity = new_capacity; + } + + // Ensure ProcessExpertBatch buffers have sufficient capacity + int64_t required_fc1_capacity = required_tokens * fc1_output_size; + int64_t required_activation_capacity = required_tokens * inter_size; + + if (required_fc1_capacity > current_fc1_capacity) { + int64_t new_fc1_capacity = std::max(required_fc1_capacity * 2, current_fc1_capacity + (512 * fc1_output_size)); + fc1_output_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_fc1_capacity)); + current_fc1_capacity = new_fc1_capacity; + } + + if (required_activation_capacity > current_activation_capacity) { + int64_t new_activation_capacity = std::max(required_activation_capacity * 2, current_activation_capacity + (512 * inter_size)); + activation_output_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_activation_capacity)); + current_activation_capacity = new_activation_capacity; + } + } + }; + + // Pre-allocate thread-local buffer pools + std::vector thread_buffers(num_expert_threads); + for (int i = 0; i < num_expert_threads; ++i) { + thread_buffers[i].EnsureCapacity(allocator, max_tokens_per_expert, hidden_size, fc1_output_size, inter_size); + } + + const size_t output_buffer_size = static_cast(output->Shape().Size()); + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); + float* thread_local_outputs = thread_local_outputs_ptr.get(); + + // Initialize thread-local outputs with vectorized operation + std::fill_n(thread_local_outputs, static_cast(num_expert_threads) * output_buffer_size, 0.0f); + + // Optimized expert processing with thread-local buffer reuse + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { + int thread_id = static_cast(thread_id_pd); + auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); + + float* local_output = thread_local_outputs + static_cast(thread_id) * output_buffer_size; + ThreadLocalBuffers& buffers = thread_buffers[thread_id]; + + for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) { + const auto& routes = expert_token_map[static_cast(expert_idx)]; + if (routes.empty()) continue; + + const int64_t num_expert_tokens = static_cast(routes.size()); + + // Ensure thread-local buffers have sufficient capacity + buffers.EnsureCapacity(allocator, num_expert_tokens, hidden_size, fc1_output_size, inter_size); + + // Use pre-allocated buffers from thread-local pool + float* A1 = buffers.A1_buffer.get(); + float* batch_weights = buffers.batch_weights_buffer.get(); + int64_t* token_ids = buffers.token_ids_buffer.get(); + T* A1_t = buffers.A1_t_buffer.get(); + T* C2 = buffers.C2_buffer.get(); + T* fc1_output = buffers.fc1_output_buffer.get(); + T* activation_output = buffers.activation_output_buffer.get(); + + // Optimized data gathering with better memory access patterns + for (int64_t r = 0; r < num_expert_tokens; ++r) { + int64_t route_idx = routes[static_cast(r)]; + int64_t token = route_idx / k_; + + token_ids[r] = token; + batch_weights[r] = route_scale[route_idx]; + + // Use SIMD-friendly copy for better performance + const float* src = input_float + token * hidden_size; + float* dst = A1 + static_cast(r) * static_cast(hidden_size); + std::copy(src, src + hidden_size, dst); + } + + const T* fc1_expert_weights = fc1_weights_data + expert_idx * fc1_output_size * hidden_size; + const T* fc1_expert_bias = fc1_bias_data ? fc1_bias_data + expert_idx * fc1_output_size : nullptr; + const T* fc2_expert_weights = fc2_weights_data + expert_idx * hidden_size * inter_size; + const T* fc2_expert_bias = fc2_bias_data ? fc2_bias_data + expert_idx * hidden_size : nullptr; + + // Convert input to T only when needed for computation + for (size_t i = 0; i < static_cast(num_expert_tokens * hidden_size); ++i) { + A1_t[i] = static_cast(A1[i]); + } + + ORT_IGNORE_RETURN_VALUE(ProcessExpertBatch(A1_t, token_ids, batch_weights, + num_expert_tokens, expert_idx, + fc1_expert_weights, fc1_expert_bias, + fc2_expert_weights, fc2_expert_bias, + C2, hidden_size, inter_size, + fc1_output, activation_output)); + + // Optimized output accumulation with vectorized operations + for (int64_t r = 0; r < num_expert_tokens; ++r) { + int64_t token = token_ids[r]; + const T* expert_output_t = C2 + static_cast(r) * static_cast(hidden_size); + float w = batch_weights[r]; + float* dest = local_output + static_cast(token) * static_cast(hidden_size); + + // Use explicit loop for better vectorization opportunities + for (int64_t j = 0; j < hidden_size; ++j) { + dest[j] += w * static_cast(expert_output_t[j]); + } + } + } + }); + + auto accumulate = [&](float* buffer) { + std::fill_n(buffer, output_buffer_size, 0.0f); + + for (size_t j = 0; j < output_buffer_size; ++j) { + double sum = 0.0; + double c = 0.0; + + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + double y = static_cast(thread_local_outputs[thread_offset + j]) - c; + double t = sum + y; + c = (t - sum) - y; + sum = t; + } + buffer[j] = static_cast(sum); + } + }; + + if constexpr (std::is_same_v) { + auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size); + float* final_output_float = final_output_float_ptr.get(); + accumulate(final_output_float); + + MlasConvertFloatToHalfBuffer(final_output_float, + reinterpret_cast(output->MutableData()), + static_cast(output_buffer_size)); + } else { + auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size); + float* final_output_float = final_output_float_ptr.get(); + accumulate(final_output_float); + + float* out_ptr = reinterpret_cast(output->MutableData()); + memcpy(out_ptr, final_output_float, output_buffer_size * sizeof(float)); + } + return Status::OK(); +} +template +Status MoE::ProcessExpertBatch(const T* input_tokens, + const int64_t* token_expert_ids, + const float* token_weights, + int64_t batch_size, + int64_t expert_id, + const T* fc1_weights, + const T* fc1_bias, + const T* fc2_weights, + const T* fc2_bias, + T* output_buffer, + int64_t hidden_size, + int64_t inter_size, + T* fc1_output_buffer, + T* activation_output_buffer) const { + const bool is_swiglu = activation_type_ == ActivationType::SwiGLU; + const int64_t fc1_output_size = is_swiglu ? (inter_size * 2) : inter_size; + + constexpr int64_t stack_threshold = 1024; + const bool use_stack = (batch_size * fc1_output_size) <= stack_threshold; + + std::vector fc1_output_vec; + std::vector activation_output_vec; + T* fc1_output; + T* activation_output; + + if (use_stack) { + fc1_output_vec.resize(static_cast(batch_size * fc1_output_size)); + activation_output_vec.resize(static_cast(batch_size * inter_size)); + fc1_output = fc1_output_vec.data(); + activation_output = activation_output_vec.data(); + } else { + fc1_output_vec.resize(static_cast(batch_size * fc1_output_size)); + activation_output_vec.resize(static_cast(batch_size * inter_size)); + fc1_output = fc1_output_vec.data(); + activation_output = activation_output_vec.data(); + } + + ORT_RETURN_IF_ERROR(ComputeGEMM(input_tokens, fc1_weights, fc1_output, + batch_size, hidden_size, fc1_output_size, true)); + + if (fc1_bias) { + for (int64_t batch = 0; batch < batch_size; ++batch) { + T* batch_output = fc1_output + batch * fc1_output_size; + // Explicit loop for better vectorization + for (int64_t i = 0; i < fc1_output_size; ++i) { + batch_output[i] = static_cast(static_cast(batch_output[i]) + + static_cast(fc1_bias[i])); + } + } + } + + if (is_swiglu) { + for (int64_t batch = 0; batch < batch_size; ++batch) { + ApplySwiGLUVectorized(fc1_output + batch * fc1_output_size, + activation_output + batch * inter_size, + inter_size); + } + } else { + ApplyActivationVectorized(fc1_output, batch_size * fc1_output_size); + std::copy(fc1_output, fc1_output + (batch_size * fc1_output_size), activation_output); + } + + ORT_RETURN_IF_ERROR(ComputeGEMM(activation_output, fc2_weights, output_buffer, + batch_size, inter_size, hidden_size, true)); + + if (fc2_bias) { + for (int64_t batch = 0; batch < batch_size; ++batch) { + T* batch_output = output_buffer + batch * hidden_size; + for (int64_t i = 0; i < hidden_size; ++i) { + batch_output[i] = static_cast(static_cast(batch_output[i]) + + static_cast(fc2_bias[i])); + } + } + } + + return Status::OK(); +} + +template <> +Status MoE::ComputeGEMM(const float* A, const float* B, float* C, + int64_t M, int64_t K, int64_t N, bool transpose_B) const { + MLAS_SGEMM_DATA_PARAMS params; + params.A = A; + params.lda = static_cast(K); + params.alpha = 1.0f; + params.beta = 0.0f; + params.C = C; + params.ldc = static_cast(N); + params.B = B; + + if (transpose_B) { + params.ldb = static_cast(K); + MlasGemm(CblasNoTrans, CblasTrans, static_cast(M), static_cast(N), static_cast(K), params, nullptr); + } else { + params.ldb = static_cast(N); + MlasGemm(CblasNoTrans, CblasNoTrans, static_cast(M), static_cast(N), static_cast(K), params, nullptr); + } + + return Status::OK(); +} + +template <> +Status MoE::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFloat16* C, + int64_t M, int64_t K, int64_t N, bool transpose_B) const { + MLAS_HALF_GEMM_DATA_PARAMS params; + params.A = A; + params.lda = static_cast(K); + params.C = C; + params.ldc = static_cast(N); + params.AIsfp32 = false; + params.BIsfp32 = false; + params.B = B; + + if (transpose_B) { + params.ldb = static_cast(K); + } else { + params.ldb = static_cast(N); + } + + MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, nullptr); + return Status::OK(); +} + +template +void MoE::ApplyActivationVectorized(T* data, int64_t size) const { + for (int64_t i = 0; i < size; ++i) { + float val = static_cast(data[i]); + data[i] = static_cast(ApplyActivation(val, activation_type_)); + } +} + +template +void MoE::ApplySwiGLUVectorized(const T* input, T* output, int64_t size) const { + for (int64_t i = 0; i < size; ++i) { + float gate = static_cast(input[2 * i]); + float linear = static_cast(input[2 * i + 1]); + + gate = std::min(gate, swiglu_limit_); + linear = std::clamp(linear, -swiglu_limit_, swiglu_limit_); + + float sigmoid_arg = activation_alpha_ * gate; + float sigmoid_out; + if (sigmoid_arg > 0) { + float exp_neg = std::exp(-sigmoid_arg); + sigmoid_out = 1.0f / (1.0f + exp_neg); + } else { + float exp_pos = std::exp(sigmoid_arg); + sigmoid_out = exp_pos / (1.0f + exp_pos); + } + + float swish_out = gate * sigmoid_out; + output[i] = static_cast(swish_out * (linear + activation_beta_)); + } +} + +template <> +void MoE::ApplySwiGLUVectorized(const float* input, float* output, int64_t size) const { + ApplySwiGLUActivation(input, output, size, true, + activation_alpha_, activation_beta_, swiglu_limit_); +} + +template class MoE; +template class MoE; + +#define REGISTER_KERNEL_TYPED(type) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MoE, kMSDomain, 1, type, kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(0, 0) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + MoE); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h new file mode 100644 index 0000000000000..60d8217015b5b --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/moe/moe_base_cpu.h" + +namespace onnxruntime { +namespace contrib { + +template +class MoE final : public OpKernel, public MoEBaseCPU { + public: + explicit MoE(const OpKernelInfo& op_kernel_info); + Status Compute(OpKernelContext* context) const override; + + private: + Status ComputeMoE(const OpKernelContext* context, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias, + Tensor* output) const; + + Status ProcessExpertBatch(const T* input_tokens, + const int64_t* token_expert_ids, + const float* token_weights, + int64_t num_tokens, + int64_t expert_id, + const T* fc1_weights, + const T* fc1_bias, + const T* fc2_weights, + const T* fc2_bias, + T* output_buffer, + int64_t hidden_size, + int64_t inter_size, + T* fc1_output_buffer, + T* activation_output_buffer) const; + + Status ComputeGEMM(const T* A, const T* B, T* C, + int64_t M, int64_t K, int64_t N, + bool transpose_B = false) const; + + void ApplyActivationVectorized(T* data, int64_t size) const; + void ApplySwiGLUVectorized(const T* input, T* output, int64_t size) const; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc index 2c59210bfabd4..5a3c5d1dd0364 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc @@ -37,10 +37,18 @@ void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t gate_val = std::min(gate_val, clamp_limit); linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit); + // Use numerically stable sigmoid computation (matches CUDA kernel behavior) float sigmoid_arg = activation_alpha * gate_val; - float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg)); - float swish_out = gate_val * sigmoid_out; + float sigmoid_out; + if (sigmoid_arg > 0) { + float exp_neg = std::exp(-sigmoid_arg); + sigmoid_out = 1.0f / (1.0f + exp_neg); + } else { + float exp_pos = std::exp(sigmoid_arg); + sigmoid_out = exp_pos / (1.0f + exp_pos); + } + float swish_out = gate_val * sigmoid_out; output_data[i] = swish_out * (linear_val + activation_beta); } } else { diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 0690b8894eb7a..ab740ea38fb74 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -1690,6 +1690,97 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { #endif } +// Test for CPU MoE implementation +static void RunMoECpuTest(const std::vector& input, const std::vector& router_probs, + const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights, + const std::vector& fc3_experts_weights, const std::vector& fc1_experts_bias, + const std::vector& fc2_experts_bias, const std::vector& output_data, int num_rows, + int num_experts, int hidden_size, int inter_size, std::string activation_type, + int normalize_routing_weights = 1, int top_k = 1) { + OpTester tester("MoE", 1, onnxruntime::kMSDomain); + tester.AddAttribute("k", static_cast(top_k)); + tester.AddAttribute("activation_type", activation_type); + tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); + + bool is_swiglu = (activation_type == "swiglu"); + + if (is_swiglu) { + tester.AddAttribute("swiglu_fusion", static_cast(1)); + tester.AddAttribute("activation_beta", 1.0f); + } + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + + int64_t fc1_output_size = is_swiglu ? (2 * inter_size) : inter_size; + + std::vector fc1_experts_weights_dims = {num_experts, hidden_size, fc1_output_size}; + std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; + std::vector fc1_experts_bias_dims = {num_experts, fc1_output_size}; + std::vector fc2_experts_bias_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + tester.AddInput("input", input_dims, input); + tester.AddInput("router_probs", router_probs_dims, router_probs); + tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + if (!fc1_experts_bias.empty()) { + tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias); + } else { + tester.AddOptionalInputEdge(); + } + tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + if (!fc2_experts_bias.empty()) { + tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias); + } else { + tester.AddOptionalInputEdge(); + } + if (!fc3_experts_weights.empty()) { + tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights); + } else { + tester.AddOptionalInputEdge(); + } + tester.AddOptionalInputEdge(); // fc3_experts_bias + + tester.AddOutput("output", output_dims, output_data); + tester.SetOutputTolerance(0.05f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MoETest, MoECpuTest_BasicSwiGLU) { + int num_rows = 2; + int num_experts = 2; + int hidden_size = 4; + int inter_size = 8; + + // Simple test data + const std::vector input = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f}; + + const std::vector router_probs = { + 0.8f, 0.2f, + 0.3f, 0.7f}; + + const std::vector fc1_experts_weights(num_experts * hidden_size * (2 * inter_size), 0.1f); + + const std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 0.1f); + + const std::vector fc3_experts_weights = {}; // No FC3 + const std::vector fc1_experts_bias = {}; // No bias + const std::vector fc2_experts_bias = {}; // No bias + + const std::vector output_data = { + 1.169694f, 1.169694f, 1.169694f, 1.169694f, + 6.970291f, 6.970291f, 6.970291f, 6.970291f}; + + RunMoECpuTest(input, router_probs, fc1_experts_weights, fc2_experts_weights, + fc3_experts_weights, fc1_experts_bias, fc2_experts_bias, output_data, + num_rows, num_experts, hidden_size, inter_size, "swiglu"); +} #endif } // namespace test diff --git a/onnxruntime/test/python/transformers/test_moe_cpu.py b/onnxruntime/test/python/transformers/test_moe_cpu.py new file mode 100644 index 0000000000000..d6cbcc64733d4 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_moe_cpu.py @@ -0,0 +1,473 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# +# Regular MoE CPU kernel testing implementation - SwiGLU Interleaved Only +# +# This file tests the non-quantized MoE CPU implementation with SwiGLU +# activation in interleaved format and validates parity between +# PyTorch reference implementation and ONNX Runtime CPU kernel. +# +# Based on the CUDA test structure for consistency. +# -------------------------------------------------------------------------- + +import itertools +import time +import unittest + +import numpy +import torch +import torch.nn as nn +import torch.nn.functional as F +from onnx import TensorProto, helper +from parameterized import parameterized + +from onnxruntime import InferenceSession, SessionOptions + +# Device and provider settings for CPU +device = torch.device("cpu") +ort_provider = ["CPUExecutionProvider"] + +torch.manual_seed(42) +numpy.random.seed(42) + +onnx_to_torch_type_map = { + TensorProto.FLOAT16: torch.float16, + TensorProto.FLOAT: torch.float, + TensorProto.BFLOAT16: torch.bfloat16, + TensorProto.UINT8: torch.uint8, +} + +ort_to_numpy_type_map = { + TensorProto.FLOAT16: numpy.float16, + TensorProto.FLOAT: numpy.float32, + TensorProto.UINT8: numpy.uint8, +} + +ort_dtype_name_map = { + TensorProto.FLOAT16: "FP16", + TensorProto.FLOAT: "FP32", + TensorProto.BFLOAT16: "BF16", +} + + +class SwigluMoeConfig: + def __init__( + self, + hidden_size=2048, + intermediate_size=2048, + num_experts_per_token=2, + num_local_experts=8, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts_per_token = num_experts_per_token + self.num_local_experts = num_local_experts + + +def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): + x_glu = x[..., ::2] + x_linear = x[..., 1::2] + + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) + return y + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +def make_onnx_intializer(name: str, tensor: torch.Tensor, shape, onnx_dtype): + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + if torch_dtype == torch.bfloat16: + numpy_vals_uint16 = tensor.to(torch.bfloat16).cpu().view(torch.uint16).numpy() + initializer = helper.make_tensor( + name=name, + data_type=TensorProto.BFLOAT16, + dims=shape, + vals=numpy_vals_uint16.tobytes(), + raw=True, + ) + else: + initializer = helper.make_tensor( + name=name, + data_type=onnx_dtype, + dims=shape, + vals=tensor.flatten().detach().cpu().numpy().astype(numpy.uint8).tolist() + if onnx_dtype == TensorProto.UINT8 + else tensor.detach().to(torch_dtype).flatten().tolist(), + raw=False, + ) + return initializer + + +def create_swiglu_moe_onnx_graph( + num_tokens: int, + num_experts: int, + hidden_size: int, + inter_size: int, + topk: int, + onnx_dtype: int, + quant_bits: int, + fc1_experts_weights: torch.Tensor, + fc1_experts_bias: torch.Tensor, + fc2_experts_weights: torch.Tensor, + fc2_experts_bias: torch.Tensor, + fc1_experts_weight_scale: torch.Tensor = None, + fc2_experts_weight_scale: torch.Tensor = None, +): + use_quant = quant_bits > 0 + op_name = "QMoE" if use_quant else "MoE" + + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_weight_scale", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_weight_scale", + "fc2_experts_bias", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_bias", + ] + ) + + nodes = [ + helper.make_node( + op_name, + inputs, + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, + activation_type="swiglu", + swiglu_fusion=1, + activation_alpha=1.702, + activation_beta=1.0, + domain="com.microsoft", + ), + ] + + if use_quant: + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + + components = 2 if quant_bits == 4 else 1 + fc1_weight_shape = [num_experts, 2 * inter_size, hidden_size // components] + fc1_bias_shape = [num_experts, 2 * inter_size] + fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size] + + fc2_weight_shape = [num_experts, hidden_size, inter_size // components] + fc2_bias_shape = [num_experts, hidden_size] + fc2_experts_weight_scale_shape = [num_experts, hidden_size] + + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + weight_torch_dtype = onnx_to_torch_type_map[weight_onnx_type] + + initializers = [ + make_onnx_intializer( + "fc1_experts_weights", fc1_experts_weights.to(weight_torch_dtype), fc1_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc1_experts_bias", fc1_experts_bias.to(torch_dtype), fc1_bias_shape, onnx_dtype), + make_onnx_intializer( + "fc2_experts_weights", fc2_experts_weights.to(weight_torch_dtype), fc2_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc2_experts_bias", fc2_experts_bias.to(torch_dtype), fc2_bias_shape, onnx_dtype), + ] + + if use_quant: + initializers.extend( + [ + make_onnx_intializer( + "fc1_experts_weight_scale", + fc1_experts_weight_scale.to(torch_dtype), + fc1_experts_weight_scale_shape, + onnx_dtype, + ), + make_onnx_intializer( + "fc2_experts_weight_scale", + fc2_experts_weight_scale.to(torch_dtype), + fc2_experts_weight_scale_shape, + onnx_dtype, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + onnx_dtype, + [num_tokens, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class SparseMoeBlockORTHelper(nn.Module): + def __init__(self, quant_bits=0, onnx_dtype=None): + super().__init__() + self.quant_bits = quant_bits + if onnx_dtype is None: + self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT + else: + self.onnx_dtype = onnx_dtype + self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 + + def create_ort_session(self, moe_onnx_graph): + sess_options = SessionOptions() + sess_options.log_severity_level = 2 + + try: + ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) + except Exception as e: + print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}") + print("Skipping ONNX Runtime execution for this test case.") + return None + + return ort_session + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + pass + + def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: + if self.ort_sess is None: + return None + + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states_flat) + + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] + + tensors = { + "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + } + + ort_inputs = { + "input": tensors["input"].detach().cpu().numpy(), + "router_probs": tensors["router_probs"].detach().cpu().numpy(), + } + + if enable_performance_test: + repeat = 1000 + s = time.time() + for _ in range(repeat): + self.ort_sess.run(None, ort_inputs) + e = time.time() + print(f"MoE CPU kernel time: {(e - s) / repeat * 1000} ms") + ort_outputs = self.ort_sess.run(None, ort_inputs) + else: + ort_outputs = self.ort_sess.run(None, ort_inputs) + + output_tensor = torch.from_numpy(ort_outputs[0]).to(device) + + return output_tensor.reshape(batch_size, sequence_length, hidden_dim) + + def parity_check(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + torch_output = self.forward(hidden_state) + ort_output = self.ort_forward(hidden_state) + + dtype_str = ort_dtype_name_map[self.onnx_dtype] + + ort_dtype_quant_bits_tolerance_map = { + "FP32:0": (5e-3, 1e-3), + } + + atol, rtol = ort_dtype_quant_bits_tolerance_map[f"{dtype_str}:{self.quant_bits}"] + if ort_output is not None: + print( + f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str}," + f" batch: {self.batch_size}, seq_len: {self.sequence_length}," + f" max_diff: {(torch_output.cpu() - ort_output.cpu()).abs().max()}" + ) + torch.testing.assert_close( + ort_output.cpu().to(torch.float32), torch_output.cpu().to(torch.float32), rtol=rtol, atol=atol + ) + + def benchmark_ort(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + self.ort_forward(hidden_state, enable_performance_test=True) + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + + for expert in self.experts: + w1_weight = expert.w1.weight.data.clone() + w2_weight = expert.w2.weight.data.clone() + w1_bias = expert.w1.bias.data.clone() + w2_bias = expert.w2.bias.data.clone() + + fc1_w_list.append(w1_weight) + fc2_w_list.append(w2_weight) + fc1_b_list.append(w1_bias) + fc2_b_list.append(w2_bias) + + fc1_experts_weights = torch.stack(fc1_w_list, dim=0) + fc2_experts_weights = torch.stack(fc2_w_list, dim=0) + fc1_experts_bias = torch.stack(fc1_b_list, dim=0) + fc2_experts_bias = torch.stack(fc2_b_list, dim=0) + + self.batch_size = batch_size + self.sequence_length = sequence_length + + self.moe_onnx_graph = create_swiglu_moe_onnx_graph( + num_tokens=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + hidden_size=self.hidden_dim, + inter_size=self.ffn_dim, + topk=self.top_k, + onnx_dtype=self.onnx_dtype, + quant_bits=self.quant_bits, + fc1_experts_weights=fc1_experts_weights, + fc1_experts_bias=fc1_experts_bias, + fc2_experts_weights=fc2_experts_weights, + fc2_experts_bias=fc2_experts_bias, + ) + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + # Compute full softmax over all experts (same as CUDA) + full_probs = F.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(full_probs, self.top_k, dim=-1) + + # For normalize_routing_weights=1: normalize by sum of top-k values (same as CUDA) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +swiglu_test_cases = list( + itertools.product( + [1, 2], # batch_size + [16, 32], # sequence_length + [0], # quant_bits (CPU kernel only supports float32) + ) +) + +perf_test_cases = list( + itertools.product( + [1], # batch_size + [128], # sequence_length + [0], # quant_bits (CPU kernel only supports float32) + ) +) + + +class TestSwigluMoECPU(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=64, intermediate_size=256, num_experts_per_token=2, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.parity_check() + + +class TestSwigluMoECPUPerf(unittest.TestCase): + @parameterized.expand(perf_test_cases) + def test_swiglu_moe_perf(self, batch_size, sequence_length, quant_bits): + hidden_size = 1024 + intermediate_size = 2048 + num_experts_per_token = 4 + num_local_experts = 16 + config = SwigluMoeConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts_per_token=num_experts_per_token, + num_local_experts=num_local_experts, + ) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.benchmark_ort() + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index efaaca29a01b6..403becbe0616a 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -254,7 +254,8 @@ def create_cpu_moe_onnx_graph( inter_size = intermediate_size topk = top_k - use_quant = True + # Only override use_quant for backward compatibility if not explicitly set + # use_quant = True # This line was causing issues for regular MoE tests if fc1_scales is None and use_quant: return None @@ -263,30 +264,52 @@ def create_cpu_moe_onnx_graph( if not has_onnx: return None - assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" - assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" - assert fc1_scales is not None, "FC1 scales must be provided for QMoE" - assert fc2_scales is not None, "FC2 scales must be provided for QMoE" - assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" - assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" + if use_quant: + # Assertions only apply to quantized MoE + assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" + assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" + assert fc1_scales is not None, "FC1 scales must be provided for QMoE" + assert fc2_scales is not None, "FC2 scales must be provided for QMoE" + assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" + assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" if not has_onnx: return None - op_name = "QMoE" - inputs = [ - "input", - "router_probs", - "fc1_experts_weights", - "fc1_scales", - "", - "fc2_experts_weights", - "fc2_scales", - "", - ] + # Set operator name and inputs based on quantization mode + if use_quant: + op_name = "QMoE" + inputs = [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + ] + else: + # For regular (non-quantized) MoE, use different operator and input layout + op_name = "MoE" # Regular MoE operator + inputs = [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_bias" if fc1_bias is not None else "", # fc1_bias as input 3 + "fc2_experts_weights", + "fc2_experts_bias" if fc2_bias is not None else "", # fc2_bias as input 5 + "", # fc3_experts_weights (not used) + "", # fc3_experts_bias (not used) + ] activation = "swiglu" if use_swiglu else "silu" + # Set normalization behavior based on operator type: + # - QMoE: Raw logits passed, needs normalization in C++ kernel + # - Regular MoE: Pre-computed probabilities passed, no additional normalization needed + normalize_routing = 1 if use_quant else 0 + nodes = [ helper.make_node( op_name, @@ -294,13 +317,14 @@ def create_cpu_moe_onnx_graph( ["output"], "MoE_0", k=topk, - normalize_routing_weights=1, # Use proper routing normalization to match PyTorch behavior + normalize_routing_weights=normalize_routing, activation_type=activation, # Add new attributes with backwards-compatible default values - swiglu_fusion=1 if (use_swiglu and swiglu_interleaved) else 0, # 1 = fused and interleaved + swiglu_fusion=1 if use_swiglu else 0, # 1 if using SwiGLU activation swiglu_limit=7.0, activation_alpha=1.702, activation_beta=1.0, + swiglu_interleaved=1 if swiglu_interleaved else 0, # Enable this attribute domain="com.microsoft", ), ] @@ -339,79 +363,106 @@ def create_cpu_moe_onnx_graph( fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) fc2_scale_size = num_experts * hidden_size - # Handle scale tensors - fc1_scales and fc2_scales are guaranteed to be not None due to earlier assertions - # Handle different possible scale tensor structures for fc1_scales - if len(fc1_scales.shape) == 4: - # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output - if use_swiglu: - fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy() + # Handle scale tensors based on quantization mode + if use_quant: + # Handle different possible scale tensor structures for fc1_scales + if len(fc1_scales.shape) == 4: + # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output + if use_swiglu: + fc1_scale_tensor = ( + fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy() + ) + else: + fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc1_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size: + # For SwiGLU, duplicate the scales to cover both gate and value components + fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten() + elif fc1_scale_tensor.size > fc1_scale_size: + # Truncate to expected size + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] else: - fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy() - elif len(fc1_scales.shape) == 2: - # 2D case: already flattened, just ensure correct size - fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size: - # For SwiGLU, duplicate the scales to cover both gate and value components - fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten() - elif fc1_scale_tensor.size > fc1_scale_size: - # Truncate to expected size - fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] - else: - # Other cases: flatten and truncate/pad as needed - fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if fc1_scale_tensor.size > fc1_scale_size: - fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] - elif fc1_scale_tensor.size < fc1_scale_size: - # Pad with ones if too small - pad_size = fc1_scale_size - fc1_scale_tensor.size - fc1_scale_tensor = numpy.concatenate([fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)]) - - # Process scale tensor for proper shape - fc1_scale_data_list = fc1_scale_tensor.tolist() - fc1_scale_data = fc1_scale_data_list - - # Handle different possible scale tensor structures for fc2_scales - if len(fc2_scales.shape) == 4: - # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output - fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy() - elif len(fc2_scales.shape) == 2: - # 2D case: already flattened, just ensure correct size - fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if fc2_scale_tensor.size > fc2_scale_size: - # Truncate to expected size - fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + # Other cases: flatten and truncate/pad as needed + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc1_scale_tensor.size > fc1_scale_size: + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + elif fc1_scale_tensor.size < fc1_scale_size: + # Pad with ones if too small + pad_size = fc1_scale_size - fc1_scale_tensor.size + fc1_scale_tensor = numpy.concatenate( + [fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)] + ) + + # Process scale tensor for proper shape + fc1_scale_data_list = fc1_scale_tensor.tolist() + fc1_scale_data = fc1_scale_data_list + + # Handle different possible scale tensor structures for fc2_scales + if len(fc2_scales.shape) == 4: + # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output + fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc2_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + # Truncate to expected size + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + elif fc2_scale_tensor.size < fc2_scale_size: + # Pad with ones if too small + pad_size = fc2_scale_size - fc2_scale_tensor.size + fc2_scale_tensor = numpy.concatenate( + [fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)] + ) + + # Process scale tensor for proper shape + fc2_scale_data_list = fc2_scale_tensor.tolist() + fc2_scale_data = fc2_scale_data_list + + initializers.extend( + [ + helper.make_tensor( + "fc1_scales", + onnx_dtype, + fc1_scale_shape, + fc1_scale_data, + raw=False, + ), + helper.make_tensor( + "fc2_scales", + onnx_dtype, + fc2_scale_shape, + fc2_scale_data, + raw=False, + ), + ] + ) else: - # Other cases: flatten and truncate/pad as needed - fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if fc2_scale_tensor.size > fc2_scale_size: - fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] - elif fc2_scale_tensor.size < fc2_scale_size: - # Pad with ones if too small - pad_size = fc2_scale_size - fc2_scale_tensor.size - fc2_scale_tensor = numpy.concatenate([fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)]) - - # Process scale tensor for proper shape - fc2_scale_data_list = fc2_scale_tensor.tolist() - fc2_scale_data = fc2_scale_data_list - - initializers.extend( - [ - helper.make_tensor( - "fc1_scales", - onnx_dtype, - fc1_scale_shape, - fc1_scale_data, - raw=False, - ), - helper.make_tensor( - "fc2_scales", - onnx_dtype, - fc2_scale_shape, - fc2_scale_data, - raw=False, - ), - ] - ) + # For non-quantized mode, add bias tensors if provided + if fc1_bias is not None: + initializers.append( + helper.make_tensor( + "fc1_experts_bias", + onnx_dtype, + list(fc1_bias.shape), + fc1_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(), + ) + ) + if fc2_bias is not None: + initializers.append( + helper.make_tensor( + "fc2_experts_bias", + onnx_dtype, + list(fc2_bias.shape), + fc2_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(), + ) + ) graph_inputs = [ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), @@ -619,17 +670,54 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: if self.ort_sess is None: + print(f"ERROR: ORT session is None for {self.__class__.__name__}") return None batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states_flat) + # Different routing logic for QMoE vs regular MoE: + # - QMoE expects raw logits (does its own softmax internally) + # - Regular MoE expects pre-computed routing probabilities + if hasattr(self, "quant_bits") and self.quant_bits > 0: + # QMoE: Pass raw logits directly (QMoE does softmax internally) + router_input = router_logits + # print("DEBUG: Using QMoE routing (raw logits)") + else: + # Regular MoE: Apply the same routing logic as PyTorch reference + # This converts raw logits to proper routing probabilities + routing_weights, selected_experts = masked_sampling_omp_inference( + router_logits, + top_k=self.top_k, + jitter_eps=self.router_jitter_noise, + training=False, + ) + + # IMPORTANT: The routing weights from masked_sampling_omp_inference sum to top_k, + # but ONNX Runtime expects normalized probabilities that sum to 1.0 + # Normalize the routing weights per token + routing_weights = routing_weights / routing_weights.sum(dim=1, keepdim=True) + + # Create proper router probabilities tensor that matches PyTorch routing + router_input = torch.zeros_like(router_logits) + for i in range(router_logits.shape[0]): # For each token + for j in range(self.top_k): # For each top-k expert + expert_idx = selected_experts[i, j] + router_input[i, expert_idx] = routing_weights[i, j] + + # print("DEBUG: Using regular MoE routing (processed probabilities)") + + # print(f"DEBUG: router_input stats: mean={router_input.mean():.6f}, std={router_input.std():.6f}") + # print( + # f"DEBUG: hidden_states_flat stats: mean={hidden_states_flat.mean():.6f}, std={hidden_states_flat.std():.6f}" + # ) + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] tensors = { "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), - "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "router_probs": router_input.clone().to(device=device, dtype=torch_dtype), "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), } @@ -656,10 +744,14 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False buffer_ptr=tensor.data_ptr(), ) + # print("DEBUG: About to run ORT inference...") + iobinding.synchronize_inputs() self.ort_sess.run_with_iobinding(iobinding) iobinding.synchronize_outputs() + # print("DEBUG: ORT inference completed successfully") + if enable_performance_test: repeat = 100 s = time.time() @@ -691,28 +783,10 @@ def recreate_onnx_model(self): w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) if self.use_swiglu: - if self.swiglu_interleaved: - pass - else: - w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) - - gate_weights = pre_qweight1 - value_weights = pre_qweight3 - gate_scales = w1_scale - value_scales = w3_scale - - pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) - w1_scale = torch.cat([gate_scales, value_scales], dim=0) - - if self.swiglu_interleaved: - self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) - - else: - intermediate_size = self.experts[i].w1.weight.shape[0] - gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() - value_dequant = w1_qdq[intermediate_size:].contiguous().clone() - self.experts[i].w1.weight.data = gate_dequant - self.experts[i].w3.weight.data = value_dequant + # For SwiGLU, CPU kernel now always expects interleaved format + # SwigluMlp weights are already in interleaved format [gate_0, linear_0, gate_1, linear_1, ...] + # No conversion needed - both CPU and CUDA use interleaved format + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) else: self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() @@ -754,7 +828,7 @@ def recreate_onnx_model(self): use_swiglu=self.use_swiglu, use_quant=True, # Always use QMoE quant_bits=self.quant_bits, - swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + swiglu_interleaved=True, # CPU kernel now always expects interleaved format ) except Exception: self.moe_onnx_graph = None @@ -1043,10 +1117,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class TestPhiQMoECPU(unittest.TestCase): @parameterized.expand(phi3_test_cases) def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): - torch.manual_seed(42) - numpy.random.seed(42) + # Create unique seed based on test parameters to ensure different inputs for each test + base_seed = 2000 # Different base seed from other tests + param_hash = hash((batch_size, sequence_length, quant_bits)) + unique_seed = base_seed + abs(param_hash) % 1000 - test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + torch.manual_seed(unique_seed) + numpy.random.seed(unique_seed) + + test_config = ( + f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, seed={unique_seed}" + ) print(f"Running Phi3 QMoE test: {test_config}") config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2) @@ -1086,10 +1167,17 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): class TestSwigluQMoECPU(unittest.TestCase): @parameterized.expand(swiglu_test_cases) def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): - torch.manual_seed(42) - numpy.random.seed(42) + # Create unique seed based on test parameters to ensure different inputs for each test + base_seed = 1000 # Different base seed from regular MoE tests + param_hash = hash((batch_size, sequence_length, quant_bits)) + unique_seed = base_seed + abs(param_hash) % 1000 + + torch.manual_seed(unique_seed) + numpy.random.seed(unique_seed) - test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + test_config = ( + f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, seed={unique_seed}" + ) print(f"Running SwiGLU test: {test_config}") config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2) @@ -1114,5 +1202,173 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): swiglu_moe.parity_check() +@unittest.skipIf(True, "Skipping QMoE CPU benchmark tests") +class TestQMoESwiGLUBenchmark(unittest.TestCase): + """Benchmark tests for QMoE SwiGLU performance measurement.""" + + def test_qmoe_swiglu_throughput_benchmark(self): + """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" + if disable_cpu_qmoe_tests: + self.skipTest("QMoE CPU tests disabled") + + print("\n=== QMoE SwiGLU Throughput Benchmark ===") + + # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits) + configs = [ + ("Medium-4bit", 2880, 2880, 32, 4, 4), + ("Medium-8bit", 2880, 2880, 32, 4, 8), + ] + + batch_size = 1 + sequence_length = 512 + num_runs = 30 + + results = [] + + for config_name, hidden_size, intermediate_size, num_experts, top_k, quant_bits in configs: + torch.manual_seed(42) + numpy.random.seed(42) + + print(f"\nTesting {config_name}:") + print(f" Hidden: {hidden_size}, Intermediate: {intermediate_size}") + print(f" Experts: {num_experts}, Top-K: {top_k}, Quant: {quant_bits}-bit") + + try: + # Create config and model + config = PhiMoEConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_local_experts=num_experts, + num_experts_per_tok=top_k, + ) + + qmoe_swiglu = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + # Create test input with fixed sequence length to match ONNX model + full_hidden_states = torch.randn(batch_size, sequence_length, hidden_size).to(torch.float32) + + # For TTFT simulation, we'll measure single forward pass time + # This represents the time to process one token in autoregressive generation + + # Initialize variables + torch_output = None + ort_output = None + + # Warm up with full context + for _ in range(3): + _ = qmoe_swiglu.forward(full_hidden_states) + + # Benchmark PyTorch TTFT (Time to First Token) + # Measure time for a single forward pass (represents token generation time) + torch.manual_seed(42) + + start_time = time.time() + for _ in range(num_runs): + torch_output = qmoe_swiglu.forward(full_hidden_states) + end_time = time.time() + torch_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second (throughput) + # For sequence generation, this represents the rate at which we can generate tokens + torch_tokens_per_sec = 1000.0 / torch_ttft_ms # 1 token / (time_ms / 1000) + + print(f" PyTorch TTFT: {torch_ttft_ms:.3f} ms (per token generation time)") + print(f" PyTorch Throughput: {torch_tokens_per_sec:.1f} tokens/sec") + + # Benchmark ONNX Runtime + ort_ttft_ms = 0 + ort_tokens_per_sec = 0 + speedup = 0 + throughput_ratio = 0 + max_diff = 0 + + model_updated = qmoe_swiglu.recreate_onnx_model() + if model_updated and qmoe_swiglu.ort_sess is not None: + # Warm up ORT with full context + for _ in range(3): + _ = qmoe_swiglu.ort_forward(full_hidden_states) + + torch.manual_seed(42) + + # Measure ONNX Runtime TTFT (Time to First Token) + start_time = time.time() + for _ in range(num_runs): + ort_output = qmoe_swiglu.ort_forward(full_hidden_states) + end_time = time.time() + ort_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second for ONNX Runtime + ort_tokens_per_sec = 1000.0 / ort_ttft_ms # 1 token / (time_ms / 1000) + + speedup = torch_ttft_ms / ort_ttft_ms if ort_ttft_ms > 0 else 0 + throughput_ratio = ort_tokens_per_sec / torch_tokens_per_sec if torch_tokens_per_sec > 0 else 0 + + print(f" ONNX RT TTFT: {ort_ttft_ms:.3f} ms (per token generation time)") + print(f" ONNX RT Throughput: {ort_tokens_per_sec:.1f} tokens/sec") + print(f" TTFT Speedup: {speedup:.2f}x") + print(f" Throughput Gain: {throughput_ratio:.2f}x") + else: + print(" ONNX RT: Not available") + + # Calculate max difference if both outputs available + if torch_output is not None and ort_output is not None: + max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max().item() + print(f" Max diff: {max_diff:.6f}") + + results.append( + { + "config": config_name, + "torch_ttft_ms": torch_ttft_ms, + "torch_tokens_per_sec": torch_tokens_per_sec, + "ort_ttft_ms": ort_ttft_ms, + "ort_tokens_per_sec": ort_tokens_per_sec, + "speedup": speedup, + "throughput_ratio": throughput_ratio, + "max_diff": max_diff, + } + ) + + except Exception as e: + print(f" Error: {e}") + continue + + # Summary + print("\n=== Token Generation Time & Throughput Summary ===") + print( + f"{'Config':<15} {'PT Time':<10} {'PT tok/s':<10} {'ORT Time':<11} {'ORT tok/s':<11} {'Time Gain':<10} {'Throughput':<11} {'Max Diff':<10}" + ) + print("-" * 105) + for result in results: + config = result["config"] + torch_ttft = result["torch_ttft_ms"] + torch_tps = result["torch_tokens_per_sec"] + ort_ttft = result["ort_ttft_ms"] + ort_tps = result["ort_tokens_per_sec"] + speedup = result["speedup"] + throughput_ratio = result["throughput_ratio"] + max_diff = result["max_diff"] + + ort_ttft_str = f"{ort_ttft:.3f}" if ort_ttft > 0 else "N/A" + ort_tps_str = f"{ort_tps:.1f}" if ort_tps > 0 else "N/A" + speedup_str = f"{speedup:.2f}x" if speedup > 0 else "N/A" + throughput_str = f"{throughput_ratio:.2f}x" if throughput_ratio > 0 else "N/A" + + print( + f"{config:<15} {torch_ttft:<10.3f} {torch_tps:<10.1f} {ort_ttft_str:<11} {ort_tps_str:<11} {speedup_str:<10} {throughput_str:<11} {max_diff:<10.6f}" + ) + + print("\nNotes:") + print("- Time: Token generation time in ms (lower is better)") + print("- tok/s: Tokens per second throughput (higher is better)") + print("- Time Gain: ORT speedup for latency (higher is better)") + print("- Throughput: ORT throughput improvement (higher is better)") + + if __name__ == "__main__": unittest.main()