Skip to content

Commit 4f5d752

Browse files
nzmora-nvidianvchenghaoz
authored andcommitted
[NVIDIA#8732][feat] Update TRTLLM Cutlass MoE kernels with ReLU2 (NVIDIA#9011)
Update TRTLLM Cutlass MoE kernels with ReLU2 activation. Nemotron-6 requires ReLU2 (i.e. squared ReLU) MoE activation function. The PR adds this and adds an API to set the activation function, in general. The ReLU2 changes are based on this FlashInfer PR: flashinfer-ai/flashinfer#1954. The PR also updates the Auto Deploy MoE backend for 16-bit and FP8 from Triton (`torch.ops.auto_deploy.triton_moe_fused`, `torch.ops.auto_deploy.triton_quant_fp8_moe`) to TRTLLM/Cutlass (`torch.ops.auto_deploy.trtllm_moe_fused`, `torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused`). Signed-off-by: Neta Zmora <[email protected]> Signed-off-by: Chenghao Zhang <[email protected]> Co-authored-by: Chenghao Zhang <[email protected]>
1 parent 19b76e4 commit 4f5d752

File tree

10 files changed

+722
-31
lines changed

10 files changed

+722
-31
lines changed

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,30 @@ __forceinline__ __device__ float tanh_opt(float x)
5959
#endif
6060
}
6161

62+
template <typename T>
63+
struct Relu2
64+
{
65+
static bool const kIsHeavy = false;
66+
67+
CUTLASS_HOST_DEVICE
68+
T operator()(T threshold, T value) const
69+
{
70+
ReLu<T> relu_op;
71+
multiplies<T> mul;
72+
T val = relu_op(threshold, value);
73+
return mul(val, val);
74+
}
75+
76+
CUTLASS_HOST_DEVICE
77+
T operator()(T value) const
78+
{
79+
ReLu<T> relu_op;
80+
multiplies<T> mul;
81+
T val = relu_op(value);
82+
return mul(val, val);
83+
}
84+
};
85+
6286
} // namespace thread
6387
} // namespace epilogue
6488
} // namespace cutlass

cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ enum class ActivationType
2929
Geglu,
3030
SwigluBias,
3131
Identity,
32+
Relu2,
3233
InvalidType
3334
};
3435

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,7 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemmBiasAct(
954954
case ActivationType::Identity: runGemm<cutlass_extensions::EpilogueOpDefault>(inputs, hopper_inputs); break;
955955
case ActivationType::Swiglu: runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(inputs, hopper_inputs); break;
956956
case ActivationType::Geglu: runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(inputs, hopper_inputs); break;
957+
case ActivationType::Relu2: TLLM_THROW("Relu2 is not supported."); break;
957958
case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break;
958959
default: TLLM_THROW("Invalid activation type."); break;
959960
}

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2307,6 +2307,8 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
23072307
decltype(block_scaling_type)::value>, // Geglu
23082308
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
23092309
decltype(block_scaling_type)::value>, // SwigluBias
2310+
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::Relu2>,
2311+
decltype(block_scaling_type)::value>, // Relu2
23102312
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
23112313
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
23122314
decltype(block_scaling_type)::value> // Identity

cpp/tensorrt_llm/kernels/internal_cutlass_kernels/include/moe_gemm_kernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ enum class ActivationType
5050
Geglu,
5151
SwigluBias,
5252
Identity,
53+
Relu2,
5354
InvalidType
5455
};
5556

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
259259
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
260260
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
261261
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
262-
torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens,
263-
torch::optional<torch::Tensor> const& out_tensor)
262+
torch::optional<int64_t> const& activation_type, torch::optional<int64_t> const& unpadded_hidden_size,
263+
torch::optional<int64_t> const& num_valid_tokens, torch::optional<torch::Tensor> const& out_tensor)
264264
{
265265
std::lock_guard<std::mutex> lock(mMutex);
266266
// Free the profile workspace to save memory
@@ -328,6 +328,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
328328
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0],
329329
"fc1_expert_weights and fc2_expert_weights must have the same number of experts.");
330330

331+
ActivationType base_activation_type = activation_type.has_value()
332+
? static_cast<ActivationType>(activation_type.value())
333+
: ActivationType::Swiglu;
331334
if (mUseINT8WoqPerChannel)
332335
{
333336
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
@@ -337,8 +340,16 @@ class FusedMoeRunner : public torch::CustomClassHolder
337340
}
338341
else
339342
{
340-
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
341-
"fc1_expert_weights inter size must be fc2_expert_weights inter size.");
343+
if (isGatedActivation(base_activation_type))
344+
{
345+
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
346+
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
347+
}
348+
else
349+
{
350+
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier,
351+
"fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.");
352+
}
342353
}
343354

344355
int experts_per_token = token_selected_experts.sizes()[1];
@@ -375,7 +386,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
375386
int const num_experts_on_rank = fc2_expert_weights.sizes()[0];
376387
auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size);
377388
auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank);
378-
ActivationType base_activation_type = ActivationType::Swiglu;
389+
379390
if (swiglu_alpha.has_value())
380391
{
381392
CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float);
@@ -474,8 +485,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
474485
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
475486
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
476487
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
477-
torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens,
478-
torch::optional<torch::Tensor> const& out_tensor)
488+
torch::optional<int64_t> const& activation_type, torch::optional<int64_t> const& unpadded_hidden_size,
489+
torch::optional<int64_t> const& num_valid_tokens, torch::optional<torch::Tensor> const& out_tensor)
479490
{
480491
std::lock_guard<std::mutex> lock(mMutex);
481492

@@ -541,7 +552,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
541552
auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size);
542553
auto parallelism_config
543554
= kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank);
544-
ActivationType base_activation_type = ActivationType::Swiglu;
555+
ActivationType base_activation_type = activation_type.has_value()
556+
? static_cast<ActivationType>(activation_type.value())
557+
: ActivationType::Swiglu;
545558
if (swiglu_alpha.has_value())
546559
{
547560
CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float);
@@ -652,7 +665,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
652665
torch::optional<torch::Tensor> const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size,
653666
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
654667
int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx,
655-
int64_t const profile_id, bool const do_preparation, int64_t const unpadded_hidden_size)
668+
int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
669+
int64_t const unpadded_hidden_size)
656670
{
657671
std::lock_guard<std::mutex> lock(mMutex);
658672

@@ -661,6 +675,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
661675
{
662676
return;
663677
}
678+
ActivationType activation_type = static_cast<ActivationType>(activation_type_int);
664679

665680
int64_t const num_rows = input.sizes()[0];
666681
int64_t hidden_size = fc2_expert_weights.sizes()[1];
@@ -715,14 +730,14 @@ class FusedMoeRunner : public torch::CustomClassHolder
715730
tensorrt_llm::runtime::TorchUtils::dataType(mWeightDtype),
716731
tensorrt_llm::runtime::TorchUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
717732
hidden_size, unpadded_hidden_size > 0 ? unpadded_hidden_size : hidden_size, inter_size, group_size,
718-
ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode,
733+
activation_type, USE_BIAS, USE_LORA, min_latency_mode,
719734
/*need_weights*/ false, parallelism_config, enable_alltoall);
720735
#else
721736
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
722737
tensorrt_llm::runtime::TorchUtils::dataType(activation_dtype),
723738
tensorrt_llm::runtime::TorchUtils::dataType(mWeightDtype),
724739
tensorrt_llm::runtime::TorchUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
725-
hidden_size, inter_size, group_size, ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode,
740+
hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA, min_latency_mode,
726741
/*need_weights*/ false, parallelism_config);
727742
#endif
728743

0 commit comments

Comments
 (0)