diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index b59ff63ea8260..9c6fc6ce57a20 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -3129,8 +3129,8 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T : tensor(float), tensor(float16)
-- Constrain input and output types to float or float16 tensors.
+- T : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain input and output types to float tensors.
@@ -4543,19 +4543,19 @@ This version of the operator has been available since version 1 of the 'com.micr
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T1
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).
-fc1_scales : T
+fc1_scales : T2
2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc1_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T1
3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
-fc2_scales : T
+fc2_scales : T2
2D input tensor with shape (num_experts, hidden_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
fc3_experts_weights (optional) : T1
3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
-fc3_scales (optional) : T
+fc3_scales (optional) : T2
2D optional input tensor with shape (num_experts, inter_size)
fc3_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
@@ -4571,10 +4571,12 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T : tensor(float16)
-- Constrain input and output types to float or float16 tensors.
+- T : tensor(float16), tensor(bfloat16)
+- Constrain input and output types to float tensors.
- T1 : tensor(uint8)
- Constrain weights type to uint8 tensors.
+- T2 : tensor(float), tensor(float16)
+- Constrain scales type to float tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 3b70e5da8b3e4..26b701fea6fbb 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -949,7 +949,7 @@ Do not modify directly.*
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)|
-|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -957,7 +957,7 @@ Do not modify directly.*
|PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PagedAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* key_cache:**T**
*in* value_cache:**T**
*in* cumulative_sequence_length:**S**
*in* past_seqlens:**S**
*in* block_table:**S**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* key_cache_out:**T**
*out* value_cache_out:**T**|1+|**S** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)|
-|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)|
+|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(bfloat16), tensor(float16)|
|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index f3346d4513261..a50ee907c302b 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -92,7 +92,9 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE);
-class CUDA_MS_OP_CLASS_NAME(1, QMoE);
+class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MoE);
+class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, QMoE);
+class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QMoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_MLFloat16, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention);
@@ -307,7 +309,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu
new file mode 100644
index 0000000000000..5f0a71147b366
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu
@@ -0,0 +1,19 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4100)
+#pragma warning(disable : 4244)
+#pragma warning(disable : 4200)
+#endif
+
+#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+namespace ort_fastertransformer {
+template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>;
+} // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu
new file mode 100644
index 0000000000000..4a84581127156
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu
@@ -0,0 +1,19 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4100)
+#pragma warning(disable : 4244)
+#pragma warning(disable : 4200)
+#endif
+
+#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+namespace ort_fastertransformer {
+template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>;
+} // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu
new file mode 100644
index 0000000000000..6c23127955ac2
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu
@@ -0,0 +1,19 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4100)
+#pragma warning(disable : 4244)
+#pragma warning(disable : 4200)
+#endif
+
+#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+namespace ort_fastertransformer {
+template class MoeGemmRunner<__nv_bfloat16, uint8_t>;
+} // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
index 8b8f45e77ab9d..f855092670bc3 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
@@ -53,6 +53,8 @@
#include "cutlass_heuristic.h"
#include "moe_gemm_kernels.h"
+#include
+
#include
#include
#include
@@ -66,8 +68,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
CutlassGemmConfig gemm_config, const int multi_processor_count,
cudaStream_t stream, int* kernel_occupancy = nullptr) {
- static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value,
- "Specialized for half, float");
+ static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value,
+ "Specialized for half, float, bfloat16");
static_assert(cutlass::platform::is_same::value ||
cutlass::platform::is_same::value ||
@@ -76,12 +78,11 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
using ElementType_ =
- typename cutlass::platform::conditional::value, cutlass::half_t, T>::type;
+ typename cutlass::platform::conditional::value, cutlass::half_t, typename cutlass::platform::conditional::value, cutlass::bfloat16_t, T>::type>::type;
using ElementType = ElementType_;
using CutlassWeightType_ =
- typename cutlass::platform::conditional::value, cutlass::half_t,
- WeightType>::type;
+ typename cutlass::platform::conditional::value, cutlass::half_t, typename cutlass::platform::conditional::value, cutlass::bfloat16_t, WeightType>::type>::type;
using CutlassWeightType = CutlassWeightType_;
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
index 4268b79e1e4f8..fc412a02e0383 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
@@ -38,10 +38,13 @@
#include "moe_kernel.h"
+#include
#include
#include
#include
+#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
+
namespace ort_fastertransformer {
static constexpr int WARP_SIZE = 32;
@@ -103,11 +106,16 @@ void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows
dim3 block(std::min(intermediate_size, 1024));
dim3 grid(num_rows);
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR("swiglu input", input, num_rows, 2 * intermediate_size);
+
if constexpr (interleaved) {
swiglu_kernel_interleaved<<>>(output, input, intermediate_size, num_rows, swiglu_alpha);
} else {
swiglu_kernel_chunked<<>>(output, input, intermediate_size, num_rows, swiglu_alpha);
}
+
+ DUMP_TENSOR("swiglu output", output, num_rows, intermediate_size);
}
// ====================== Softmax things ===============================
@@ -838,11 +846,15 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr,
}
namespace {
-
-struct __align__(8) Half4 {
+typedef struct __CUDA_ALIGN__(8) {
half2 x;
half2 y;
-};
+} half2_2;
+
+typedef struct __CUDA_ALIGN__(8) {
+ __nv_bfloat162 x;
+ __nv_bfloat162 y;
+} __nv_bfloat162_2;
// TODO(wy): move to common header
template
@@ -853,7 +865,11 @@ struct T4 {
};
template <>
struct T4 {
- using Type = Half4;
+ using Type = half2_2;
+};
+template <>
+struct T4<__nv_bfloat16> {
+ using Type = __nv_bfloat162_2;
};
template
@@ -866,6 +882,10 @@ template <>
struct T2 {
using Type = half2;
};
+template <>
+struct T2<__nv_bfloat16> {
+ using Type = __nv_bfloat162;
+};
inline __device__ float2 operator*(const float2 a, const float2 b) { return make_float2(a.x * b.x, a.y * b.y); }
@@ -882,15 +902,27 @@ inline __device__ half2 operator*(const half2 a, const half2 b) { return make_ha
#endif
// TODO(wy): use cuda common header and investigate pipeline build issue.
-inline __device__ Half4 operator*(const Half4 a, const Half4 b) {
+inline __device__ half2_2 operator*(const half2_2 a, const half2_2 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \
((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
- Half4 result;
+ half2_2 result;
result.x = a.x * b.x;
result.y = a.y * b.y;
return result;
#else
- return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
+ return half2_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
+#endif
+}
+
+inline __device__ __nv_bfloat162_2 operator*(const __nv_bfloat162_2 a, const __nv_bfloat162_2 b) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && \
+ ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
+ __nv_bfloat162_2 result;
+ result.x = a.x * b.x;
+ result.y = a.y * b.y;
+ return result;
+#else
+ return __nv_bfloat162_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
#endif
}
@@ -1291,18 +1323,26 @@ template void topk_gating_softmax_kernelLauncher(const float*, const bool*, floa
int, bool, bool, cudaStream_t);
template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, int,
int, bool, bool, cudaStream_t);
+template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*, const bool*, __nv_bfloat16*, __nv_bfloat16*, int*, int*, int, int,
+ int, bool, bool, cudaStream_t);
// ==================== Variable batched GEMM specializations ==================================
template class CutlassMoeFCRunner;
template class CutlassMoeFCRunner;
+template class CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>;
+// For qMoE:
template class CutlassMoeFCRunner;
template class CutlassMoeFCRunner;
+template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>;
+template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>;
// ===================== Specializations for init routing =========================
template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, int, int,
cudaStream_t);
template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, int, int,
cudaStream_t);
+template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const int*, int*, int, int, int, int,
+ cudaStream_t);
// ==================== Specializations for final routing ===================================
template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*,
@@ -1317,6 +1357,8 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl
const float*, const int*, const int*, int, int, int, cudaStream_t);
template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*,
const half*, const int*, const int*, int, int, int, cudaStream_t);
+template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*,
+ const __nv_bfloat16*, const int*, const int*, int, int, int, cudaStream_t);
template void invokeSwiGLU(float*, float const*, int, int, float, cudaStream_t);
template void invokeSwiGLU(half*, half const*, int, int, float, cudaStream_t);
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
index 3ac4862e101c3..de11d357a8c07 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
@@ -178,14 +178,4 @@ class CutlassMoeFCRunner {
std::vector total_rows_before_expert_host_;
};
-template
-class CutlassMoeFCRunner::value>> {
- public:
- CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);
-
- size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) {
- return 0;
- }
-};
-
} // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc
index cc6fe871a3bc1..6409a6e12afc6 100644
--- a/onnxruntime/contrib_ops/cuda/moe/moe.cc
+++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc
@@ -3,6 +3,7 @@
#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/cuda_type_conversion.h"
#include "moe.h"
using namespace onnxruntime::cuda;
@@ -20,6 +21,7 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
+REGISTER_KERNEL_TYPED(BFloat16)
template
MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {
@@ -42,7 +44,7 @@ Status MoE::ComputeInternal(OpKernelContext* context) const {
fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional,
fc3_experts_weights_optional, fc3_experts_bias_optional));
- typedef typename ToCudaType::MappedType CudaT;
+ using CudaT = typename OrtToCudaType::type;
auto stream = context->GetComputeStream();
auto& device_prop = GetDeviceProp();
diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
index db6d99674cf5a..aef31c7e9ed3a 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
@@ -5,6 +5,7 @@
#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/quantization/moe_quantization.h"
+#include "core/providers/cuda/cuda_type_conversion.h"
using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
@@ -14,16 +15,6 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-#define REGISTER_KERNEL() \
- ONNX_OPERATOR_KERNEL_EX(QMoE, kMSDomain, 1, kCudaExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .MayInplace(0, 0) \
- .TypeConstraint("T", BuildKernelDefConstraints()) \
- .TypeConstraint("T1", BuildKernelDefConstraints()), \
- QMoE);
-
-REGISTER_KERNEL()
-
namespace {
template
struct ToCudaTypeWrapper : public ToCudaType {};
@@ -40,27 +31,29 @@ struct ToCudaTypeWrapper {
} // anonymous namespace
-QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {
+template
+QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK());
ORT_ENFORCE(expert_weight_bits_ == 8 || expert_weight_bits_ == 4,
"expert_weight_bits must be 4 or 8, but got ", expert_weight_bits_);
}
+template
template
-Status QMoE::QuantizedMoEImpl(OpKernelContext* context,
- MoEParameters& moe_params,
- const Tensor* input,
- const Tensor* router_probs,
- const Tensor* fc1_experts_weights,
- const Tensor* fc1_experts_bias_optional,
- const Tensor* fc2_experts_weights,
- const Tensor* fc2_experts_bias_optional,
- const Tensor* fc3_experts_weights_optional,
- const Tensor* fc3_experts_bias_optional,
- const Tensor* fc1_scales,
- const Tensor* fc2_scales,
- const Tensor* fc3_scales_optional,
- const cudaDeviceProp& device_prop) const {
+Status QMoE::QuantizedMoEImpl(OpKernelContext* context,
+ MoEParameters& moe_params,
+ const Tensor* input,
+ const Tensor* router_probs,
+ const Tensor* fc1_experts_weights,
+ const Tensor* fc1_experts_bias_optional,
+ const Tensor* fc2_experts_weights,
+ const Tensor* fc2_experts_bias_optional,
+ const Tensor* fc3_experts_weights_optional,
+ const Tensor* fc3_experts_bias_optional,
+ const Tensor* fc1_scales,
+ const Tensor* fc2_scales,
+ const Tensor* fc3_scales_optional,
+ const cudaDeviceProp& device_prop) const {
auto stream = context->GetComputeStream();
const int sm = device_prop.major * 10 + device_prop.minor;
@@ -68,8 +61,7 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context,
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
- using T = MLFloat16;
- using CudaT = typename ToCudaType::MappedType;
+ using CudaT = typename OrtToCudaType::type;
ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm,
activation_type_,
@@ -137,7 +129,8 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context,
return Status::OK();
}
-Status QMoE::ComputeInternal(OpKernelContext* context) const {
+template
+Status QMoE::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input(0);
const Tensor* router_probs = context->Input(1);
const Tensor* fc1_experts_weights = context->Input(2);
@@ -184,6 +177,32 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
#endif
}
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ QMoE,
+ kMSDomain,
+ 1,
+ MLFloat16,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .MayInplace(0, 0)
+ .TypeConstraint("T", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ QMoE);
+
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ QMoE,
+ kMSDomain,
+ 1,
+ BFloat16,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .MayInplace(0, 0)
+ .TypeConstraint("T", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ QMoE);
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h
index c0164576d7c7f..c4698a1f277ef 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h
+++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h
@@ -14,6 +14,7 @@ namespace cuda {
using namespace onnxruntime::cuda;
+template
class QMoE final : public CudaKernel, public MoEBase {
public:
explicit QMoE(const OpKernelInfo& op_kernel_info);
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 39bf2bf855976..686ebfb1f6fb5 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1405,7 +1405,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1,
.Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", "T", OpSchema::Optional)
.Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional)
.Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
- .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.")
+ .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
ONNX_MS_OPERATOR_SET_SCHEMA(
@@ -1440,7 +1440,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"3D input tensor with shape (num_experts, hidden_size, inter_size) "
"or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).",
"T1")
- .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T")
+ .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T2")
.Input(4,
"fc1_experts_bias",
"2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional)
@@ -1449,7 +1449,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"3D input tensor with shape (num_experts, inter_size, hidden_size) "
"or (num_experts, inter_size, hidden_size / 2)",
"T1")
- .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T")
+ .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T2")
.Input(7,
"fc2_experts_bias",
"2D optional input tensor with shape (num_experts, hidden_size)",
@@ -1464,7 +1464,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Input(9,
"fc3_scales",
"2D optional input tensor with shape (num_experts, inter_size)",
- "T",
+ "T2",
OpSchema::Optional)
.Input(10,
"fc3_experts_bias",
@@ -1476,8 +1476,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape "
"(batch_size, sequence_length, hidden_size)",
"T")
- .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float or float16 tensors.")
+ .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("T1", {"tensor(uint8)"}, "Constrain weights type to uint8 tensors.")
+ .TypeConstraint("T2", {"tensor(float)", "tensor(float16)"}, "Constrain scales type to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1,
diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_moe_cuda.py
similarity index 68%
rename from onnxruntime/test/python/transformers/test_parity_moe.py
rename to onnxruntime/test/python/transformers/test_moe_cuda.py
index d805c8f9cae3c..9b69d63970311 100644
--- a/onnxruntime/test/python/transformers/test_parity_moe.py
+++ b/onnxruntime/test/python/transformers/test_moe_cuda.py
@@ -10,6 +10,7 @@
# license information.
# --------------------------------------------------------------------------
import itertools
+import os
import unittest
from collections import OrderedDict
@@ -22,83 +23,59 @@
import onnxruntime
-torch.manual_seed(42)
-numpy.random.seed(42)
-
-
-def value_string_of(numpy_array):
- arr = numpy_array.flatten()
- lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)]
- return "{\n " + "f,\n ".join(lines) + "f}"
-
+# Reduces number of tests to run for faster pipeline checks
+pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1"
-def print_tensor(name, numpy_array):
- print(f"const std::vector {name} = {value_string_of(numpy_array)};")
+onnxruntime.preload_dlls()
+# Determine the execution provider and device based on CUDA availability.
+use_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers() and torch.cuda.is_available()
+device = torch.device("cuda:0" if use_cuda else "cpu")
+ort_provider = ["CUDAExecutionProvider"] if use_cuda else ["CPUExecutionProvider"]
-def quant_dequant(weights: torch.Tensor, is_4_bit_quantization: bool):
- """
- Performs symmetric per-column quantization and dequantization on a weight tensor.
-
- This implementation is a pure PyTorch replacement for the original function that
- relied on a custom tensorrt_llm operator. It supports both 8-bit (int8) and
- 4-bit (quint4x2 style) quantization.
-
- Args:
- weights (torch.Tensor): The input weight tensor to be quantized.
- is_4_bit_quantization (bool): If True, performs 4-bit quantization. If False,
- performs 8-bit quantization.
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
- - scales (torch.float16): The quantization scales for each column.
- - processed_q_weight (torch.int8): The packed quantized weights. For
- 4-bit mode, two 4-bit values are packed into a single int8. For
- 8-bit mode, this is the standard int8 quantized tensor. It is
- transposed relative to the input weights' shape.
- - dequantized_weights (torch.Tensor): The weights after being dequantized,
- restored to the original dtype and device.
- """
- # Determine quantization bits and range based on the mode
- if is_4_bit_quantization:
- # 4-bit symmetric quantization path
- q_bits = 4
- q_max = 2 ** (q_bits - 1) - 1 # 7
- q_min = -(2 ** (q_bits - 1)) # -8
-
- max_abs_val = torch.max(torch.abs(weights), dim=0, keepdim=True).values
- max_abs_val[max_abs_val == 0] = 1.0
- scales = max_abs_val / q_max
+torch.manual_seed(42)
+numpy.random.seed(42)
- quant_weights = torch.round(weights / scales).clamp(q_min, q_max).to(torch.int8)
+onnx_to_torch_type_map = {
+ TensorProto.FLOAT16: torch.float16,
+ TensorProto.FLOAT: torch.float,
+ TensorProto.BFLOAT16: torch.bfloat16,
+ TensorProto.UINT8: torch.uint8,
+}
- # Pack two 4-bit integers into a single int8
- q_weights_t = quant_weights.T.contiguous()
- shape = q_weights_t.shape
- q_weights_t_reshaped = q_weights_t.view(shape[0], shape[1] // 2, 2)
- lower_nibble = q_weights_t_reshaped[..., 0]
- upper_nibble = q_weights_t_reshaped[..., 1]
- processed_q_weight = (lower_nibble & 0x0F) | (upper_nibble << 4)
+ort_to_numpy_type_map = {
+ TensorProto.FLOAT16: numpy.float16,
+ TensorProto.FLOAT: numpy.float32,
+ TensorProto.UINT8: numpy.uint8,
+}
- else:
- # 8-bit symmetric quantization path
- q_bits = 8
- q_max = 2 ** (q_bits - 1) - 1 # 127
- q_min = -(2 ** (q_bits - 1)) # -128
+ort_dtype_name_map = {
+ TensorProto.FLOAT16: "FP16",
+ TensorProto.FLOAT: "FP32",
+ TensorProto.BFLOAT16: "BF16",
+}
- max_abs_val = torch.max(torch.abs(weights), dim=0, keepdim=True).values
- max_abs_val[max_abs_val == 0] = 1.0
- scales = max_abs_val / q_max
- quant_weights = torch.round(weights / scales).clamp(q_min, q_max).to(torch.int8)
+def quant_dequant(weights, is_4_bit_quantization: bool = True):
+ # use the test version `_symmetric_...` to get the non-interleaved weights
+ type = torch.quint4x2 if is_4_bit_quantization else torch.int8
+ # This import is needed to use torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix()
+ # Comment out this line for passing the lintrunner check in the CI.
+ # import tensorrt_llm
- # For 8-bit, the processed weights are just the transposed quantized weights (no packing)
- processed_q_weight = quant_weights.T.contiguous()
+ quant_weights, processed_q_weight, torch_weight_scales = (
+ torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(weights.T.cpu().contiguous(), type)
+ )
- # Dequantize the weights to verify and return for PyTorch-side parity check
- dequantized_weights = quant_weights.to(weights.dtype) * scales.to(weights.dtype)
+ # Unpack the int4s int int8s
+ if is_4_bit_quantization:
+ upper = quant_weights >> 4
+ lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends
+ quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape)
- return (scales.squeeze(0).to(torch.float16), processed_q_weight, dequantized_weights.T.to(device=weights.device))
+ quant_weights = quant_weights.to(dtype=weights.dtype)
+ result = torch.multiply(quant_weights, torch_weight_scales.unsqueeze(0)).T.contiguous()
+ return torch_weight_scales.to(torch.float16), processed_q_weight, result.to(device=weights.device)
def create_moe_onnx_graph(
@@ -110,7 +87,7 @@ def create_moe_onnx_graph(
fc1_experts_bias,
fc2_experts_weights,
fc2_experts_bias,
- ort_dtype,
+ onnx_dtype,
):
nodes = [
helper.make_node(
@@ -134,21 +111,21 @@ def create_moe_onnx_graph(
fc1_shape = [num_experts, hidden_size, inter_size]
fc2_shape = [num_experts, inter_size, hidden_size]
- torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32
+ torch_dtype = onnx_to_torch_type_map[onnx_dtype]
initializers = [
helper.make_tensor(
"fc1_experts_weights",
- ort_dtype,
+ onnx_dtype,
fc1_shape,
- fc1_experts_weights.to(torch_type).flatten().tolist(),
+ fc1_experts_weights.to(torch_dtype).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc2_experts_weights",
- ort_dtype,
+ onnx_dtype,
fc2_shape,
- fc2_experts_weights.to(torch_type).flatten().tolist(),
+ fc2_experts_weights.to(torch_dtype).flatten().tolist(),
raw=False,
),
]
@@ -159,35 +136,35 @@ def create_moe_onnx_graph(
[
helper.make_tensor(
"fc1_experts_bias",
- ort_dtype,
+ onnx_dtype,
fc1_bias_shape,
- fc1_experts_bias.to(torch_type).flatten().tolist(),
+ fc1_experts_bias.to(torch_dtype).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc2_experts_bias",
- ort_dtype,
+ onnx_dtype,
fc2_bias_shape,
- fc2_experts_bias.to(torch_type).flatten().tolist(),
+ fc2_experts_bias.to(torch_dtype).flatten().tolist(),
raw=False,
),
]
)
graph_inputs = [
- helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]),
]
graph_inputs.append(
helper.make_tensor_value_info(
"router_probs",
- ort_dtype,
+ onnx_dtype,
[sequence_length, num_experts],
)
)
graph_outputs = [
- helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]),
]
graph = helper.make_graph(
@@ -211,7 +188,7 @@ def create_mixtral_moe_onnx_graph(
fc2_experts_weights,
fc3_experts_weights,
topk,
- ort_dtype,
+ onnx_dtype,
):
nodes = [
helper.make_node(
@@ -238,46 +215,46 @@ def create_mixtral_moe_onnx_graph(
fc2_shape = [num_experts, inter_size, hidden_size]
fc3_shape = [num_experts, hidden_size, inter_size]
- torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32
+ torch_dtype = onnx_to_torch_type_map[onnx_dtype]
initializers = [
helper.make_tensor(
"fc1_experts_weights",
- ort_dtype,
+ onnx_dtype,
fc1_shape,
- fc1_experts_weights.to(torch_type).flatten().tolist(),
+ fc1_experts_weights.to(torch_dtype).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc2_experts_weights",
- ort_dtype,
+ onnx_dtype,
fc2_shape,
- fc2_experts_weights.to(torch_type).flatten().tolist(),
+ fc2_experts_weights.to(torch_dtype).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc3_experts_weights",
- ort_dtype,
+ onnx_dtype,
fc3_shape,
- fc3_experts_weights.to(torch_type).flatten().tolist(),
+ fc3_experts_weights.to(torch_dtype).flatten().tolist(),
raw=False,
),
]
graph_inputs = [
- helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]),
]
graph_inputs.append(
helper.make_tensor_value_info(
"router_probs",
- ort_dtype,
+ onnx_dtype,
[sequence_length, num_experts],
)
)
graph_outputs = [
- helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]),
]
graph = helper.make_graph(
@@ -301,7 +278,7 @@ def create_phi_moe_onnx_graph(
fc2_experts_weights,
fc3_experts_weights,
topk,
- ort_dtype,
+ onnx_dtype,
quant_bits=0,
fc1_scales=None,
fc2_scales=None,
@@ -368,31 +345,31 @@ def create_phi_moe_onnx_graph(
fc2_shape = [num_experts, inter_size, hidden_size // components]
fc3_shape = [num_experts, hidden_size, inter_size // components]
- torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32
- numpy_type = numpy.float16 if ort_dtype == TensorProto.FLOAT16 else numpy.float32
- weight_numpy_type = numpy.uint8 if use_quant else numpy_type
- weight_onnx_type = TensorProto.UINT8 if use_quant else ort_dtype
+ torch_dtype = onnx_to_torch_type_map[onnx_dtype]
+
+ weight_numpy_type = numpy.uint8 if use_quant else ort_to_numpy_type_map[onnx_dtype]
+ weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype
initializers = [
helper.make_tensor(
"fc1_experts_weights",
weight_onnx_type,
fc1_shape,
- fc1_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(),
+ fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(),
raw=False,
),
helper.make_tensor(
"fc2_experts_weights",
weight_onnx_type,
fc2_shape,
- fc2_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(),
+ fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(),
raw=False,
),
helper.make_tensor(
"fc3_experts_weights",
weight_onnx_type,
fc3_shape,
- fc3_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(),
+ fc3_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(),
raw=False,
),
]
@@ -405,42 +382,42 @@ def create_phi_moe_onnx_graph(
[
helper.make_tensor(
"fc1_scales",
- ort_dtype,
+ onnx_dtype,
fc1_scale_shape,
- fc1_scales.to(torch_type).flatten().tolist(),
+ fc1_scales.to(torch_dtype).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc2_scales",
- ort_dtype,
+ onnx_dtype,
fc2_scale_shape,
- fc2_scales.to(torch_type).flatten().tolist(),
+ fc2_scales.to(torch_dtype).flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc3_scales",
- ort_dtype,
+ onnx_dtype,
fc3_scale_shape,
- fc3_scales.to(torch_type).flatten().tolist(),
+ fc3_scales.to(torch_dtype).flatten().tolist(),
raw=False,
),
]
)
graph_inputs = [
- helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]),
]
graph_inputs.append(
helper.make_tensor_value_info(
"router_probs",
- ort_dtype,
+ onnx_dtype,
[sequence_length, num_experts],
)
)
graph_outputs = [
- helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]),
+ helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]),
]
graph = helper.make_graph(
@@ -593,127 +570,127 @@ def __init__(self, config: PhiMoEConfig):
class SparseMoeBlockORTHelper(nn.Module):
- def __init__(self, quant_bits=0):
+ def __init__(self, quant_bits=0, onnx_dtype=None):
super().__init__()
self.quant_bits = quant_bits
- self.ort_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT
- self.np_type = numpy.float16 if self.ort_dtype == TensorProto.FLOAT16 else numpy.float32
+ if onnx_dtype is None:
+ self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT
+ else:
+ self.onnx_dtype = onnx_dtype
+ self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32
def create_ort_session(self, moe_onnx_graph):
from onnxruntime import InferenceSession, SessionOptions # noqa: PLC0415
sess_options = SessionOptions()
+ sess_options.log_severity_level = 2
- cuda_providers = ["CUDAExecutionProvider"]
- if cuda_providers[0] not in onnxruntime.get_available_providers():
+ try:
+ ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider)
+ except Exception as e:
+ print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}")
+ print("Skipping ONNX Runtime execution for this test case.")
return None
- sess_options.log_severity_level = 2
- ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"])
-
return ort_session
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
pass
- def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor:
+ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor:
+ if self.ort_sess is None:
+ return None
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
- hidden_states = hidden_states.view(-1, hidden_dim)
+ hidden_states_flat = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
- router_logits = self.gate(hidden_states)
-
- ort_inputs = {
- "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(self.np_type)),
- "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(self.np_type)),
- }
+ router_logits = self.gate(hidden_states_flat)
- ort_output = None
- if self.ort_sess is not None:
- if not iobinding:
- ort_output = self.ort_sess.run(None, ort_inputs)
- return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits
- else:
- self.ort_run_with_iobinding(ort_inputs)
- return None
+ # Determine the correct torch dtype from the onnx_dtype
+ torch_dtype = onnx_to_torch_type_map[self.onnx_dtype]
- return None
+ # Prepare tensors on the correct device for ORT inference with the CORRECT dtype
+ tensors = {
+ "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype),
+ "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype),
+ "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype),
+ }
- def ort_run_with_iobinding(self, ort_inputs, repeat=1000):
+ # Bind inputs and outputs to torch tensors directly.
iobinding = self.ort_sess.io_binding()
- device_id = torch.cuda.current_device()
-
- iobinding.bind_input(
- name="input",
- device_type="cuda",
- device_id=device_id,
- element_type=self.np_type,
- shape=ort_inputs["input"].shape,
- buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(),
- )
-
- iobinding.bind_input(
- name="router_probs",
- device_type="cuda",
- device_id=device_id,
- element_type=self.np_type,
- shape=ort_inputs["router_probs"].shape,
- buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(
- ort_inputs["router_probs"], "cuda", device_id
- ).data_ptr(),
- )
- iobinding.bind_output(
- name="output",
- device_type="cuda",
- device_id=device_id,
- element_type=self.np_type,
- shape=ort_inputs["input"].shape,
- buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(
- numpy.zeros(ort_inputs["input"].shape), "cuda", device_id
- ).data_ptr(),
- )
-
- # warm up
- for _ in range(5):
- iobinding.synchronize_inputs()
- self.ort_sess.run_with_iobinding(iobinding)
- iobinding.synchronize_outputs()
-
- import time # noqa: PLC0415
-
- s = time.time()
- for _ in range(repeat):
- iobinding.synchronize_inputs()
- self.ort_sess.run_with_iobinding(iobinding)
- iobinding.synchronize_outputs()
- e = time.time()
- print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms")
-
- def parity_check(self, atol=None, rtol=None):
- hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim)
+ for name, tensor in tensors.items():
+ # Ensure tensor is on the globally defined device
+ if name == "output":
+ iobinding.bind_output(
+ name=name,
+ device_type=tensor.device.type,
+ device_id=tensor.device.index or 0,
+ element_type=self.onnx_dtype,
+ shape=tensor.shape,
+ buffer_ptr=tensor.data_ptr(),
+ )
+ else:
+ iobinding.bind_input(
+ name=name,
+ device_type=tensor.device.type,
+ device_id=tensor.device.index or 0,
+ element_type=self.onnx_dtype,
+ shape=tensor.shape,
+ buffer_ptr=tensor.data_ptr(),
+ )
+
+ iobinding.synchronize_inputs()
+ self.ort_sess.run_with_iobinding(iobinding)
+ iobinding.synchronize_outputs()
+
+ if enable_performance_test:
+ import time # noqa: PLC0415
+
+ repeat = 1000
+ s = time.time()
+ for _ in range(repeat):
+ iobinding.synchronize_inputs()
+ self.ort_sess.run_with_iobinding(iobinding)
+ iobinding.synchronize_outputs()
+ e = time.time()
+ print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms")
+
+ # The output tensor is on `device`. Reshape and return it.
+ return tensors["output"].reshape(batch_size, sequence_length, hidden_dim)
+
+ def parity_check(self):
+ hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device)
torch_output = self.forward(hidden_state)
ort_output = self.ort_forward(hidden_state)
- if atol is None:
- atol = 5e-2 if self.quant_bits == 0 else (2.0 if self.quant_bits == 8 else 3.0)
-
- if rtol is None:
- rtol = 1e-3 if self.quant_bits == 0 else 1e-2
+ dtype_str = ort_dtype_name_map[self.onnx_dtype]
+
+ # Maps "ort_type:quant_bits" to (atol, rtol)
+ ort_dtype_quant_bits_tolerance_map = {
+ "FP32:0": (5e-3, 1e-3),
+ "FP16:0": (5e-2, 1e-3),
+ "FP16:4": (3.0, 1e-2),
+ "FP16:8": (2.0, 1e-2),
+ "BF16:0": (1.0, 1e-2),
+ "BF16:4": (30.0, 1e-1),
+ "BF16:8": (20.0, 1e-1),
+ }
+ atol, rtol = ort_dtype_quant_bits_tolerance_map[f"{dtype_str}:{self.quant_bits}"]
if ort_output is not None:
- dtype_str = "FP32" if self.quant_bits == 0 else "FP16"
print(
f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str},"
f" batch: {self.batch_size}, seq_len: {self.sequence_length},"
- f" max_diff: {(torch_output - ort_output).abs().max()}"
+ f" max_diff: {(torch_output.cpu() - ort_output.cpu()).abs().max()}"
)
torch.testing.assert_close(
- ort_output.to(torch.float32), torch_output.to(torch.float32), rtol=rtol, atol=atol
+ ort_output.cpu().to(torch.float32), torch_output.cpu().to(torch.float32), rtol=rtol, atol=atol
)
def benchmark_ort(self):
- hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim)
- self.ort_forward(hidden_state, iobinding=True)
+ hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device)
+ self.ort_forward(hidden_state, enable_performance_test=True)
class SwitchMoE(SparseMoeBlockORTHelper):
@@ -757,7 +734,7 @@ def __init__(
self.moe_experts.bias1,
self.moe_experts.weight2.transpose(1, 2),
self.moe_experts.bias2,
- self.ort_dtype,
+ self.onnx_dtype,
)
self.ort_sess = self.create_ort_session(self.moe_onnx_graph)
@@ -827,7 +804,7 @@ def __init__(self, config, batch_size, sequence_length):
self.moe_experts_weight2,
self.moe_experts_weight3,
self.top_k,
- self.ort_dtype,
+ self.onnx_dtype,
)
self.ort_sess = self.create_ort_session(self.moe_onnx_graph)
@@ -924,8 +901,8 @@ class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper):
and memory on padding.
"""
- def __init__(self, config, batch_size, sequence_length, quant_bits=0):
- super().__init__(quant_bits)
+ def __init__(self, config, batch_size, sequence_length, quant_bits=0, onnx_dtype=None):
+ super().__init__(quant_bits, onnx_dtype)
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
@@ -950,18 +927,18 @@ def __init__(self, config, batch_size, sequence_length, quant_bits=0):
is_4_bit = self.quant_bits == 4
for i in range(self.num_experts):
# Corrected quantization logic for per-output-channel quantization
- w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight.T, is_4_bit)
- w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight.T, is_4_bit)
- w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight.T, is_4_bit)
+ w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit)
+ w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit)
+ w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit)
self.experts[i].w1.weight.data = w1_qdq
self.experts[i].w2.weight.data = w2_qdq
self.experts[i].w3.weight.data = w3_qdq
# Transpose quantized weights to match the expected ONNX layout
- w1_list.append(pre_qweight1.T)
- w2_list.append(pre_qweight2.T)
- w3_list.append(pre_qweight3.T)
+ w1_list.append(pre_qweight1)
+ w2_list.append(pre_qweight2)
+ w3_list.append(pre_qweight3)
w1_scale_list.append(w1_scale)
w2_scale_list.append(w2_scale)
w3_scale_list.append(w3_scale)
@@ -985,7 +962,7 @@ def __init__(self, config, batch_size, sequence_length, quant_bits=0):
self.moe_experts_weight2,
self.moe_experts_weight3,
self.top_k,
- self.ort_dtype,
+ self.onnx_dtype,
self.quant_bits,
moe_experts_weight_scale1,
moe_experts_weight_scale2,
@@ -1045,25 +1022,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def small_test_cases():
for batch_size in [1, 4, 16]:
for sequence_length in [128, 512, 1024]:
- yield batch_size, sequence_length, 0
-
-
-# Test cases for Phi-3 MoE.
-# We test three modes: no quantization, 8-bit, and 4-bit.
-phi3_test_params = list(
- itertools.product(
- [1, 4], # batch_size
- [1, 32], # sequence_length
- [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16)
- )
-)
+ yield batch_size, sequence_length
+@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.")
class TestSwitchMoE(unittest.TestCase):
@parameterized.expand(small_test_cases())
- def test_switch_moe_parity(self, batch_size, sequence_length, quant_bits):
- # if platform.system() == "Windows":
- # pytest.skip("Skip on Windows")
+ def test_switch_moe_parity(self, batch_size, sequence_length):
switch_moe = SwitchMoE(
batch_size=batch_size,
sequence_length=sequence_length,
@@ -1072,26 +1037,42 @@ def test_switch_moe_parity(self, batch_size, sequence_length, quant_bits):
hidden_features=1024,
out_features=256,
)
+ switch_moe.to(device)
switch_moe.parity_check()
- # switch_moe.benchmark_ort()
+# quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16)
+# since qMoE test requires tensorrt_llm for quant_dequant. We disable it in CI pipeline to avoid extra dependency.
+quant_bits_list = [0] if pipeline_mode else [0, 8, 4]
+
+
+@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.")
class TestMixtralMoE(unittest.TestCase):
- @parameterized.expand([(b, s, q) for b, s, q in small_test_cases() if q == 0]) # only run non-quantized
- def test_mixtral_moe_parity(self, batch_size, sequence_length, quant_bits):
+ @parameterized.expand(small_test_cases())
+ def test_mixtral_moe_parity(self, batch_size, sequence_length):
config = MixtralConfig(hidden_size=256, intermediate_size=1024)
mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length)
+ mixtral_moe.to(device)
mixtral_moe.parity_check()
- # mixtral_moe.benchmark_ort()
+phi3_test_cases = list(
+ itertools.product(
+ [1, 4], # batch_size
+ [1, 32], # sequence_length
+ quant_bits_list,
+ )
+)
+
+
+@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.")
class TestPhiMoE(unittest.TestCase):
- @parameterized.expand(phi3_test_params)
+ @parameterized.expand(phi3_test_cases)
def test_phi3_moe_parity(self, batch_size, sequence_length, quant_bits):
config = PhiMoEConfig(hidden_size=256, intermediate_size=1024)
phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits)
+ phi3_moe.to(device)
phi3_moe.parity_check()
- # phi3_moe.benchmark_ort()
# ---------------------------------------------
@@ -1111,6 +1092,14 @@ def __init__(
self.num_local_experts = num_local_experts
+def swiglu(x: torch.Tensor):
+ dim = x.shape[-1]
+ x = x.view(-1, dim // 2, 2)
+ x_glu, x_linear = x[..., 0], x[..., 1]
+ y = x_glu * torch.sigmoid(1.702 * x_glu) * (x_linear + 1)
+ return y
+
+
class SwigluMlp(nn.Module):
def __init__(self, config):
super().__init__()
@@ -1119,26 +1108,45 @@ def __init__(self, config):
self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True)
self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True)
- def swiglu(self, x: torch.Tensor):
- dim = x.shape[-1]
- x = x.view(-1, dim // 2, 2)
- x_glu, x_linear = x[..., 0], x[..., 1]
- y = x_glu * torch.sigmoid(1.702 * x_glu) * (x_linear + 1)
- return y
-
def forward(self, x):
- y = self.swiglu(self.w1(x))
+ x1 = self.w1(x)
+ y = swiglu(x1)
y = self.w2(y)
return y
+# Note that the shape might not match the tensor shape. See Attention note in this file.
+def make_onnx_intializer(name: str, tensor: torch.Tensor, shape, onnx_dtype):
+ torch_dtype = onnx_to_torch_type_map[onnx_dtype]
+ if torch_dtype == torch.bfloat16:
+ numpy_vals_uint16 = tensor.to(torch.bfloat16).cpu().view(torch.uint16).numpy()
+ initializer = helper.make_tensor(
+ name=name,
+ data_type=TensorProto.BFLOAT16,
+ dims=shape,
+ vals=numpy_vals_uint16.tobytes(),
+ raw=True,
+ )
+ else:
+ initializer = helper.make_tensor(
+ name=name,
+ data_type=onnx_dtype,
+ dims=shape,
+ vals=tensor.flatten().detach().cpu().numpy().astype(numpy.uint8).tolist()
+ if onnx_dtype == TensorProto.UINT8
+ else tensor.detach().to(torch_dtype).flatten().tolist(),
+ raw=False,
+ )
+ return initializer
+
+
def create_swiglu_moe_onnx_graph(
num_tokens: int,
num_experts: int,
hidden_size: int,
inter_size: int,
topk: int,
- ort_dtype: int,
+ onnx_dtype: int,
quant_bits: int,
fc1_experts_weights: torch.Tensor,
fc1_experts_bias: torch.Tensor,
@@ -1189,6 +1197,8 @@ def create_swiglu_moe_onnx_graph(
nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)])
components = 2 if quant_bits == 4 else 1
+ # ATTENTION: Actual weight layout is like [num_experts, 2 * inter_size, hidden_size // components]
+ # Here we claim a different shape for the initializer to match the operator spec for weight tensor!
fc1_weight_shape = [num_experts, hidden_size, 2 * inter_size // components]
fc1_bias_shape = [num_experts, 2 * inter_size]
fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size]
@@ -1197,80 +1207,54 @@ def create_swiglu_moe_onnx_graph(
fc2_bias_shape = [num_experts, hidden_size]
fc2_experts_weight_scale_shape = [num_experts, hidden_size]
- torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32
- numpy_type = numpy.float16 if ort_dtype == TensorProto.FLOAT16 else numpy.float32
- weight_numpy_type = numpy.uint8 if use_quant else numpy_type
- weight_onnx_type = TensorProto.UINT8 if use_quant else ort_dtype
+ weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype
+
+ torch_dtype = onnx_to_torch_type_map[onnx_dtype]
+ weight_torch_dtype = onnx_to_torch_type_map[weight_onnx_type]
initializers = [
- helper.make_tensor(
- "fc1_experts_weights",
- weight_onnx_type,
- fc1_weight_shape,
- fc1_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist()
- if use_quant
- else fc1_experts_weights.to(torch_type).flatten().tolist(),
- raw=False,
+ make_onnx_intializer(
+ "fc1_experts_weights", fc1_experts_weights.to(weight_torch_dtype), fc1_weight_shape, weight_onnx_type
),
- helper.make_tensor(
- "fc1_experts_bias",
- ort_dtype,
- fc1_bias_shape,
- fc1_experts_bias.to(torch_type).flatten().tolist(),
- raw=False,
- ),
- helper.make_tensor(
- "fc2_experts_weights",
- weight_onnx_type,
- fc2_weight_shape,
- fc2_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist()
- if use_quant
- else fc2_experts_weights.to(torch_type).flatten().tolist(),
- raw=False,
- ),
- helper.make_tensor(
- "fc2_experts_bias",
- ort_dtype,
- fc2_bias_shape,
- fc2_experts_bias.to(torch_type).flatten().tolist(),
- raw=False,
+ make_onnx_intializer("fc1_experts_bias", fc1_experts_bias.to(torch_dtype), fc1_bias_shape, onnx_dtype),
+ make_onnx_intializer(
+ "fc2_experts_weights", fc2_experts_weights.to(weight_torch_dtype), fc2_weight_shape, weight_onnx_type
),
+ make_onnx_intializer("fc2_experts_bias", fc2_experts_bias.to(torch_dtype), fc2_bias_shape, onnx_dtype),
]
if use_quant:
initializers.extend(
[
- helper.make_tensor(
+ make_onnx_intializer(
"fc1_experts_weight_scale",
- ort_dtype,
+ fc1_experts_weight_scale.to(torch_dtype),
fc1_experts_weight_scale_shape,
- fc1_experts_weight_scale.to(torch_type).flatten().tolist(),
- raw=False,
+ onnx_dtype,
),
- helper.make_tensor(
+ make_onnx_intializer(
"fc2_experts_weight_scale",
- ort_dtype,
+ fc2_experts_weight_scale.to(torch_dtype),
fc2_experts_weight_scale_shape,
- fc2_experts_weight_scale.to(torch_type).flatten().tolist(),
- raw=False,
+ onnx_dtype,
),
]
)
graph_inputs = [
- helper.make_tensor_value_info("input", ort_dtype, [num_tokens, hidden_size]),
+ helper.make_tensor_value_info("input", onnx_dtype, [num_tokens, hidden_size]),
]
graph_inputs.append(
helper.make_tensor_value_info(
"router_probs",
- ort_dtype,
+ onnx_dtype,
[num_tokens, num_experts],
)
)
graph_outputs = [
- helper.make_tensor_value_info("output", ort_dtype, [num_tokens, hidden_size]),
+ helper.make_tensor_value_info("output", onnx_dtype, [num_tokens, hidden_size]),
]
graph = helper.make_graph(
@@ -1286,8 +1270,10 @@ def create_swiglu_moe_onnx_graph(
class SwigluMoEBlock(SparseMoeBlockORTHelper):
- def __init__(self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0):
- super().__init__(quant_bits)
+ def __init__(
+ self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None
+ ):
+ super().__init__(quant_bits, onnx_dtype=onnx_dtype)
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
@@ -1298,54 +1284,61 @@ def __init__(self, config: SwigluMoeConfig, batch_size: int, sequence_length: in
self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)])
- weight_1_list, weight_2_list = [], []
- bias_1_list, bias_2_list = [], []
+ # For the ONNX MoE operator, weights must be transposed to [In, Out] format.
+ # Biases do not require transposition.
+ fc1_w_list, fc2_w_list = [], []
+ fc1_b_list, fc2_b_list = [], []
scale_1_list, scale_2_list = [], []
- for i in range(self.num_experts):
- bias_1_list.append(self.experts[i].w1.bias)
- bias_2_list.append(self.experts[i].w2.bias)
+ for expert in self.experts:
+ fc1_b_list.append(expert.w1.bias)
+ fc2_b_list.append(expert.w2.bias)
if not use_quant:
- weight_1_list.append(self.experts[i].w1.weight)
- weight_2_list.append(self.experts[i].w2.weight)
+ # ATTENTION: Weight tensor for CUDA shall have [E, out, in] memory layout just like Linear.
+ # But the initializer shape shall be [E, in, out] to match op spec.
+ fc1_w_list.append(expert.w1.weight)
+ fc2_w_list.append(expert.w2.weight)
else:
is_4_bit = self.quant_bits == 4
- # Pass the transposed weight to quant_dequant to get correct scales,
- # then transpose the resulting quantized weight back to the expected layout.
- scale1, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight.T, is_4_bit)
- scale2, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight.T, is_4_bit)
- self.experts[i].w1.weight.data = w1_qdq
- self.experts[i].w2.weight.data = w2_qdq
+ # quant_dequant expects [Out, In] format, matching nn.Linear.weight
+ scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit)
+ scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit)
- weight_1_list.append(pre_qweight1.T)
- weight_2_list.append(pre_qweight2.T)
+ # Update the expert's weight with the dequantized version for the PyTorch reference.
+ expert.w1.weight.data = w1_qdq
+ expert.w2.weight.data = w2_qdq
+
+ fc1_w_list.append(pre_qweight1)
+ fc2_w_list.append(pre_qweight2)
scale_1_list.append(scale1)
scale_2_list.append(scale2)
- self.moe_experts_weight1 = torch.stack(weight_1_list, dim=0)
- self.moe_experts_weight2 = torch.stack(weight_2_list, dim=0)
-
- self.moe_experts_bias1 = torch.stack(bias_1_list, dim=0)
- self.moe_experts_bias2 = torch.stack(bias_2_list, dim=0)
+ # Stack the prepared tensors for the graph builder
+ fc1_experts_weights = torch.stack(fc1_w_list, dim=0)
+ fc2_experts_weights = torch.stack(fc2_w_list, dim=0)
+ fc1_experts_bias = torch.stack(fc1_b_list, dim=0)
+ fc2_experts_bias = torch.stack(fc2_b_list, dim=0)
moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None
moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None
self.batch_size = batch_size
self.sequence_length = sequence_length
+
+ # Build the ONNX graph with the correctly shaped tensors
self.moe_onnx_graph = create_swiglu_moe_onnx_graph(
num_tokens=self.batch_size * self.sequence_length,
num_experts=self.num_experts,
hidden_size=self.hidden_dim,
inter_size=self.ffn_dim,
topk=self.top_k,
- ort_dtype=self.ort_dtype,
+ onnx_dtype=self.onnx_dtype,
quant_bits=self.quant_bits,
- fc1_experts_weights=self.moe_experts_weight1,
- fc1_experts_bias=self.moe_experts_bias1,
- fc2_experts_weights=self.moe_experts_weight2,
- fc2_experts_bias=self.moe_experts_bias2,
+ fc1_experts_weights=fc1_experts_weights,
+ fc1_experts_bias=fc1_experts_bias,
+ fc2_experts_weights=fc2_experts_weights,
+ fc2_experts_bias=fc2_experts_bias,
fc1_experts_weight_scale=moe_experts_weight_scale1,
fc2_experts_weight_scale=moe_experts_weight_scale2,
)
@@ -1353,26 +1346,24 @@ def __init__(self, config: SwigluMoeConfig, batch_size: int, sequence_length: in
self.ort_sess = self.create_ort_session(self.moe_onnx_graph)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """ """
+ """
+ This is the robust PyTorch reference implementation. It directly uses the
+ nn.Module experts, which is cleaner and less error-prone than manual matmul.
+ """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
- router_logits = self.gate(hidden_states) # router_logits shape is (batch * sequence_length, num_experts)
+ router_logits = self.gate(hidden_states)
routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1)
-
routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float)
- # we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
- # One hot encode the selected experts to create an expert mask
- # this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
- # Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
@@ -1380,35 +1371,78 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if top_x.shape[0] == 0:
continue
- # Index the correct hidden states and compute the expert hidden state for
- # the current expert. We need to make sure to multiply the output hidden
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
- # However `index_add_` only support torch tensors for indexing so we'll use
- # the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states
-swiglu_test_params = list(
+swiglu_test_cases = list(
itertools.product(
- [1, 4], # batch_size
- [1, 32], # sequence_length
- [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16)
+ [1, 2], # batch_size
+ [1, 3], # sequence_length
+ quant_bits_list, # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16)
)
)
+@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.")
class TestSwigluMoE(unittest.TestCase):
- @parameterized.expand(swiglu_test_params)
+ @parameterized.expand(swiglu_test_cases)
def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits):
- config = SwigluMoeConfig(hidden_size=128, intermediate_size=512, num_experts_per_token=1, num_local_experts=4)
+ config = SwigluMoeConfig(hidden_size=64, intermediate_size=256, num_experts_per_token=2, num_local_experts=4)
moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits)
+ moe.to(device)
+ moe.parity_check()
+
+
+def has_bf16_moe():
+ if "CUDAExecutionProvider" not in onnxruntime.get_available_providers() or not torch.cuda.is_available():
+ return False
+ major, _ = torch.cuda.get_device_capability()
+ return major >= 8
+
+
+@unittest.skipIf(not has_bf16_moe(), "skipping bf16 moe tests.")
+class TestSwigluMoeBf16(unittest.TestCase):
+ @parameterized.expand(swiglu_test_cases)
+ def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits):
+ config = SwigluMoeConfig(hidden_size=64, intermediate_size=128, num_experts_per_token=2, num_local_experts=4)
+ moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits, onnx_dtype=TensorProto.BFLOAT16)
+ moe.to(device)
moe.parity_check()
+perf_test_cases = list(
+ itertools.product(
+ [1], # batch_size
+ [128, 512, 1024, 2048, 4096], # sequence_length
+ [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16)
+ )
+)
+
+
+@unittest.skipIf(pipeline_mode or not use_cuda, "skipping performance test in CI pipeline.")
+class TestSwigluMoEPerf(unittest.TestCase):
+ @parameterized.expand(perf_test_cases)
+ def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits):
+ hidden_size = 2880
+ intermediate_size = 2880
+ num_experts_per_token = 8
+ num_local_experts = 128
+ config = SwigluMoeConfig(
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_experts_per_token=num_experts_per_token,
+ num_local_experts=num_local_experts,
+ )
+ moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits)
+ moe.to(device)
+ moe.benchmark_ort()
+
+
if __name__ == "__main__":
unittest.main()