diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index f3dcde1abe37a..b59ff63ea8260 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -3089,7 +3089,7 @@ This version of the operator has been available since version 1 of the 'com.micr
- activation_type : string
-- Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+- Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
- k : int
- Number of top experts to select from expert pool
- normalize_routing_weights : int
@@ -3106,9 +3106,9 @@ This version of the operator has been available since version 1 of the 'com.micr
- router_probs : T
- 2D input tensor with shape (num_rows, num_experts)
- fc1_experts_weights : T
-- 3D input tensor with shape (num_experts, hidden_size, inter_size)
+- 3D input tensor with shape (num_experts, hidden_size, inter_size), or (num_experts, hidden_size, 2 * inter_size) for swiglu
- fc1_experts_bias (optional) : T
-- 2D optional input tensor with shape (num_experts, inter_size)
+- 2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
- fc2_experts_weights : T
- 3D input tensor with shape (num_experts, inter_size, hidden_size)
- fc2_experts_bias (optional) : T
@@ -4523,7 +4523,7 @@ This version of the operator has been available since version 1 of the 'com.micr
- activation_type : string
-- Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+- Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
- expert_weight_bits : int
- Number of bits used in quantized weights. Default is 4 bits
- k : int
@@ -4542,11 +4542,11 @@ This version of the operator has been available since version 1 of the 'com.micr
- router_probs : T
- 2D input tensor with shape (num_rows, num_experts)
- fc1_experts_weights : T1
-- 3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
+- 3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).
- fc1_scales : T
-- 2D input tensor with shape (num_experts, inter_size)
+- 2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
- fc1_experts_bias (optional) : T
-- 2D optional input tensor with shape (num_experts, inter_size)
+- 2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
- fc2_experts_weights : T1
- 3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
- fc2_scales : T
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index fa6c731231405..3b70e5da8b3e4 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -43,6 +43,7 @@ Do not modify directly.*
|||[7, 21]|**T** = tensor(float)|
|Atanh|*in* input:**T**
*out* output:**T**|22+|**T** = tensor(float)|
|||[9, 21]|**T** = tensor(float)|
+|Attention|*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**|23+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)
**U** = tensor(bool), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)|
|||[19, 21]|**T** = tensor(float)|
|||[11, 18]|**T** = tensor(float)|
@@ -58,11 +59,11 @@ Do not modify directly.*
|BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BlackmanWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
+|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
+|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
+|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
+|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
|Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
index 1a4a63de38790..e8cdc50ed4ca7 100644
--- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
+++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
@@ -78,8 +78,11 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");
- ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr,
- normalize_routing_weights_, use_sparse_mixer_);
+ ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm,
+ activation_type_,
+ fc3_experts_weights_optional != nullptr,
+ normalize_routing_weights_,
+ use_sparse_mixer_);
size_t ws_size = moe_runner.getWorkspaceSize(
static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size),
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
index 36127054cfd5e..d5ad8161e100e 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
@@ -52,6 +52,7 @@ enum class ActivationType { Gelu,
GeGLU,
ReGLU,
SiGLU,
+ SwiGLU,
Identity,
InvalidType };
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
index ef1f97b9e57a2..8b8f45e77ab9d 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
@@ -391,12 +391,10 @@ void MoeGemmRunner::dispatch_to_arch(const T* A, con
dispatch_moe_gemm_to_cutlass(
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
sm_, multi_processor_count_, stream, occupancy);
- } else if (sm_ >= 80 && sm_ < 90) {
+ } else if (sm_ >= 80) { // Hopper and Blackwell will fallback to use Ampere kernels.
dispatch_moe_gemm_to_cutlass(
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
sm_, multi_processor_count_, stream, occupancy);
- } else {
- ORT_THROW("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM");
}
}
@@ -478,6 +476,7 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp
int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
int num_experts, ActivationType activation_type,
cudaStream_t stream) {
+ // Swiglu will use Identity to call this function so we not need to handle it here.
switch (activation_type) {
case ActivationType::Relu:
run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n,
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
index bfbe1d81b1c15..4268b79e1e4f8 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
@@ -44,6 +44,72 @@
namespace ort_fastertransformer {
static constexpr int WARP_SIZE = 32;
+
+// SwiGLU with interleaved is like the following python code using PyTorch:
+// dim = x.shape[-1]
+// x = x.view(-1, dim // 2, 2)
+// x_glu, x_linear = x[..., 0], x[..., 1]
+// y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1)
+template
+__global__ void swiglu_kernel_interleaved(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha) {
+ int const row = blockIdx.x;
+ if (row >= num_rows) {
+ return;
+ }
+
+ T const* row_input = input + row * 2 * intermediate_size;
+ T* row_output = output + row * intermediate_size;
+
+ for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) {
+ T x_glu = row_input[2 * i];
+ T x_linear = row_input[2 * i + 1];
+
+ float sigmoid_arg = swiglu_alpha * static_cast(x_glu);
+ float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg));
+
+ float swish_out = static_cast(x_glu) * sigmoid_out;
+ row_output[i] = static_cast(swish_out * (static_cast(x_linear) + 1.f));
+ }
+}
+
+// Non interleaved version of SwiGLU kernel, which splits each row into two chunks of same size.
+template
+__global__ void swiglu_kernel_chunked(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha) {
+ int const row = blockIdx.x;
+ if (row >= num_rows) {
+ return;
+ }
+
+ T const* row_input = input + row * 2 * intermediate_size;
+ T* row_output = output + row * intermediate_size;
+
+ for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) {
+ T x_glu = row_input[i];
+ T x_linear = row_input[i + intermediate_size];
+
+ float sigmoid_arg = swiglu_alpha * static_cast(x_glu);
+ float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg));
+
+ float swish_out = static_cast(x_glu) * sigmoid_out;
+ row_output[i] = static_cast(swish_out * (static_cast(x_linear) + 1.f));
+ }
+}
+
+template
+void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha, cudaStream_t stream) {
+ if (num_rows == 0) {
+ return;
+ }
+ dim3 block(std::min(intermediate_size, 1024));
+ dim3 grid(num_rows);
+
+ if constexpr (interleaved) {
+ swiglu_kernel_interleaved<<>>(output, input, intermediate_size, num_rows, swiglu_alpha);
+ } else {
+ swiglu_kernel_chunked<<>>(output, input, intermediate_size, num_rows, swiglu_alpha);
+ }
+}
+
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
@@ -666,9 +732,14 @@ __global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, i
}
template
-CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3,
+CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3,
bool normalize_routing_weights, bool use_sparse_mixer)
- : has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0), normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) {
+ : activation_type_(activation_type),
+ has_fc3_(has_fc3),
+ total_past_rows_(0),
+ total_covered_rows_(0),
+ normalize_routing_weights_(normalize_routing_weights),
+ use_sparse_mixer_(use_sparse_mixer) {
moe_gemm_runner_.initialize(sm_version);
}
@@ -695,8 +766,16 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro
total_ws_bytes += buf_size * sizeof(T); // permuted_data
total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
total_ws_bytes += num_softmax_outs * sizeof(T);
- const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T);
- const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows));
+
+ size_t bytes_for_fc1_result;
+ if (activation_type_ == ActivationType::SwiGLU) {
+ // Space for both fc1_result_ and act_result_.
+ bytes_for_fc1_result = (2 * interbuf_size + interbuf_size) * sizeof(T);
+ } else {
+ bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T);
+ }
+
+ const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows));
sorter_.update_num_experts(static_cast(num_experts));
size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
@@ -705,7 +784,7 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro
bytes_for_intermediate_and_sorting += remaining_bytes;
}
- total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace
+ total_ws_bytes += bytes_for_intermediate_and_sorting;
return total_ws_bytes;
}
@@ -725,16 +804,34 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr,
total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size);
+ char* current_ptr = reinterpret_cast(total_rows_before_expert_ + padded_experts);
+
+ if (activation_type_ == ActivationType::SwiGLU) {
+ // fc1_result_ is used for GEMM1 output (2 * inter_size)
+ fc1_result_ = reinterpret_cast(current_ptr);
+ current_ptr += 2 * interbuf_size * sizeof(T);
+
+ // act_result_ is used for SwiGLU output (inter_size)
+ act_result_ = reinterpret_cast(current_ptr);
+ current_ptr += interbuf_size * sizeof(T);
+
+ ORT_ENFORCE(!has_fc3_, "SwiGLU activation is not supported with fc3");
+ } else {
+ fc1_result_ = reinterpret_cast(current_ptr);
+ act_result_ = nullptr; // No extra buffer for activation since it is done inplace.
+ current_ptr += interbuf_size * sizeof(T);
+ }
+
if (has_fc3_) {
- fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts);
- fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size);
+ fc3_result_ = reinterpret_cast(current_ptr);
+ current_ptr += interbuf_size * sizeof(T);
} else {
- fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts);
+ fc3_result_ = nullptr;
}
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
- softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size);
+ softmax_out_ = reinterpret_cast(current_ptr);
} else {
softmax_out_ = nullptr;
}
@@ -880,8 +977,51 @@ void CutlassMoeFCRunner::run_moe_fc(
stream);
}
- // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size,
- // expanded_active_expert_rows);
+ if (fc1_activation_type == ActivationType::SwiGLU) {
+ T* gemm1_output_buffer = fc1_result_;
+ T* swiglu_output_buffer = act_result_;
+
+ moe_gemm_runner_.moe_gemm_bias_act(
+ permuted_data_ + total_past_rows_ * hidden_size,
+ fc1_expert_weights,
+ fc1_scales,
+ fc1_expert_biases,
+ gemm1_output_buffer + total_past_rows_ * 2 * inter_size,
+ total_rows_before_expert_ + local_experts_start_index,
+ expanded_active_expert_rows,
+ 2 * inter_size,
+ hidden_size,
+ local_num_experts,
+ ActivationType::Identity,
+ stream);
+
+ constexpr bool swiglu_interleaved = true;
+ constexpr float swiglu_alpha = 1.702f;
+ invokeSwiGLU(
+ swiglu_output_buffer + total_past_rows_ * inter_size,
+ gemm1_output_buffer + total_past_rows_ * 2 * inter_size,
+ inter_size,
+ static_cast(total_covered_rows_),
+ swiglu_alpha,
+ stream);
+
+ moe_gemm_runner_.moe_gemm(
+ swiglu_output_buffer + total_past_rows_ * inter_size,
+ fc2_expert_weights,
+ fc2_scales,
+ nullptr,
+ fc2_result + total_past_rows_ * hidden_size,
+ total_rows_before_expert_ + local_experts_start_index,
+ expanded_active_expert_rows,
+ hidden_size,
+ inter_size,
+ local_num_experts,
+ stream);
+
+ // No fc3 for SwiGLU
+ return;
+ }
+
moe_gemm_runner_.moe_gemm_bias_act(
permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases,
fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index,
@@ -1178,4 +1318,7 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl
template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*,
const half*, const int*, const int*, int, int, int, cudaStream_t);
+template void invokeSwiGLU(float*, float const*, int, int, float, cudaStream_t);
+template void invokeSwiGLU(half*, half const*, int, int, float, cudaStream_t);
+
} // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
index c457b608decbf..3ac4862e101c3 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
@@ -54,7 +54,10 @@ static inline size_t pad_to_multiple_of_16(size_t input) {
template
void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_out,
int* indices, int* source_row, int num_rows, int num_experts, int k,
- cudaStream_t stream);
+ bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream);
+
+template
+void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha, cudaStream_t stream);
class CubKeyValueSorter {
public:
@@ -109,7 +112,7 @@ template
class CutlassMoeFCRunner {
public:
- CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);
+ CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);
size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k);
@@ -157,8 +160,10 @@ class CutlassMoeFCRunner {
int64_t* total_rows_before_expert_;
T* fc1_result_;
+ T* act_result_;
T* fc3_result_;
+ ActivationType activation_type_;
bool has_fc3_;
bool normalize_routing_weights_;
bool use_sparse_mixer_;
@@ -176,7 +181,7 @@ class CutlassMoeFCRunner {
template
class CutlassMoeFCRunner::value>> {
public:
- CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);
+ CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);
size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) {
return 0;
diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc
index c5352d931ce2c..cc6fe871a3bc1 100644
--- a/onnxruntime/contrib_ops/cuda/moe/moe.cc
+++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc
@@ -48,8 +48,11 @@ Status MoE::ComputeInternal(OpKernelContext* context) const {
auto& device_prop = GetDeviceProp();
const int sm = device_prop.major * 10 + device_prop.minor;
- ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr,
- normalize_routing_weights_, use_sparse_mixer_);
+ ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm,
+ activation_type_,
+ fc3_experts_weights_optional != nullptr,
+ normalize_routing_weights_,
+ use_sparse_mixer_);
size_t ws_size = moe_runner.getWorkspaceSize(
static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size),
diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h
index 6b65557444a66..194f33acbeb59 100644
--- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h
+++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h
@@ -76,15 +76,16 @@ class MoEBase {
}
const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1;
- if (fc1_experts_weights_dims[2] != inter_size / coe) {
+ const int64_t act = activation_type_ == ort_fastertransformer::ActivationType::SwiGLU ? 2 : 1;
+ if (fc1_experts_weights_dims[2] != act * inter_size / coe) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "fc1_experts_weights_dims[2] must be equal to inter_size, got ",
- fc1_experts_weights_dims[2], " and ", inter_size);
+ "fc1_experts_weights_dims[2] is ",
+ fc1_experts_weights_dims[2], " expected ", act * inter_size / coe);
}
if (fc2_experts_weights_dims[2] != hidden_size / coe) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "fc2_experts_weights_dims[2] must be equal to hidden_size, got ",
- fc2_experts_weights_dims[2], " and ", hidden_size);
+ "fc2_experts_weights_dims[2] is ",
+ fc2_experts_weights_dims[2], " expected ", hidden_size / coe);
}
if (router_probs_dims.size() != 2) {
@@ -116,10 +117,10 @@ class MoEBase {
"fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0],
" and ", num_experts);
}
- if (fc1_experts_bias_dims[1] != inter_size) {
+ if (fc1_experts_bias_dims[1] != act * inter_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1],
- " and ", inter_size);
+ "fc1_experts_bias_dims[1] is ", fc1_experts_bias_dims[1],
+ ", expected ", act * inter_size);
}
if (fc2_experts_bias_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@@ -182,10 +183,14 @@ class MoEBase {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ",
fc1_experts_scales_dims[0], " and ", num_experts);
}
- if (fc1_experts_scales_dims[1] != inter_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to inter_size, got ",
- fc1_experts_scales_dims[1], " and ", inter_size);
+
+ // The activation type affects the output dimension of the first FC layer.
+ const int64_t act = activation_type_ == ort_fastertransformer::ActivationType::SwiGLU ? 2 : 1;
+ if (fc1_experts_scales_dims[1] != act * inter_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to act * inter_size, got ",
+ fc1_experts_scales_dims[1], " and ", act * inter_size);
}
+
if (fc2_experts_scales_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ",
fc2_experts_scales->Shape().GetDims().size());
@@ -219,6 +224,8 @@ class MoEBase {
activation_type_ = ort_fastertransformer::ActivationType::Gelu;
} else if (activation_type_str == "silu") {
activation_type_ = ort_fastertransformer::ActivationType::Silu;
+ } else if (activation_type_str == "swiglu") {
+ activation_type_ = ort_fastertransformer::ActivationType::SwiGLU;
} else if (activation_type_str == "identity") {
activation_type_ = ort_fastertransformer::ActivationType::Identity;
} else {
diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
index 4dd5a079d1a29..db6d99674cf5a 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
@@ -72,6 +72,7 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context,
using CudaT = typename ToCudaType::MappedType;
ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm,
+ activation_type_,
fc3_experts_weights_optional != nullptr,
normalize_routing_weights_,
use_sparse_mixer_);
@@ -185,4 +186,4 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
} // namespace cuda
} // namespace contrib
-} // namespace onnxruntime
\ No newline at end of file
+} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 5511275239e45..39bf2bf855976 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1392,14 +1392,14 @@ constexpr const char* MoE_ver1_doc = R"DOC(
ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1,
OpSchema()
.SetDoc(MoE_ver1_doc)
- .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu"))
+ .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu", AttributeProto::STRING, std::string("relu"))
.Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1))
.Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0))
.Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0))
.Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
.Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T")
- .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T")
- .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional)
+ .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size), or (num_experts, hidden_size, 2 * inter_size) for swiglu", "T")
+ .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional)
.Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T")
.Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional)
.Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", "T", OpSchema::Optional)
@@ -1413,7 +1413,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema()
.SetDoc("Quantized MoE")
.Attr("activation_type",
- "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu",
+ "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu",
AttributeProto::STRING,
std::string("relu"))
.Attr("k",
@@ -1438,12 +1438,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Input(2,
"fc1_experts_weights",
"3D input tensor with shape (num_experts, hidden_size, inter_size) "
- "or (num_experts, hidden_size, inter_size / 2)",
+ "or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).",
"T1")
- .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T")
+ .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T")
.Input(4,
"fc1_experts_bias",
- "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional)
+ "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional)
.Input(5,
"fc2_experts_weights",
"3D input tensor with shape (num_experts, inter_size, hidden_size) "
diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py
index 252d89a2257fc..d805c8f9cae3c 100644
--- a/onnxruntime/test/python/transformers/test_parity_moe.py
+++ b/onnxruntime/test/python/transformers/test_parity_moe.py
@@ -9,6 +9,7 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
+import itertools
import unittest
from collections import OrderedDict
@@ -24,11 +25,6 @@
torch.manual_seed(42)
numpy.random.seed(42)
-USE_QUANT = False
-ORT_DTYPE = TensorProto.FLOAT16 if USE_QUANT else TensorProto.FLOAT
-NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32
-THRESHOLD = 5e-1 if USE_QUANT else 1e-2
-
def value_string_of(numpy_array):
arr = numpy_array.flatten()
@@ -40,26 +36,69 @@ def print_tensor(name, numpy_array):
print(f"const std::vector {name} = {value_string_of(numpy_array)};")
-def quant_dequant(weights, quant_mode: bool = True):
- # use the test version `_symmetric_...` to get the non-interleaved weights
- type = torch.quint4x2 if quant_mode else torch.int8
- # This import is needed to use torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix()
- # Comment out this line for passing the lintrunner check in the CI.
- # import tensorrt_llm
+def quant_dequant(weights: torch.Tensor, is_4_bit_quantization: bool):
+ """
+ Performs symmetric per-column quantization and dequantization on a weight tensor.
+
+ This implementation is a pure PyTorch replacement for the original function that
+ relied on a custom tensorrt_llm operator. It supports both 8-bit (int8) and
+ 4-bit (quint4x2 style) quantization.
+
+ Args:
+ weights (torch.Tensor): The input weight tensor to be quantized.
+ is_4_bit_quantization (bool): If True, performs 4-bit quantization. If False,
+ performs 8-bit quantization.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
+ - scales (torch.float16): The quantization scales for each column.
+ - processed_q_weight (torch.int8): The packed quantized weights. For
+ 4-bit mode, two 4-bit values are packed into a single int8. For
+ 8-bit mode, this is the standard int8 quantized tensor. It is
+ transposed relative to the input weights' shape.
+ - dequantized_weights (torch.Tensor): The weights after being dequantized,
+ restored to the original dtype and device.
+ """
+ # Determine quantization bits and range based on the mode
+ if is_4_bit_quantization:
+ # 4-bit symmetric quantization path
+ q_bits = 4
+ q_max = 2 ** (q_bits - 1) - 1 # 7
+ q_min = -(2 ** (q_bits - 1)) # -8
- quant_weights, processed_q_weight, torch_weight_scales = (
- torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(weights.T.cpu().contiguous(), type)
- )
+ max_abs_val = torch.max(torch.abs(weights), dim=0, keepdim=True).values
+ max_abs_val[max_abs_val == 0] = 1.0
+ scales = max_abs_val / q_max
+
+ quant_weights = torch.round(weights / scales).clamp(q_min, q_max).to(torch.int8)
+
+ # Pack two 4-bit integers into a single int8
+ q_weights_t = quant_weights.T.contiguous()
+ shape = q_weights_t.shape
+ q_weights_t_reshaped = q_weights_t.view(shape[0], shape[1] // 2, 2)
+ lower_nibble = q_weights_t_reshaped[..., 0]
+ upper_nibble = q_weights_t_reshaped[..., 1]
+ processed_q_weight = (lower_nibble & 0x0F) | (upper_nibble << 4)
+
+ else:
+ # 8-bit symmetric quantization path
+ q_bits = 8
+ q_max = 2 ** (q_bits - 1) - 1 # 127
+ q_min = -(2 ** (q_bits - 1)) # -128
- # Unpack the int4s int int8s
- if quant_mode:
- upper = quant_weights >> 4
- lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends
- quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape)
+ max_abs_val = torch.max(torch.abs(weights), dim=0, keepdim=True).values
+ max_abs_val[max_abs_val == 0] = 1.0
+ scales = max_abs_val / q_max
- quant_weights = quant_weights.to(dtype=weights.dtype)
- result = torch.multiply(quant_weights, torch_weight_scales.unsqueeze(0)).T.contiguous()
- return torch_weight_scales.to(torch.float16), processed_q_weight, result.to(device=weights.device)
+ quant_weights = torch.round(weights / scales).clamp(q_min, q_max).to(torch.int8)
+
+ # For 8-bit, the processed weights are just the transposed quantized weights (no packing)
+ processed_q_weight = quant_weights.T.contiguous()
+
+ # Dequantize the weights to verify and return for PyTorch-side parity check
+ dequantized_weights = quant_weights.to(weights.dtype) * scales.to(weights.dtype)
+
+ return (scales.squeeze(0).to(torch.float16), processed_q_weight, dequantized_weights.T.to(device=weights.device))
def create_moe_onnx_graph(
@@ -71,6 +110,7 @@ def create_moe_onnx_graph(
fc1_experts_bias,
fc2_experts_weights,
fc2_experts_bias,
+ ort_dtype,
):
nodes = [
helper.make_node(
@@ -94,19 +134,19 @@ def create_moe_onnx_graph(
fc1_shape = [num_experts, hidden_size, inter_size]
fc2_shape = [num_experts, inter_size, hidden_size]
- torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32
+ torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32
initializers = [
helper.make_tensor(
"fc1_experts_weights",
- ORT_DTYPE,
+ ort_dtype,
fc1_shape,
fc1_experts_weights.to(torch_type).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc2_experts_weights",
- ORT_DTYPE,
+ ort_dtype,
fc2_shape,
fc2_experts_weights.to(torch_type).flatten().tolist(),
raw=False,
@@ -119,14 +159,14 @@ def create_moe_onnx_graph(
[
helper.make_tensor(
"fc1_experts_bias",
- ORT_DTYPE,
+ ort_dtype,
fc1_bias_shape,
fc1_experts_bias.to(torch_type).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc2_experts_bias",
- ORT_DTYPE,
+ ort_dtype,
fc2_bias_shape,
fc2_experts_bias.to(torch_type).flatten().tolist(),
raw=False,
@@ -135,19 +175,19 @@ def create_moe_onnx_graph(
)
graph_inputs = [
- helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]),
]
graph_inputs.append(
helper.make_tensor_value_info(
"router_probs",
- ORT_DTYPE,
+ ort_dtype,
[sequence_length, num_experts],
)
)
graph_outputs = [
- helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]),
]
graph = helper.make_graph(
@@ -171,6 +211,7 @@ def create_mixtral_moe_onnx_graph(
fc2_experts_weights,
fc3_experts_weights,
topk,
+ ort_dtype,
):
nodes = [
helper.make_node(
@@ -197,26 +238,26 @@ def create_mixtral_moe_onnx_graph(
fc2_shape = [num_experts, inter_size, hidden_size]
fc3_shape = [num_experts, hidden_size, inter_size]
- torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32
+ torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32
initializers = [
helper.make_tensor(
"fc1_experts_weights",
- ORT_DTYPE,
+ ort_dtype,
fc1_shape,
fc1_experts_weights.to(torch_type).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc2_experts_weights",
- ORT_DTYPE,
+ ort_dtype,
fc2_shape,
fc2_experts_weights.to(torch_type).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc3_experts_weights",
- ORT_DTYPE,
+ ort_dtype,
fc3_shape,
fc3_experts_weights.to(torch_type).flatten().tolist(),
raw=False,
@@ -224,19 +265,19 @@ def create_mixtral_moe_onnx_graph(
]
graph_inputs = [
- helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]),
]
graph_inputs.append(
helper.make_tensor_value_info(
"router_probs",
- ORT_DTYPE,
+ ort_dtype,
[sequence_length, num_experts],
)
)
graph_outputs = [
- helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]),
]
graph = helper.make_graph(
@@ -259,12 +300,14 @@ def create_phi_moe_onnx_graph(
fc1_experts_weights,
fc2_experts_weights,
fc3_experts_weights,
- fc1_scales,
- fc2_scales,
- fc3_scales,
topk,
+ ort_dtype,
+ quant_bits=0,
+ fc1_scales=None,
+ fc2_scales=None,
+ fc3_scales=None,
):
- use_quant = USE_QUANT
+ use_quant = quant_bits > 0
if use_quant:
assert fc1_experts_weights.dtype == torch.int8
assert fc2_experts_weights.dtype == torch.int8
@@ -276,34 +319,37 @@ def create_phi_moe_onnx_graph(
assert fc2_scales.dtype == torch.float16
assert fc3_scales.dtype == torch.float16
+ op_name = "QMoE" if use_quant else "MoE"
+ inputs = (
+ [
+ "input",
+ "router_probs",
+ "fc1_experts_weights",
+ "fc1_scales",
+ "",
+ "fc2_experts_weights",
+ "fc2_scales",
+ "",
+ "fc3_experts_weights",
+ "fc3_scales",
+ "",
+ ]
+ if use_quant
+ else [
+ "input",
+ "router_probs",
+ "fc1_experts_weights",
+ "",
+ "fc2_experts_weights",
+ "",
+ "fc3_experts_weights",
+ ]
+ )
+
nodes = [
helper.make_node(
- "MoE" if not use_quant else "QMoE",
- (
- [
- "input",
- "router_probs",
- "fc1_experts_weights",
- "",
- "fc2_experts_weights",
- "",
- "fc3_experts_weights",
- ]
- if not use_quant
- else [
- "input",
- "router_probs",
- "fc1_experts_weights",
- "fc1_scales",
- "",
- "fc2_experts_weights",
- "fc2_scales",
- "",
- "fc3_experts_weights",
- "fc3_scales",
- "",
- ]
- ),
+ op_name,
+ inputs,
["output"],
"MoE_0",
k=topk,
@@ -315,37 +361,38 @@ def create_phi_moe_onnx_graph(
]
if use_quant:
- nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", 8)])
+ nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)])
- fc1_shape = [num_experts, hidden_size, inter_size]
- fc2_shape = [num_experts, inter_size, hidden_size]
- fc3_shape = [num_experts, hidden_size, inter_size]
+ components = 2 if quant_bits == 4 else 1
+ fc1_shape = [num_experts, hidden_size, inter_size // components]
+ fc2_shape = [num_experts, inter_size, hidden_size // components]
+ fc3_shape = [num_experts, hidden_size, inter_size // components]
- torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32
- numpy_type = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32
- if use_quant:
- numpy_type = numpy.uint8
+ torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32
+ numpy_type = numpy.float16 if ort_dtype == TensorProto.FLOAT16 else numpy.float32
+ weight_numpy_type = numpy.uint8 if use_quant else numpy_type
+ weight_onnx_type = TensorProto.UINT8 if use_quant else ort_dtype
initializers = [
helper.make_tensor(
"fc1_experts_weights",
- ORT_DTYPE if not use_quant else TensorProto.UINT8,
+ weight_onnx_type,
fc1_shape,
- fc1_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(),
+ fc1_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(),
raw=False,
),
helper.make_tensor(
"fc2_experts_weights",
- ORT_DTYPE if not use_quant else TensorProto.UINT8,
+ weight_onnx_type,
fc2_shape,
- fc2_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(),
+ fc2_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(),
raw=False,
),
helper.make_tensor(
"fc3_experts_weights",
- ORT_DTYPE if not use_quant else TensorProto.UINT8,
+ weight_onnx_type,
fc3_shape,
- fc3_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(),
+ fc3_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(),
raw=False,
),
]
@@ -358,21 +405,21 @@ def create_phi_moe_onnx_graph(
[
helper.make_tensor(
"fc1_scales",
- ORT_DTYPE,
+ ort_dtype,
fc1_scale_shape,
fc1_scales.to(torch_type).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc2_scales",
- ORT_DTYPE,
+ ort_dtype,
fc2_scale_shape,
fc2_scales.to(torch_type).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc3_scales",
- ORT_DTYPE,
+ ort_dtype,
fc3_scale_shape,
fc3_scales.to(torch_type).flatten().tolist(),
raw=False,
@@ -381,19 +428,19 @@ def create_phi_moe_onnx_graph(
)
graph_inputs = [
- helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]),
]
graph_inputs.append(
helper.make_tensor_value_info(
"router_probs",
- ORT_DTYPE,
+ ort_dtype,
[sequence_length, num_experts],
)
)
graph_outputs = [
- helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]),
]
graph = helper.make_graph(
@@ -546,8 +593,11 @@ def __init__(self, config: PhiMoEConfig):
class SparseMoeBlockORTHelper(nn.Module):
- def __init__(self):
+ def __init__(self, quant_bits=0):
super().__init__()
+ self.quant_bits = quant_bits
+ self.ort_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT
+ self.np_type = numpy.float16 if self.ort_dtype == TensorProto.FLOAT16 else numpy.float32
def create_ort_session(self, moe_onnx_graph):
from onnxruntime import InferenceSession, SessionOptions # noqa: PLC0415
@@ -573,8 +623,8 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten
router_logits = self.gate(hidden_states)
ort_inputs = {
- "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)),
- "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)),
+ "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(self.np_type)),
+ "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(self.np_type)),
}
ort_output = None
@@ -586,13 +636,6 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten
self.ort_run_with_iobinding(ort_inputs)
return None
- # print_tensor("input", ort_inputs["input"])
- # print_tensor("router_probs", ort_inputs["router_probs"])
- # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy())
- # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy())
- # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy())
- # print_tensor("output", ort_output[0])
-
return None
def ort_run_with_iobinding(self, ort_inputs, repeat=1000):
@@ -603,7 +646,7 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000):
name="input",
device_type="cuda",
device_id=device_id,
- element_type=NP_TYPE,
+ element_type=self.np_type,
shape=ort_inputs["input"].shape,
buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(),
)
@@ -612,7 +655,7 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000):
name="router_probs",
device_type="cuda",
device_id=device_id,
- element_type=NP_TYPE,
+ element_type=self.np_type,
shape=ort_inputs["router_probs"].shape,
buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(
ort_inputs["router_probs"], "cuda", device_id
@@ -623,7 +666,7 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000):
name="output",
device_type="cuda",
device_id=device_id,
- element_type=NP_TYPE,
+ element_type=self.np_type,
shape=ort_inputs["input"].shape,
buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(
numpy.zeros(ort_inputs["input"].shape), "cuda", device_id
@@ -646,22 +689,27 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000):
e = time.time()
print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms")
- def parity_check(self):
+ def parity_check(self, atol=None, rtol=None):
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim)
torch_output = self.forward(hidden_state)
ort_output = self.ort_forward(hidden_state)
+
+ if atol is None:
+ atol = 5e-2 if self.quant_bits == 0 else (2.0 if self.quant_bits == 8 else 3.0)
+
+ if rtol is None:
+ rtol = 1e-3 if self.quant_bits == 0 else 1e-2
+
if ort_output is not None:
+ dtype_str = "FP32" if self.quant_bits == 0 else "FP16"
print(
- "name:",
- self.__class__.__name__,
- " batch_size:",
- self.batch_size,
- " sequence_length:",
- self.sequence_length,
- " max_diff:",
- (torch_output - ort_output).abs().max(),
+ f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str},"
+ f" batch: {self.batch_size}, seq_len: {self.sequence_length},"
+ f" max_diff: {(torch_output - ort_output).abs().max()}"
+ )
+ torch.testing.assert_close(
+ ort_output.to(torch.float32), torch_output.to(torch.float32), rtol=rtol, atol=atol
)
- torch.testing.assert_close(ort_output.to(torch.float32), torch_output, rtol=THRESHOLD, atol=THRESHOLD)
def benchmark_ort(self):
hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim)
@@ -680,7 +728,7 @@ def __init__(
eval_capacity=-1,
activation="gelu",
):
- super().__init__()
+ super().__init__(quant_bits=0) # SwitchMoE is not quantized
self.batch_size = batch_size
self.sequence_length = sequence_length
self.num_experts = num_experts
@@ -709,6 +757,7 @@ def __init__(
self.moe_experts.bias1,
self.moe_experts.weight2.transpose(1, 2),
self.moe_experts.bias2,
+ self.ort_dtype,
)
self.ort_sess = self.create_ort_session(self.moe_onnx_graph)
@@ -744,7 +793,7 @@ class MixtralSparseMoeBlock(SparseMoeBlockORTHelper):
"""
def __init__(self, config, batch_size, sequence_length):
- super().__init__()
+ super().__init__(quant_bits=0) # Mixtral test is not quantized
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
@@ -778,6 +827,7 @@ def __init__(self, config, batch_size, sequence_length):
self.moe_experts_weight2,
self.moe_experts_weight3,
self.top_k,
+ self.ort_dtype,
)
self.ort_sess = self.create_ort_session(self.moe_onnx_graph)
@@ -874,43 +924,44 @@ class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper):
and memory on padding.
"""
- def __init__(self, config, batch_size, sequence_length):
- super().__init__()
+ def __init__(self, config, batch_size, sequence_length, quant_bits=0):
+ super().__init__(quant_bits)
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
self.router_jitter_noise = config.router_jitter_noise
+ use_quant = self.quant_bits > 0
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = nn.ModuleList([PhiMoEBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
- w1_list = []
- w2_list = []
- w3_list = []
- w1_scale_list = []
- w2_scale_list = []
- w3_scale_list = []
- if not USE_QUANT:
+ w1_list, w2_list, w3_list = [], [], []
+ w1_scale_list, w2_scale_list, w3_scale_list = [], [], []
+
+ if not use_quant:
for i in range(self.num_experts):
w1_list.append(self.experts[i].w1.weight)
w2_list.append(self.experts[i].w2.weight)
w3_list.append(self.experts[i].w3.weight)
else:
+ is_4_bit = self.quant_bits == 4
for i in range(self.num_experts):
- w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, False)
- w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, False)
- w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, False)
+ # Corrected quantization logic for per-output-channel quantization
+ w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight.T, is_4_bit)
+ w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight.T, is_4_bit)
+ w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight.T, is_4_bit)
self.experts[i].w1.weight.data = w1_qdq
self.experts[i].w2.weight.data = w2_qdq
self.experts[i].w3.weight.data = w3_qdq
- w1_list.append(pre_qweight1)
- w2_list.append(pre_qweight2)
- w3_list.append(pre_qweight3)
+ # Transpose quantized weights to match the expected ONNX layout
+ w1_list.append(pre_qweight1.T)
+ w2_list.append(pre_qweight2.T)
+ w3_list.append(pre_qweight3.T)
w1_scale_list.append(w1_scale)
w2_scale_list.append(w2_scale)
w3_scale_list.append(w3_scale)
@@ -919,9 +970,9 @@ def __init__(self, config, batch_size, sequence_length):
self.moe_experts_weight2 = torch.stack(w2_list, dim=0)
self.moe_experts_weight3 = torch.stack(w3_list, dim=0)
- moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if USE_QUANT else None
- moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if USE_QUANT else None
- moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if USE_QUANT else None
+ moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if use_quant else None
+ moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if use_quant else None
+ moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if use_quant else None
self.batch_size = batch_size
self.sequence_length = sequence_length
@@ -933,10 +984,12 @@ def __init__(self, config, batch_size, sequence_length):
self.moe_experts_weight1,
self.moe_experts_weight2,
self.moe_experts_weight3,
+ self.top_k,
+ self.ort_dtype,
+ self.quant_bits,
moe_experts_weight_scale1,
moe_experts_weight_scale2,
moe_experts_weight_scale3,
- self.top_k,
)
self.ort_sess = self.create_ort_session(self.moe_onnx_graph)
@@ -992,19 +1045,23 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def small_test_cases():
for batch_size in [1, 4, 16]:
for sequence_length in [128, 512, 1024]:
- yield batch_size, sequence_length
+ yield batch_size, sequence_length, 0
-def phi3_test_cases():
- # TODO: phi3 moe failed in long sequence lengths (max diff 0.22 > threshold 0.01), need investigation.
- for batch_size in [1, 4, 16]:
- for sequence_length in [128]:
- yield batch_size, sequence_length
+# Test cases for Phi-3 MoE.
+# We test three modes: no quantization, 8-bit, and 4-bit.
+phi3_test_params = list(
+ itertools.product(
+ [1, 4], # batch_size
+ [1, 32], # sequence_length
+ [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16)
+ )
+)
class TestSwitchMoE(unittest.TestCase):
@parameterized.expand(small_test_cases())
- def test_switch_moe_parity(self, batch_size, sequence_length):
+ def test_switch_moe_parity(self, batch_size, sequence_length, quant_bits):
# if platform.system() == "Windows":
# pytest.skip("Skip on Windows")
switch_moe = SwitchMoE(
@@ -1020,8 +1077,8 @@ def test_switch_moe_parity(self, batch_size, sequence_length):
class TestMixtralMoE(unittest.TestCase):
- @parameterized.expand(small_test_cases())
- def test_mixtral_moe_parity(self, batch_size, sequence_length):
+ @parameterized.expand([(b, s, q) for b, s, q in small_test_cases() if q == 0]) # only run non-quantized
+ def test_mixtral_moe_parity(self, batch_size, sequence_length, quant_bits):
config = MixtralConfig(hidden_size=256, intermediate_size=1024)
mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length)
mixtral_moe.parity_check()
@@ -1029,13 +1086,329 @@ def test_mixtral_moe_parity(self, batch_size, sequence_length):
class TestPhiMoE(unittest.TestCase):
- @parameterized.expand(phi3_test_cases())
- def test_phi3_moe_parity(self, batch_size, sequence_length):
+ @parameterized.expand(phi3_test_params)
+ def test_phi3_moe_parity(self, batch_size, sequence_length, quant_bits):
config = PhiMoEConfig(hidden_size=256, intermediate_size=1024)
- phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length)
+ phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits)
phi3_moe.parity_check()
# phi3_moe.benchmark_ort()
+# ---------------------------------------------
+# The following test are for swiglu activation
+# ---------------------------------------------
+class SwigluMoeConfig:
+ def __init__(
+ self,
+ hidden_size=2048,
+ intermediate_size=2048,
+ num_experts_per_token=2,
+ num_local_experts=8,
+ ):
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_experts_per_token = num_experts_per_token
+ self.num_local_experts = num_local_experts
+
+
+class SwigluMlp(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_size = config.intermediate_size
+ self.hidden_dim = config.hidden_size
+ self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True)
+ self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True)
+
+ def swiglu(self, x: torch.Tensor):
+ dim = x.shape[-1]
+ x = x.view(-1, dim // 2, 2)
+ x_glu, x_linear = x[..., 0], x[..., 1]
+ y = x_glu * torch.sigmoid(1.702 * x_glu) * (x_linear + 1)
+ return y
+
+ def forward(self, x):
+ y = self.swiglu(self.w1(x))
+ y = self.w2(y)
+ return y
+
+
+def create_swiglu_moe_onnx_graph(
+ num_tokens: int,
+ num_experts: int,
+ hidden_size: int,
+ inter_size: int,
+ topk: int,
+ ort_dtype: int,
+ quant_bits: int,
+ fc1_experts_weights: torch.Tensor,
+ fc1_experts_bias: torch.Tensor,
+ fc2_experts_weights: torch.Tensor,
+ fc2_experts_bias: torch.Tensor,
+ fc1_experts_weight_scale: torch.Tensor = None,
+ fc2_experts_weight_scale: torch.Tensor = None,
+):
+ use_quant = quant_bits > 0
+ op_name = "QMoE" if use_quant else "MoE"
+
+ inputs = (
+ [
+ "input",
+ "router_probs",
+ "fc1_experts_weights",
+ "fc1_experts_weight_scale",
+ "fc1_experts_bias",
+ "fc2_experts_weights",
+ "fc2_experts_weight_scale",
+ "fc2_experts_bias",
+ ]
+ if use_quant
+ else [
+ "input",
+ "router_probs",
+ "fc1_experts_weights",
+ "fc1_experts_bias",
+ "fc2_experts_weights",
+ "fc2_experts_bias",
+ ]
+ )
+
+ nodes = [
+ helper.make_node(
+ op_name,
+ inputs,
+ ["output"],
+ "MoE_0",
+ k=topk,
+ normalize_routing_weights=1,
+ activation_type="swiglu",
+ domain="com.microsoft",
+ ),
+ ]
+
+ if use_quant:
+ nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)])
+
+ components = 2 if quant_bits == 4 else 1
+ fc1_weight_shape = [num_experts, hidden_size, 2 * inter_size // components]
+ fc1_bias_shape = [num_experts, 2 * inter_size]
+ fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size]
+
+ fc2_weight_shape = [num_experts, inter_size, hidden_size // components]
+ fc2_bias_shape = [num_experts, hidden_size]
+ fc2_experts_weight_scale_shape = [num_experts, hidden_size]
+
+ torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32
+ numpy_type = numpy.float16 if ort_dtype == TensorProto.FLOAT16 else numpy.float32
+ weight_numpy_type = numpy.uint8 if use_quant else numpy_type
+ weight_onnx_type = TensorProto.UINT8 if use_quant else ort_dtype
+
+ initializers = [
+ helper.make_tensor(
+ "fc1_experts_weights",
+ weight_onnx_type,
+ fc1_weight_shape,
+ fc1_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist()
+ if use_quant
+ else fc1_experts_weights.to(torch_type).flatten().tolist(),
+ raw=False,
+ ),
+ helper.make_tensor(
+ "fc1_experts_bias",
+ ort_dtype,
+ fc1_bias_shape,
+ fc1_experts_bias.to(torch_type).flatten().tolist(),
+ raw=False,
+ ),
+ helper.make_tensor(
+ "fc2_experts_weights",
+ weight_onnx_type,
+ fc2_weight_shape,
+ fc2_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist()
+ if use_quant
+ else fc2_experts_weights.to(torch_type).flatten().tolist(),
+ raw=False,
+ ),
+ helper.make_tensor(
+ "fc2_experts_bias",
+ ort_dtype,
+ fc2_bias_shape,
+ fc2_experts_bias.to(torch_type).flatten().tolist(),
+ raw=False,
+ ),
+ ]
+
+ if use_quant:
+ initializers.extend(
+ [
+ helper.make_tensor(
+ "fc1_experts_weight_scale",
+ ort_dtype,
+ fc1_experts_weight_scale_shape,
+ fc1_experts_weight_scale.to(torch_type).flatten().tolist(),
+ raw=False,
+ ),
+ helper.make_tensor(
+ "fc2_experts_weight_scale",
+ ort_dtype,
+ fc2_experts_weight_scale_shape,
+ fc2_experts_weight_scale.to(torch_type).flatten().tolist(),
+ raw=False,
+ ),
+ ]
+ )
+
+ graph_inputs = [
+ helper.make_tensor_value_info("input", ort_dtype, [num_tokens, hidden_size]),
+ ]
+
+ graph_inputs.append(
+ helper.make_tensor_value_info(
+ "router_probs",
+ ort_dtype,
+ [num_tokens, num_experts],
+ )
+ )
+
+ graph_outputs = [
+ helper.make_tensor_value_info("output", ort_dtype, [num_tokens, hidden_size]),
+ ]
+
+ graph = helper.make_graph(
+ nodes,
+ "MoE_Graph",
+ graph_inputs,
+ graph_outputs,
+ initializers,
+ )
+
+ model = helper.make_model(graph)
+ return model.SerializeToString()
+
+
+class SwigluMoEBlock(SparseMoeBlockORTHelper):
+ def __init__(self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0):
+ super().__init__(quant_bits)
+ self.hidden_dim = config.hidden_size
+ self.ffn_dim = config.intermediate_size
+ self.num_experts = config.num_local_experts
+ self.top_k = config.num_experts_per_token
+ use_quant = self.quant_bits > 0
+
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True)
+
+ self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)])
+
+ weight_1_list, weight_2_list = [], []
+ bias_1_list, bias_2_list = [], []
+ scale_1_list, scale_2_list = [], []
+
+ for i in range(self.num_experts):
+ bias_1_list.append(self.experts[i].w1.bias)
+ bias_2_list.append(self.experts[i].w2.bias)
+ if not use_quant:
+ weight_1_list.append(self.experts[i].w1.weight)
+ weight_2_list.append(self.experts[i].w2.weight)
+ else:
+ is_4_bit = self.quant_bits == 4
+ # Pass the transposed weight to quant_dequant to get correct scales,
+ # then transpose the resulting quantized weight back to the expected layout.
+ scale1, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight.T, is_4_bit)
+ scale2, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight.T, is_4_bit)
+
+ self.experts[i].w1.weight.data = w1_qdq
+ self.experts[i].w2.weight.data = w2_qdq
+
+ weight_1_list.append(pre_qweight1.T)
+ weight_2_list.append(pre_qweight2.T)
+ scale_1_list.append(scale1)
+ scale_2_list.append(scale2)
+
+ self.moe_experts_weight1 = torch.stack(weight_1_list, dim=0)
+ self.moe_experts_weight2 = torch.stack(weight_2_list, dim=0)
+
+ self.moe_experts_bias1 = torch.stack(bias_1_list, dim=0)
+ self.moe_experts_bias2 = torch.stack(bias_2_list, dim=0)
+
+ moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None
+ moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None
+
+ self.batch_size = batch_size
+ self.sequence_length = sequence_length
+ self.moe_onnx_graph = create_swiglu_moe_onnx_graph(
+ num_tokens=self.batch_size * self.sequence_length,
+ num_experts=self.num_experts,
+ hidden_size=self.hidden_dim,
+ inter_size=self.ffn_dim,
+ topk=self.top_k,
+ ort_dtype=self.ort_dtype,
+ quant_bits=self.quant_bits,
+ fc1_experts_weights=self.moe_experts_weight1,
+ fc1_experts_bias=self.moe_experts_bias1,
+ fc2_experts_weights=self.moe_experts_weight2,
+ fc2_experts_bias=self.moe_experts_bias2,
+ fc1_experts_weight_scale=moe_experts_weight_scale1,
+ fc2_experts_weight_scale=moe_experts_weight_scale2,
+ )
+
+ self.ort_sess = self.create_ort_session(self.moe_onnx_graph)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """ """
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ router_logits = self.gate(hidden_states) # router_logits shape is (batch * sequence_length, num_experts)
+ routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1)
+
+ routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float)
+
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ final_hidden_states = torch.zeros(
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ # One hot encode the selected experts to create an expert mask
+ # this will be used to easily index which expert is going to be sollicitated
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ # Loop over all available experts in the model and perform the computation on each expert
+ for expert_idx in range(self.num_experts):
+ expert_layer = self.experts[expert_idx]
+ idx, top_x = torch.where(expert_mask[expert_idx])
+
+ if top_x.shape[0] == 0:
+ continue
+
+ # Index the correct hidden states and compute the expert hidden state for
+ # the current expert. We need to make sure to multiply the output hidden
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
+
+ # However `index_add_` only support torch tensors for indexing so we'll use
+ # the `top_x` tensor here.
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+ return final_hidden_states
+
+
+swiglu_test_params = list(
+ itertools.product(
+ [1, 4], # batch_size
+ [1, 32], # sequence_length
+ [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16)
+ )
+)
+
+
+class TestSwigluMoE(unittest.TestCase):
+ @parameterized.expand(swiglu_test_params)
+ def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits):
+ config = SwigluMoeConfig(hidden_size=128, intermediate_size=512, num_experts_per_token=1, num_local_experts=4)
+ moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits)
+ moe.parity_check()
+
+
if __name__ == "__main__":
unittest.main()