diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index b59ff63ea8260..9c6fc6ce57a20 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3129,8 +3129,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
-
Constrain input and output types to float or float16 tensors.
+
T : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
@@ -4543,19 +4543,19 @@ This version of the operator has been available since version 1 of the 'com.micr
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T1
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).
-
fc1_scales : T
+
fc1_scales : T2
2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc1_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T1
3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
-
fc2_scales : T
+
fc2_scales : T2
2D input tensor with shape (num_experts, hidden_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
fc3_experts_weights (optional) : T1
3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
-
fc3_scales (optional) : T
+
fc3_scales (optional) : T2
2D optional input tensor with shape (num_experts, inter_size)
fc3_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
@@ -4571,10 +4571,12 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float16)
-
Constrain input and output types to float or float16 tensors.
+
T : tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
T1 : tensor(uint8)
Constrain weights type to uint8 tensors.
+
T2 : tensor(float), tensor(float16)
+
Constrain scales type to float tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 3b70e5da8b3e4..26b701fea6fbb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -949,7 +949,7 @@ Do not modify directly.* |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)| -|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -957,7 +957,7 @@ Do not modify directly.* |PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |PagedAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* key_cache:**T**
*in* value_cache:**T**
*in* cumulative_sequence_length:**S**
*in* past_seqlens:**S**
*in* block_table:**S**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* key_cache_out:**T**
*out* value_cache_out:**T**|1+|**S** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| -|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)| +|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(bfloat16), tensor(float16)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index f3346d4513261..a50ee907c302b 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -92,7 +92,9 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE); -class CUDA_MS_OP_CLASS_NAME(1, QMoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, QMoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QMoE); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_MLFloat16, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention); @@ -307,7 +309,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu new file mode 100644 index 0000000000000..5f0a71147b366 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu new file mode 100644 index 0000000000000..4a84581127156 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu new file mode 100644 index 0000000000000..6c23127955ac2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, uint8_t>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 8b8f45e77ab9d..f855092670bc3 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -53,6 +53,8 @@ #include "cutlass_heuristic.h" #include "moe_gemm_kernels.h" +#include + #include #include #include @@ -66,8 +68,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, int* kernel_occupancy = nullptr) { - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for half, float"); + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float, bfloat16"); static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || @@ -76,12 +78,11 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. using ElementType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; + typename cutlass::platform::conditional::value, cutlass::half_t, typename cutlass::platform::conditional::value, cutlass::bfloat16_t, T>::type>::type; using ElementType = ElementType_; using CutlassWeightType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, - WeightType>::type; + typename cutlass::platform::conditional::value, cutlass::half_t, typename cutlass::platform::conditional::value, cutlass::bfloat16_t, WeightType>::type>::type; using CutlassWeightType = CutlassWeightType_; diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 4268b79e1e4f8..fc412a02e0383 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -38,10 +38,13 @@ #include "moe_kernel.h" +#include #include #include #include +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" + namespace ort_fastertransformer { static constexpr int WARP_SIZE = 32; @@ -103,11 +106,16 @@ void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows dim3 block(std::min(intermediate_size, 1024)); dim3 grid(num_rows); + DUMP_TENSOR_INIT(); + DUMP_TENSOR("swiglu input", input, num_rows, 2 * intermediate_size); + if constexpr (interleaved) { swiglu_kernel_interleaved<<>>(output, input, intermediate_size, num_rows, swiglu_alpha); } else { swiglu_kernel_chunked<<>>(output, input, intermediate_size, num_rows, swiglu_alpha); } + + DUMP_TENSOR("swiglu output", output, num_rows, intermediate_size); } // ====================== Softmax things =============================== @@ -838,11 +846,15 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, } namespace { - -struct __align__(8) Half4 { +typedef struct __CUDA_ALIGN__(8) { half2 x; half2 y; -}; +} half2_2; + +typedef struct __CUDA_ALIGN__(8) { + __nv_bfloat162 x; + __nv_bfloat162 y; +} __nv_bfloat162_2; // TODO(wy): move to common header template @@ -853,7 +865,11 @@ struct T4 { }; template <> struct T4 { - using Type = Half4; + using Type = half2_2; +}; +template <> +struct T4<__nv_bfloat16> { + using Type = __nv_bfloat162_2; }; template @@ -866,6 +882,10 @@ template <> struct T2 { using Type = half2; }; +template <> +struct T2<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; inline __device__ float2 operator*(const float2 a, const float2 b) { return make_float2(a.x * b.x, a.y * b.y); } @@ -882,15 +902,27 @@ inline __device__ half2 operator*(const half2 a, const half2 b) { return make_ha #endif // TODO(wy): use cuda common header and investigate pipeline build issue. -inline __device__ Half4 operator*(const Half4 a, const Half4 b) { +inline __device__ half2_2 operator*(const half2_2 a, const half2_2 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) - Half4 result; + half2_2 result; result.x = a.x * b.x; result.y = a.y * b.y; return result; #else - return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; + return half2_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; +#endif +} + +inline __device__ __nv_bfloat162_2 operator*(const __nv_bfloat162_2 a, const __nv_bfloat162_2 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && \ + ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + __nv_bfloat162_2 result; + result.x = a.x * b.x; + result.y = a.y * b.y; + return result; +#else + return __nv_bfloat162_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; #endif } @@ -1291,18 +1323,26 @@ template void topk_gating_softmax_kernelLauncher(const float*, const bool*, floa int, bool, bool, cudaStream_t); template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, int, int, bool, bool, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*, const bool*, __nv_bfloat16*, __nv_bfloat16*, int*, int*, int, int, + int, bool, bool, cudaStream_t); // ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>; +// For qMoE: template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>; +template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>; // ===================== Specializations for init routing ========================= template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, int, int, cudaStream_t); template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, int, int, cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const int*, int*, int, int, int, int, + cudaStream_t); // ==================== Specializations for final routing =================================== template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*, @@ -1317,6 +1357,8 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl const float*, const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const half*, const int*, const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*, + const __nv_bfloat16*, const int*, const int*, int, int, int, cudaStream_t); template void invokeSwiGLU(float*, float const*, int, int, float, cudaStream_t); template void invokeSwiGLU(half*, half const*, int, int, float, cudaStream_t); diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index 3ac4862e101c3..de11d357a8c07 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -178,14 +178,4 @@ class CutlassMoeFCRunner { std::vector total_rows_before_expert_host_; }; -template -class CutlassMoeFCRunner::value>> { - public: - CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); - - size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) { - return 0; - } -}; - } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index cc6fe871a3bc1..6409a6e12afc6 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -3,6 +3,7 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_type_conversion.h" #include "moe.h" using namespace onnxruntime::cuda; @@ -20,6 +21,7 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { @@ -42,7 +44,7 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional, fc3_experts_bias_optional)); - typedef typename ToCudaType::MappedType CudaT; + using CudaT = typename OrtToCudaType::type; auto stream = context->GetComputeStream(); auto& device_prop = GetDeviceProp(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index db6d99674cf5a..aef31c7e9ed3a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -5,6 +5,7 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/quantization/moe_quantization.h" +#include "core/providers/cuda/cuda_type_conversion.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -14,16 +15,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL() \ - ONNX_OPERATOR_KERNEL_EX(QMoE, kMSDomain, 1, kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(0, 0) \ - .TypeConstraint("T", BuildKernelDefConstraints()) \ - .TypeConstraint("T1", BuildKernelDefConstraints()), \ - QMoE); - -REGISTER_KERNEL() - namespace { template struct ToCudaTypeWrapper : public ToCudaType {}; @@ -40,27 +31,29 @@ struct ToCudaTypeWrapper { } // anonymous namespace -QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { +template +QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK()); ORT_ENFORCE(expert_weight_bits_ == 8 || expert_weight_bits_ == 4, "expert_weight_bits must be 4 or 8, but got ", expert_weight_bits_); } +template template -Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional, - const cudaDeviceProp& device_prop) const { +Status QMoE::QuantizedMoEImpl(OpKernelContext* context, + MoEParameters& moe_params, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias_optional, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias_optional, + const Tensor* fc3_experts_weights_optional, + const Tensor* fc3_experts_bias_optional, + const Tensor* fc1_scales, + const Tensor* fc2_scales, + const Tensor* fc3_scales_optional, + const cudaDeviceProp& device_prop) const { auto stream = context->GetComputeStream(); const int sm = device_prop.major * 10 + device_prop.minor; @@ -68,8 +61,7 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - using T = MLFloat16; - using CudaT = typename ToCudaType::MappedType; + using CudaT = typename OrtToCudaType::type; ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, activation_type_, @@ -137,7 +129,8 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, return Status::OK(); } -Status QMoE::ComputeInternal(OpKernelContext* context) const { +template +Status QMoE::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* router_probs = context->Input(1); const Tensor* fc1_experts_weights = context->Input(2); @@ -184,6 +177,32 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { #endif } +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .MayInplace(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoE); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, + kMSDomain, + 1, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .MayInplace(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoE); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h index c0164576d7c7f..c4698a1f277ef 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h @@ -14,6 +14,7 @@ namespace cuda { using namespace onnxruntime::cuda; +template class QMoE final : public CudaKernel, public MoEBase { public: explicit QMoE(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 39bf2bf855976..686ebfb1f6fb5 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1405,7 +1405,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, .Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", "T", OpSchema::Optional) .Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1440,7 +1440,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "3D input tensor with shape (num_experts, hidden_size, inter_size) " "or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).", "T1") - .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T") + .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T2") .Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional) @@ -1449,7 +1449,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "3D input tensor with shape (num_experts, inter_size, hidden_size) " "or (num_experts, inter_size, hidden_size / 2)", "T1") - .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T") + .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T2") .Input(7, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", @@ -1464,7 +1464,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(9, "fc3_scales", "2D optional input tensor with shape (num_experts, inter_size)", - "T", + "T2", OpSchema::Optional) .Input(10, "fc3_experts_bias", @@ -1476,8 +1476,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " "(batch_size, sequence_length, hidden_size)", "T") - .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("T1", {"tensor(uint8)"}, "Constrain weights type to uint8 tensors.") + .TypeConstraint("T2", {"tensor(float)", "tensor(float16)"}, "Constrain scales type to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_moe_cuda.py similarity index 68% rename from onnxruntime/test/python/transformers/test_parity_moe.py rename to onnxruntime/test/python/transformers/test_moe_cuda.py index d805c8f9cae3c..9b69d63970311 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_moe_cuda.py @@ -10,6 +10,7 @@ # license information. # -------------------------------------------------------------------------- import itertools +import os import unittest from collections import OrderedDict @@ -22,83 +23,59 @@ import onnxruntime -torch.manual_seed(42) -numpy.random.seed(42) - - -def value_string_of(numpy_array): - arr = numpy_array.flatten() - lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)] - return "{\n " + "f,\n ".join(lines) + "f}" - +# Reduces number of tests to run for faster pipeline checks +pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" -def print_tensor(name, numpy_array): - print(f"const std::vector {name} = {value_string_of(numpy_array)};") +onnxruntime.preload_dlls() +# Determine the execution provider and device based on CUDA availability. +use_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers() and torch.cuda.is_available() +device = torch.device("cuda:0" if use_cuda else "cpu") +ort_provider = ["CUDAExecutionProvider"] if use_cuda else ["CPUExecutionProvider"] -def quant_dequant(weights: torch.Tensor, is_4_bit_quantization: bool): - """ - Performs symmetric per-column quantization and dequantization on a weight tensor. - - This implementation is a pure PyTorch replacement for the original function that - relied on a custom tensorrt_llm operator. It supports both 8-bit (int8) and - 4-bit (quint4x2 style) quantization. - - Args: - weights (torch.Tensor): The input weight tensor to be quantized. - is_4_bit_quantization (bool): If True, performs 4-bit quantization. If False, - performs 8-bit quantization. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: - - scales (torch.float16): The quantization scales for each column. - - processed_q_weight (torch.int8): The packed quantized weights. For - 4-bit mode, two 4-bit values are packed into a single int8. For - 8-bit mode, this is the standard int8 quantized tensor. It is - transposed relative to the input weights' shape. - - dequantized_weights (torch.Tensor): The weights after being dequantized, - restored to the original dtype and device. - """ - # Determine quantization bits and range based on the mode - if is_4_bit_quantization: - # 4-bit symmetric quantization path - q_bits = 4 - q_max = 2 ** (q_bits - 1) - 1 # 7 - q_min = -(2 ** (q_bits - 1)) # -8 - - max_abs_val = torch.max(torch.abs(weights), dim=0, keepdim=True).values - max_abs_val[max_abs_val == 0] = 1.0 - scales = max_abs_val / q_max +torch.manual_seed(42) +numpy.random.seed(42) - quant_weights = torch.round(weights / scales).clamp(q_min, q_max).to(torch.int8) +onnx_to_torch_type_map = { + TensorProto.FLOAT16: torch.float16, + TensorProto.FLOAT: torch.float, + TensorProto.BFLOAT16: torch.bfloat16, + TensorProto.UINT8: torch.uint8, +} - # Pack two 4-bit integers into a single int8 - q_weights_t = quant_weights.T.contiguous() - shape = q_weights_t.shape - q_weights_t_reshaped = q_weights_t.view(shape[0], shape[1] // 2, 2) - lower_nibble = q_weights_t_reshaped[..., 0] - upper_nibble = q_weights_t_reshaped[..., 1] - processed_q_weight = (lower_nibble & 0x0F) | (upper_nibble << 4) +ort_to_numpy_type_map = { + TensorProto.FLOAT16: numpy.float16, + TensorProto.FLOAT: numpy.float32, + TensorProto.UINT8: numpy.uint8, +} - else: - # 8-bit symmetric quantization path - q_bits = 8 - q_max = 2 ** (q_bits - 1) - 1 # 127 - q_min = -(2 ** (q_bits - 1)) # -128 +ort_dtype_name_map = { + TensorProto.FLOAT16: "FP16", + TensorProto.FLOAT: "FP32", + TensorProto.BFLOAT16: "BF16", +} - max_abs_val = torch.max(torch.abs(weights), dim=0, keepdim=True).values - max_abs_val[max_abs_val == 0] = 1.0 - scales = max_abs_val / q_max - quant_weights = torch.round(weights / scales).clamp(q_min, q_max).to(torch.int8) +def quant_dequant(weights, is_4_bit_quantization: bool = True): + # use the test version `_symmetric_...` to get the non-interleaved weights + type = torch.quint4x2 if is_4_bit_quantization else torch.int8 + # This import is needed to use torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix() + # Comment out this line for passing the lintrunner check in the CI. + # import tensorrt_llm - # For 8-bit, the processed weights are just the transposed quantized weights (no packing) - processed_q_weight = quant_weights.T.contiguous() + quant_weights, processed_q_weight, torch_weight_scales = ( + torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(weights.T.cpu().contiguous(), type) + ) - # Dequantize the weights to verify and return for PyTorch-side parity check - dequantized_weights = quant_weights.to(weights.dtype) * scales.to(weights.dtype) + # Unpack the int4s int int8s + if is_4_bit_quantization: + upper = quant_weights >> 4 + lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends + quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape) - return (scales.squeeze(0).to(torch.float16), processed_q_weight, dequantized_weights.T.to(device=weights.device)) + quant_weights = quant_weights.to(dtype=weights.dtype) + result = torch.multiply(quant_weights, torch_weight_scales.unsqueeze(0)).T.contiguous() + return torch_weight_scales.to(torch.float16), processed_q_weight, result.to(device=weights.device) def create_moe_onnx_graph( @@ -110,7 +87,7 @@ def create_moe_onnx_graph( fc1_experts_bias, fc2_experts_weights, fc2_experts_bias, - ort_dtype, + onnx_dtype, ): nodes = [ helper.make_node( @@ -134,21 +111,21 @@ def create_moe_onnx_graph( fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] - torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 + torch_dtype = onnx_to_torch_type_map[onnx_dtype] initializers = [ helper.make_tensor( "fc1_experts_weights", - ort_dtype, + onnx_dtype, fc1_shape, - fc1_experts_weights.to(torch_type).flatten().tolist(), + fc1_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ort_dtype, + onnx_dtype, fc2_shape, - fc2_experts_weights.to(torch_type).flatten().tolist(), + fc2_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), ] @@ -159,35 +136,35 @@ def create_moe_onnx_graph( [ helper.make_tensor( "fc1_experts_bias", - ort_dtype, + onnx_dtype, fc1_bias_shape, - fc1_experts_bias.to(torch_type).flatten().tolist(), + fc1_experts_bias.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_bias", - ort_dtype, + onnx_dtype, fc2_bias_shape, - fc2_experts_bias.to(torch_type).flatten().tolist(), + fc2_experts_bias.to(torch_dtype).flatten().tolist(), raw=False, ), ] ) graph_inputs = [ - helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ort_dtype, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -211,7 +188,7 @@ def create_mixtral_moe_onnx_graph( fc2_experts_weights, fc3_experts_weights, topk, - ort_dtype, + onnx_dtype, ): nodes = [ helper.make_node( @@ -238,46 +215,46 @@ def create_mixtral_moe_onnx_graph( fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] - torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 + torch_dtype = onnx_to_torch_type_map[onnx_dtype] initializers = [ helper.make_tensor( "fc1_experts_weights", - ort_dtype, + onnx_dtype, fc1_shape, - fc1_experts_weights.to(torch_type).flatten().tolist(), + fc1_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ort_dtype, + onnx_dtype, fc2_shape, - fc2_experts_weights.to(torch_type).flatten().tolist(), + fc2_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", - ort_dtype, + onnx_dtype, fc3_shape, - fc3_experts_weights.to(torch_type).flatten().tolist(), + fc3_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), ] graph_inputs = [ - helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ort_dtype, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -301,7 +278,7 @@ def create_phi_moe_onnx_graph( fc2_experts_weights, fc3_experts_weights, topk, - ort_dtype, + onnx_dtype, quant_bits=0, fc1_scales=None, fc2_scales=None, @@ -368,31 +345,31 @@ def create_phi_moe_onnx_graph( fc2_shape = [num_experts, inter_size, hidden_size // components] fc3_shape = [num_experts, hidden_size, inter_size // components] - torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 - numpy_type = numpy.float16 if ort_dtype == TensorProto.FLOAT16 else numpy.float32 - weight_numpy_type = numpy.uint8 if use_quant else numpy_type - weight_onnx_type = TensorProto.UINT8 if use_quant else ort_dtype + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + + weight_numpy_type = numpy.uint8 if use_quant else ort_to_numpy_type_map[onnx_dtype] + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype initializers = [ helper.make_tensor( "fc1_experts_weights", weight_onnx_type, fc1_shape, - fc1_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(), + fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", weight_onnx_type, fc2_shape, - fc2_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(), + fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", weight_onnx_type, fc3_shape, - fc3_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(), + fc3_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), ] @@ -405,42 +382,42 @@ def create_phi_moe_onnx_graph( [ helper.make_tensor( "fc1_scales", - ort_dtype, + onnx_dtype, fc1_scale_shape, - fc1_scales.to(torch_type).flatten().tolist(), + fc1_scales.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_scales", - ort_dtype, + onnx_dtype, fc2_scale_shape, - fc2_scales.to(torch_type).flatten().tolist(), + fc2_scales.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_scales", - ort_dtype, + onnx_dtype, fc3_scale_shape, - fc3_scales.to(torch_type).flatten().tolist(), + fc3_scales.to(torch_dtype).flatten().tolist(), raw=False, ), ] ) graph_inputs = [ - helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ort_dtype, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -593,127 +570,127 @@ def __init__(self, config: PhiMoEConfig): class SparseMoeBlockORTHelper(nn.Module): - def __init__(self, quant_bits=0): + def __init__(self, quant_bits=0, onnx_dtype=None): super().__init__() self.quant_bits = quant_bits - self.ort_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT - self.np_type = numpy.float16 if self.ort_dtype == TensorProto.FLOAT16 else numpy.float32 + if onnx_dtype is None: + self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT + else: + self.onnx_dtype = onnx_dtype + self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 def create_ort_session(self, moe_onnx_graph): from onnxruntime import InferenceSession, SessionOptions # noqa: PLC0415 sess_options = SessionOptions() + sess_options.log_severity_level = 2 - cuda_providers = ["CUDAExecutionProvider"] - if cuda_providers[0] not in onnxruntime.get_available_providers(): + try: + ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) + except Exception as e: + print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}") + print("Skipping ONNX Runtime execution for this test case.") return None - sess_options.log_severity_level = 2 - ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) - return ort_session def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: pass - def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: + def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: + if self.ort_sess is None: + return None + batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) + hidden_states_flat = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - ort_inputs = { - "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(self.np_type)), - "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(self.np_type)), - } + router_logits = self.gate(hidden_states_flat) - ort_output = None - if self.ort_sess is not None: - if not iobinding: - ort_output = self.ort_sess.run(None, ort_inputs) - return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits - else: - self.ort_run_with_iobinding(ort_inputs) - return None + # Determine the correct torch dtype from the onnx_dtype + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] - return None + # Prepare tensors on the correct device for ORT inference with the CORRECT dtype + tensors = { + "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + } - def ort_run_with_iobinding(self, ort_inputs, repeat=1000): + # Bind inputs and outputs to torch tensors directly. iobinding = self.ort_sess.io_binding() - device_id = torch.cuda.current_device() - - iobinding.bind_input( - name="input", - device_type="cuda", - device_id=device_id, - element_type=self.np_type, - shape=ort_inputs["input"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), - ) - - iobinding.bind_input( - name="router_probs", - device_type="cuda", - device_id=device_id, - element_type=self.np_type, - shape=ort_inputs["router_probs"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( - ort_inputs["router_probs"], "cuda", device_id - ).data_ptr(), - ) - iobinding.bind_output( - name="output", - device_type="cuda", - device_id=device_id, - element_type=self.np_type, - shape=ort_inputs["input"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( - numpy.zeros(ort_inputs["input"].shape), "cuda", device_id - ).data_ptr(), - ) - - # warm up - for _ in range(5): - iobinding.synchronize_inputs() - self.ort_sess.run_with_iobinding(iobinding) - iobinding.synchronize_outputs() - - import time # noqa: PLC0415 - - s = time.time() - for _ in range(repeat): - iobinding.synchronize_inputs() - self.ort_sess.run_with_iobinding(iobinding) - iobinding.synchronize_outputs() - e = time.time() - print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") - - def parity_check(self, atol=None, rtol=None): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) + for name, tensor in tensors.items(): + # Ensure tensor is on the globally defined device + if name == "output": + iobinding.bind_output( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + else: + iobinding.bind_input( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + + if enable_performance_test: + import time # noqa: PLC0415 + + repeat = 1000 + s = time.time() + for _ in range(repeat): + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + e = time.time() + print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + + # The output tensor is on `device`. Reshape and return it. + return tensors["output"].reshape(batch_size, sequence_length, hidden_dim) + + def parity_check(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) torch_output = self.forward(hidden_state) ort_output = self.ort_forward(hidden_state) - if atol is None: - atol = 5e-2 if self.quant_bits == 0 else (2.0 if self.quant_bits == 8 else 3.0) - - if rtol is None: - rtol = 1e-3 if self.quant_bits == 0 else 1e-2 + dtype_str = ort_dtype_name_map[self.onnx_dtype] + + # Maps "ort_type:quant_bits" to (atol, rtol) + ort_dtype_quant_bits_tolerance_map = { + "FP32:0": (5e-3, 1e-3), + "FP16:0": (5e-2, 1e-3), + "FP16:4": (3.0, 1e-2), + "FP16:8": (2.0, 1e-2), + "BF16:0": (1.0, 1e-2), + "BF16:4": (30.0, 1e-1), + "BF16:8": (20.0, 1e-1), + } + atol, rtol = ort_dtype_quant_bits_tolerance_map[f"{dtype_str}:{self.quant_bits}"] if ort_output is not None: - dtype_str = "FP32" if self.quant_bits == 0 else "FP16" print( f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str}," f" batch: {self.batch_size}, seq_len: {self.sequence_length}," - f" max_diff: {(torch_output - ort_output).abs().max()}" + f" max_diff: {(torch_output.cpu() - ort_output.cpu()).abs().max()}" ) torch.testing.assert_close( - ort_output.to(torch.float32), torch_output.to(torch.float32), rtol=rtol, atol=atol + ort_output.cpu().to(torch.float32), torch_output.cpu().to(torch.float32), rtol=rtol, atol=atol ) def benchmark_ort(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) - self.ort_forward(hidden_state, iobinding=True) + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + self.ort_forward(hidden_state, enable_performance_test=True) class SwitchMoE(SparseMoeBlockORTHelper): @@ -757,7 +734,7 @@ def __init__( self.moe_experts.bias1, self.moe_experts.weight2.transpose(1, 2), self.moe_experts.bias2, - self.ort_dtype, + self.onnx_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -827,7 +804,7 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight2, self.moe_experts_weight3, self.top_k, - self.ort_dtype, + self.onnx_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -924,8 +901,8 @@ class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): and memory on padding. """ - def __init__(self, config, batch_size, sequence_length, quant_bits=0): - super().__init__(quant_bits) + def __init__(self, config, batch_size, sequence_length, quant_bits=0, onnx_dtype=None): + super().__init__(quant_bits, onnx_dtype) self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts @@ -950,18 +927,18 @@ def __init__(self, config, batch_size, sequence_length, quant_bits=0): is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): # Corrected quantization logic for per-output-channel quantization - w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight.T, is_4_bit) - w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight.T, is_4_bit) - w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight.T, is_4_bit) + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) self.experts[i].w1.weight.data = w1_qdq self.experts[i].w2.weight.data = w2_qdq self.experts[i].w3.weight.data = w3_qdq # Transpose quantized weights to match the expected ONNX layout - w1_list.append(pre_qweight1.T) - w2_list.append(pre_qweight2.T) - w3_list.append(pre_qweight3.T) + w1_list.append(pre_qweight1) + w2_list.append(pre_qweight2) + w3_list.append(pre_qweight3) w1_scale_list.append(w1_scale) w2_scale_list.append(w2_scale) w3_scale_list.append(w3_scale) @@ -985,7 +962,7 @@ def __init__(self, config, batch_size, sequence_length, quant_bits=0): self.moe_experts_weight2, self.moe_experts_weight3, self.top_k, - self.ort_dtype, + self.onnx_dtype, self.quant_bits, moe_experts_weight_scale1, moe_experts_weight_scale2, @@ -1045,25 +1022,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def small_test_cases(): for batch_size in [1, 4, 16]: for sequence_length in [128, 512, 1024]: - yield batch_size, sequence_length, 0 - - -# Test cases for Phi-3 MoE. -# We test three modes: no quantization, 8-bit, and 4-bit. -phi3_test_params = list( - itertools.product( - [1, 4], # batch_size - [1, 32], # sequence_length - [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) - ) -) + yield batch_size, sequence_length +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestSwitchMoE(unittest.TestCase): @parameterized.expand(small_test_cases()) - def test_switch_moe_parity(self, batch_size, sequence_length, quant_bits): - # if platform.system() == "Windows": - # pytest.skip("Skip on Windows") + def test_switch_moe_parity(self, batch_size, sequence_length): switch_moe = SwitchMoE( batch_size=batch_size, sequence_length=sequence_length, @@ -1072,26 +1037,42 @@ def test_switch_moe_parity(self, batch_size, sequence_length, quant_bits): hidden_features=1024, out_features=256, ) + switch_moe.to(device) switch_moe.parity_check() - # switch_moe.benchmark_ort() +# quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) +# since qMoE test requires tensorrt_llm for quant_dequant. We disable it in CI pipeline to avoid extra dependency. +quant_bits_list = [0] if pipeline_mode else [0, 8, 4] + + +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestMixtralMoE(unittest.TestCase): - @parameterized.expand([(b, s, q) for b, s, q in small_test_cases() if q == 0]) # only run non-quantized - def test_mixtral_moe_parity(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(small_test_cases()) + def test_mixtral_moe_parity(self, batch_size, sequence_length): config = MixtralConfig(hidden_size=256, intermediate_size=1024) mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) + mixtral_moe.to(device) mixtral_moe.parity_check() - # mixtral_moe.benchmark_ort() +phi3_test_cases = list( + itertools.product( + [1, 4], # batch_size + [1, 32], # sequence_length + quant_bits_list, + ) +) + + +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestPhiMoE(unittest.TestCase): - @parameterized.expand(phi3_test_params) + @parameterized.expand(phi3_test_cases) def test_phi3_moe_parity(self, batch_size, sequence_length, quant_bits): config = PhiMoEConfig(hidden_size=256, intermediate_size=1024) phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits) + phi3_moe.to(device) phi3_moe.parity_check() - # phi3_moe.benchmark_ort() # --------------------------------------------- @@ -1111,6 +1092,14 @@ def __init__( self.num_local_experts = num_local_experts +def swiglu(x: torch.Tensor): + dim = x.shape[-1] + x = x.view(-1, dim // 2, 2) + x_glu, x_linear = x[..., 0], x[..., 1] + y = x_glu * torch.sigmoid(1.702 * x_glu) * (x_linear + 1) + return y + + class SwigluMlp(nn.Module): def __init__(self, config): super().__init__() @@ -1119,26 +1108,45 @@ def __init__(self, config): self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) - def swiglu(self, x: torch.Tensor): - dim = x.shape[-1] - x = x.view(-1, dim // 2, 2) - x_glu, x_linear = x[..., 0], x[..., 1] - y = x_glu * torch.sigmoid(1.702 * x_glu) * (x_linear + 1) - return y - def forward(self, x): - y = self.swiglu(self.w1(x)) + x1 = self.w1(x) + y = swiglu(x1) y = self.w2(y) return y +# Note that the shape might not match the tensor shape. See Attention note in this file. +def make_onnx_intializer(name: str, tensor: torch.Tensor, shape, onnx_dtype): + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + if torch_dtype == torch.bfloat16: + numpy_vals_uint16 = tensor.to(torch.bfloat16).cpu().view(torch.uint16).numpy() + initializer = helper.make_tensor( + name=name, + data_type=TensorProto.BFLOAT16, + dims=shape, + vals=numpy_vals_uint16.tobytes(), + raw=True, + ) + else: + initializer = helper.make_tensor( + name=name, + data_type=onnx_dtype, + dims=shape, + vals=tensor.flatten().detach().cpu().numpy().astype(numpy.uint8).tolist() + if onnx_dtype == TensorProto.UINT8 + else tensor.detach().to(torch_dtype).flatten().tolist(), + raw=False, + ) + return initializer + + def create_swiglu_moe_onnx_graph( num_tokens: int, num_experts: int, hidden_size: int, inter_size: int, topk: int, - ort_dtype: int, + onnx_dtype: int, quant_bits: int, fc1_experts_weights: torch.Tensor, fc1_experts_bias: torch.Tensor, @@ -1189,6 +1197,8 @@ def create_swiglu_moe_onnx_graph( nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) components = 2 if quant_bits == 4 else 1 + # ATTENTION: Actual weight layout is like [num_experts, 2 * inter_size, hidden_size // components] + # Here we claim a different shape for the initializer to match the operator spec for weight tensor! fc1_weight_shape = [num_experts, hidden_size, 2 * inter_size // components] fc1_bias_shape = [num_experts, 2 * inter_size] fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size] @@ -1197,80 +1207,54 @@ def create_swiglu_moe_onnx_graph( fc2_bias_shape = [num_experts, hidden_size] fc2_experts_weight_scale_shape = [num_experts, hidden_size] - torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 - numpy_type = numpy.float16 if ort_dtype == TensorProto.FLOAT16 else numpy.float32 - weight_numpy_type = numpy.uint8 if use_quant else numpy_type - weight_onnx_type = TensorProto.UINT8 if use_quant else ort_dtype + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + weight_torch_dtype = onnx_to_torch_type_map[weight_onnx_type] initializers = [ - helper.make_tensor( - "fc1_experts_weights", - weight_onnx_type, - fc1_weight_shape, - fc1_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist() - if use_quant - else fc1_experts_weights.to(torch_type).flatten().tolist(), - raw=False, + make_onnx_intializer( + "fc1_experts_weights", fc1_experts_weights.to(weight_torch_dtype), fc1_weight_shape, weight_onnx_type ), - helper.make_tensor( - "fc1_experts_bias", - ort_dtype, - fc1_bias_shape, - fc1_experts_bias.to(torch_type).flatten().tolist(), - raw=False, - ), - helper.make_tensor( - "fc2_experts_weights", - weight_onnx_type, - fc2_weight_shape, - fc2_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist() - if use_quant - else fc2_experts_weights.to(torch_type).flatten().tolist(), - raw=False, - ), - helper.make_tensor( - "fc2_experts_bias", - ort_dtype, - fc2_bias_shape, - fc2_experts_bias.to(torch_type).flatten().tolist(), - raw=False, + make_onnx_intializer("fc1_experts_bias", fc1_experts_bias.to(torch_dtype), fc1_bias_shape, onnx_dtype), + make_onnx_intializer( + "fc2_experts_weights", fc2_experts_weights.to(weight_torch_dtype), fc2_weight_shape, weight_onnx_type ), + make_onnx_intializer("fc2_experts_bias", fc2_experts_bias.to(torch_dtype), fc2_bias_shape, onnx_dtype), ] if use_quant: initializers.extend( [ - helper.make_tensor( + make_onnx_intializer( "fc1_experts_weight_scale", - ort_dtype, + fc1_experts_weight_scale.to(torch_dtype), fc1_experts_weight_scale_shape, - fc1_experts_weight_scale.to(torch_type).flatten().tolist(), - raw=False, + onnx_dtype, ), - helper.make_tensor( + make_onnx_intializer( "fc2_experts_weight_scale", - ort_dtype, + fc2_experts_weight_scale.to(torch_dtype), fc2_experts_weight_scale_shape, - fc2_experts_weight_scale.to(torch_type).flatten().tolist(), - raw=False, + onnx_dtype, ), ] ) graph_inputs = [ - helper.make_tensor_value_info("input", ort_dtype, [num_tokens, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [num_tokens, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ort_dtype, + onnx_dtype, [num_tokens, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ort_dtype, [num_tokens, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [num_tokens, hidden_size]), ] graph = helper.make_graph( @@ -1286,8 +1270,10 @@ def create_swiglu_moe_onnx_graph( class SwigluMoEBlock(SparseMoeBlockORTHelper): - def __init__(self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0): - super().__init__(quant_bits) + def __init__( + self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts @@ -1298,54 +1284,61 @@ def __init__(self, config: SwigluMoeConfig, batch_size: int, sequence_length: in self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) - weight_1_list, weight_2_list = [], [] - bias_1_list, bias_2_list = [], [] + # For the ONNX MoE operator, weights must be transposed to [In, Out] format. + # Biases do not require transposition. + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] scale_1_list, scale_2_list = [], [] - for i in range(self.num_experts): - bias_1_list.append(self.experts[i].w1.bias) - bias_2_list.append(self.experts[i].w2.bias) + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) if not use_quant: - weight_1_list.append(self.experts[i].w1.weight) - weight_2_list.append(self.experts[i].w2.weight) + # ATTENTION: Weight tensor for CUDA shall have [E, out, in] memory layout just like Linear. + # But the initializer shape shall be [E, in, out] to match op spec. + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) else: is_4_bit = self.quant_bits == 4 - # Pass the transposed weight to quant_dequant to get correct scales, - # then transpose the resulting quantized weight back to the expected layout. - scale1, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight.T, is_4_bit) - scale2, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight.T, is_4_bit) - self.experts[i].w1.weight.data = w1_qdq - self.experts[i].w2.weight.data = w2_qdq + # quant_dequant expects [Out, In] format, matching nn.Linear.weight + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) - weight_1_list.append(pre_qweight1.T) - weight_2_list.append(pre_qweight2.T) + # Update the expert's weight with the dequantized version for the PyTorch reference. + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) scale_1_list.append(scale1) scale_2_list.append(scale2) - self.moe_experts_weight1 = torch.stack(weight_1_list, dim=0) - self.moe_experts_weight2 = torch.stack(weight_2_list, dim=0) - - self.moe_experts_bias1 = torch.stack(bias_1_list, dim=0) - self.moe_experts_bias2 = torch.stack(bias_2_list, dim=0) + # Stack the prepared tensors for the graph builder + fc1_experts_weights = torch.stack(fc1_w_list, dim=0) + fc2_experts_weights = torch.stack(fc2_w_list, dim=0) + fc1_experts_bias = torch.stack(fc1_b_list, dim=0) + fc2_experts_bias = torch.stack(fc2_b_list, dim=0) moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None self.batch_size = batch_size self.sequence_length = sequence_length + + # Build the ONNX graph with the correctly shaped tensors self.moe_onnx_graph = create_swiglu_moe_onnx_graph( num_tokens=self.batch_size * self.sequence_length, num_experts=self.num_experts, hidden_size=self.hidden_dim, inter_size=self.ffn_dim, topk=self.top_k, - ort_dtype=self.ort_dtype, + onnx_dtype=self.onnx_dtype, quant_bits=self.quant_bits, - fc1_experts_weights=self.moe_experts_weight1, - fc1_experts_bias=self.moe_experts_bias1, - fc2_experts_weights=self.moe_experts_weight2, - fc2_experts_bias=self.moe_experts_bias2, + fc1_experts_weights=fc1_experts_weights, + fc1_experts_bias=fc1_experts_bias, + fc2_experts_weights=fc2_experts_weights, + fc2_experts_bias=fc2_experts_bias, fc1_experts_weight_scale=moe_experts_weight_scale1, fc2_experts_weight_scale=moe_experts_weight_scale2, ) @@ -1353,26 +1346,24 @@ def __init__(self, config: SwigluMoeConfig, batch_size: int, sequence_length: in self.ort_sess = self.create_ort_session(self.moe_onnx_graph) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ + """ + This is the robust PyTorch reference implementation. It directly uses the + nn.Module experts, which is cleaner and less error-prone than manual matmul. + """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(hidden_states) # router_logits shape is (batch * sequence_length, num_experts) + router_logits = self.gate(hidden_states) routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) - routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float) - # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) @@ -1380,35 +1371,78 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if top_x.shape[0] == 0: continue - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states -swiglu_test_params = list( +swiglu_test_cases = list( itertools.product( - [1, 4], # batch_size - [1, 32], # sequence_length - [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + [1, 2], # batch_size + [1, 3], # sequence_length + quant_bits_list, # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) ) ) +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestSwigluMoE(unittest.TestCase): - @parameterized.expand(swiglu_test_params) + @parameterized.expand(swiglu_test_cases) def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): - config = SwigluMoeConfig(hidden_size=128, intermediate_size=512, num_experts_per_token=1, num_local_experts=4) + config = SwigluMoeConfig(hidden_size=64, intermediate_size=256, num_experts_per_token=2, num_local_experts=4) moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.parity_check() + + +def has_bf16_moe(): + if "CUDAExecutionProvider" not in onnxruntime.get_available_providers() or not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 8 + + +@unittest.skipIf(not has_bf16_moe(), "skipping bf16 moe tests.") +class TestSwigluMoeBf16(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=64, intermediate_size=128, num_experts_per_token=2, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits, onnx_dtype=TensorProto.BFLOAT16) + moe.to(device) moe.parity_check() +perf_test_cases = list( + itertools.product( + [1], # batch_size + [128, 512, 1024, 2048, 4096], # sequence_length + [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + ) +) + + +@unittest.skipIf(pipeline_mode or not use_cuda, "skipping performance test in CI pipeline.") +class TestSwigluMoEPerf(unittest.TestCase): + @parameterized.expand(perf_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + hidden_size = 2880 + intermediate_size = 2880 + num_experts_per_token = 8 + num_local_experts = 128 + config = SwigluMoeConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts_per_token=num_experts_per_token, + num_local_experts=num_local_experts, + ) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.benchmark_ort() + + if __name__ == "__main__": unittest.main()