Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
adcdca7
Fixes CPU kernel
apsonawane Aug 3, 2025
541e08b
Additional fixes
apsonawane Aug 3, 2025
764b55a
Optimizations
apsonawane Aug 3, 2025
27d05d5
Fix pipelines
apsonawane Aug 3, 2025
c1500b8
Address comments
apsonawane Aug 4, 2025
85268ad
Address comments
apsonawane Aug 4, 2025
37c0858
Revert "Address comments"
apsonawane Aug 4, 2025
85874ff
Fix the memory optimization issue
apsonawane Aug 4, 2025
1c9f927
Fix race condition
apsonawane Aug 6, 2025
f774682
Fix unused variables
apsonawane Aug 7, 2025
728d7a8
Optimizations
apsonawane Aug 12, 2025
c2386f5
Fix
apsonawane Aug 12, 2025
a6da84d
Debugging alot
apsonawane Aug 13, 2025
e2c5d68
Remove comments
apsonawane Aug 13, 2025
4c905ae
Some modifications
apsonawane Aug 20, 2025
c364758
FC1 fixed
apsonawane Aug 21, 2025
ed52e13
Working fix
apsonawane Aug 21, 2025
1ea12bc
Remove print statements
apsonawane Aug 21, 2025
f5be0ce
Low diff values
apsonawane Aug 22, 2025
e450158
Rebase with main
apsonawane Aug 22, 2025
471bb8b
Fix
apsonawane Aug 22, 2025
b015c3d
Fix tests
apsonawane Aug 22, 2025
2b67465
Fix pipelines
apsonawane Aug 22, 2025
f85a9f1
refactoring
tianleiwu Aug 22, 2025
1bcb20d
format
tianleiwu Aug 22, 2025
25aa31b
parallel optimization
tianleiwu Aug 23, 2025
ca180b6
fix build
tianleiwu Aug 23, 2025
6a48486
eliminate the intermediate memcpy after SwiGLU
tianleiwu Aug 23, 2025
c369322
parallelize the routing logic
tianleiwu Aug 23, 2025
73a437c
format
tianleiwu Aug 25, 2025
94a2729
refactoring output
tianleiwu Aug 25, 2025
5de1b21
Fix pipelines
apsonawane Aug 25, 2025
27c1c05
Update cpu tests to use same python reference implementation as cuda …
apsonawane Aug 25, 2025
81e6713
Fix tests
apsonawane Aug 26, 2025
d11f51c
Remove failing CPU test
apsonawane Aug 26, 2025
a7978f8
Add legacy shape check back
apsonawane Aug 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE);
// ******** End: Quantization ******************* //

#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
Expand Down Expand Up @@ -272,7 +273,8 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#include "core/common/common.h"
#include "core/framework/tensor_shape.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/quantization/moe_helper.h"
#include "contrib_ops/cpu/moe/moe_helper.h"
#include <limits>

Check warning on line 10 in onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: moe_base_cpu.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h:10: Found C++ system header after other header. Should be: moe_base_cpu.h, c system, c++ system, other. [build/include_order] [4]

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -46,12 +47,21 @@
if (use_sparse_mixer_) {
ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2");
}

swiglu_fusion_ = op_kernel_info.GetAttrOrDefault<int64_t>("swiglu_fusion", 0);
swiglu_limit_ = op_kernel_info.GetAttrOrDefault<float>("swiglu_limit", std::numeric_limits<float>::infinity());
activation_alpha_ = op_kernel_info.GetAttrOrDefault<float>("activation_alpha", 1.0f);
activation_beta_ = op_kernel_info.GetAttrOrDefault<float>("activation_beta", 0.0f);
}

bool normalize_routing_weights_;
bool use_sparse_mixer_;
int64_t k_;
ActivationType activation_type_;
float activation_alpha_;
float activation_beta_;
float swiglu_limit_;
int64_t swiglu_fusion_;
};

} // namespace contrib
Expand Down
393 changes: 393 additions & 0 deletions onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/moe/moe_base_cpu.h"

namespace onnxruntime {
namespace contrib {

/**
* @brief QMoE is the templated CPU implementation of the Quantized Mixture of Experts operator.
*
* This kernel supports both float and MLFloat16 data types for activations, scales, and outputs.
* It parallelizes expert computation using the ONNX Runtime thread pool and minimizes memory
* usage through on-the-fly block dequantization of weights.
*
* @tparam T The data type for the kernel (float or MLFloat16).
*/
template <typename T>
class QMoECPU final : public OpKernel, public MoEBaseCPU {
public:
explicit QMoECPU(const OpKernelInfo& op_kernel_info);
Status Compute(OpKernelContext* context) const override;

private:
int64_t expert_weight_bits_;
int64_t block_size_;
};

} // namespace contrib
} // namespace onnxruntime
66 changes: 12 additions & 54 deletions onnxruntime/contrib_ops/cpu/moe/moe_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "contrib_ops/cpu/moe/moe_utils.h"
#include <cmath>
#include <algorithm>
#include "core/common/common.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -19,74 +20,31 @@ float ApplyActivation(float x, ActivationType activation_type) {
case ActivationType::Identity:
return x;
case ActivationType::SwiGLU:
// SwiGLU: This is handled specially as it requires gating, not applied here
// SwiGLU is a special case handled by ApplySwiGLUActivation, this is just a placeholder
return x;
default:
return x; // Default to identity
return x;
}
}

// Helper method for applying SwiGLU activation with different memory layouts
void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format) {
constexpr float swiglu_alpha = 1.702f;
constexpr float clamp_limit = 7.0f; // Clamping limit as specified

void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format,
float activation_alpha, float activation_beta, float clamp_limit) {
if (is_interleaved_format) {
// For interleaved format [linear, gate, linear, gate, ...], process directly
// Make a temporary copy of each pair of values before modifying them
for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
const size_t linear_idx = 2 * idx;
const size_t gate_idx = linear_idx + 1;
float gate_val = input_data[2 * i];
float linear_val = input_data[2 * i + 1];

// Store original values
float linear_val = data[linear_idx]; // Interleaved: even index
float gate_val = data[gate_idx]; // Interleaved: odd index
gate_val = std::min(gate_val, clamp_limit);
linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit);

// Apply clamping to the values
if (gate_val > clamp_limit) gate_val = clamp_limit; // Clamp gate max only
if (linear_val > clamp_limit) linear_val = clamp_limit; // Clamp linear min/max
if (linear_val < -clamp_limit) linear_val = -clamp_limit;

// SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1)
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_arg = activation_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
float swish_out = gate_val * sigmoid_out;
float result = swish_out * (linear_val + 1.0f);

// Store result in first element (linear position)
data[idx] = result;
output_data[i] = swish_out * (linear_val + activation_beta);
}
} else {
// For chunked layout [linear..., gate...], handle separately
// Need to work with original data in-place
// First, store all the gate computations since they depend on original gate values
std::vector<float> computed_gates(static_cast<size_t>(inter_size));

for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
float gate_val = data[idx + static_cast<size_t>(inter_size)];

// Apply clamping to the gate value (max only)
if (gate_val > clamp_limit) gate_val = clamp_limit;

// Compute the gate part of SwiGLU
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
computed_gates[idx] = gate_val * sigmoid_out;
}

// Now apply the full activation with the precomputed gate values
for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
float linear_val = data[idx];

// Apply clamping to the linear value (min/max)
if (linear_val > clamp_limit) linear_val = clamp_limit;
if (linear_val < -clamp_limit) linear_val = -clamp_limit;

data[idx] = computed_gates[idx] * (linear_val + 1.0f);
}
ORT_NOT_IMPLEMENTED("Non-interleaved format not supported for SwiGLU activation");
}
}

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cpu/moe/moe_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ namespace onnxruntime {
namespace contrib {

float ApplyActivation(float x, ActivationType activation_type);
void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format);

void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format,
float activation_alpha, float activation_beta, float clamp_limit);

} // namespace contrib
} // namespace onnxruntime
Loading
Loading