Skip to content

Commit 5e53008

Browse files
authored
[#8732][feat] Add ReLU2 to TRTLLM Cutlass MoE BF16 kernels (#9191)
Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent fd99164 commit 5e53008

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ struct EpilogueOpDefaultReLU
6161
{
6262
};
6363

64+
struct EpilogueOpDefaultRelu2
65+
{
66+
};
67+
6468
struct EpilogueOpDefaultFtGelu
6569
{
6670
};
@@ -122,6 +126,14 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
122126
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
123127
};
124128

129+
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
130+
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultRelu2>
131+
{
132+
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Relu2, ElementType,
133+
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, DefaultScaleMode,
134+
cutlass::FloatRoundStyle::round_to_nearest, true>;
135+
};
136+
125137
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
126138
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultFtGelu>
127139
{

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemmBiasAct(
954954
case ActivationType::Identity: runGemm<cutlass_extensions::EpilogueOpDefault>(inputs, hopper_inputs); break;
955955
case ActivationType::Swiglu: runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(inputs, hopper_inputs); break;
956956
case ActivationType::Geglu: runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(inputs, hopper_inputs); break;
957-
case ActivationType::Relu2: TLLM_THROW("Relu2 is not supported."); break;
957+
case ActivationType::Relu2: runGemm<cutlass_extensions::EpilogueOpDefaultRelu2>(inputs, hopper_inputs); break;
958958
case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break;
959959
default: TLLM_THROW("Invalid activation type."); break;
960960
}

0 commit comments

Comments
 (0)