@@ -96,9 +96,7 @@ struct genericMoeGemmKernelLauncher {
9696
9797 static_assert (cutlass::platform::is_same<T, WeightType>::value ||
9898 cutlass::platform::is_same<WeightType, uint8_t >::value ||
99- #if defined(ENABLE_FP4)
10099 cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value ||
101- #endif
102100 cutlass::platform::is_same<WeightType, cutlass::uint4b_t >::value);
103101
104102 static_assert (arch::kMinComputeCapability < 90 ,
@@ -739,42 +737,34 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
739737 " Hopper configuration provided for non-Hopper architecture" );
740738
741739 if (sm_ >= 75 && sm_ < 80 ) {
742- #if defined(ENABLE_FP4)
743740 if constexpr (!std::is_same_v<WeightType, __nv_fp4_e2m1>) {
744- #endif
745741 cutlass_kernels_oss::dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType,
746742 cutlass::arch::Sm75, EpilogueTag>(
747743 inputs, multi_processor_count_);
748- #if defined(ENABLE_FP4)
749744 } else {
750745 TLLM_THROW (" FP4 data type is not supported on SM < 90" );
751746 }
752- #endif
753747 } else if (sm_ >= 80 && sm_ < 90 ) {
754- if constexpr (use_fp8 || use_w4afp8) {
748+ if constexpr (!std::is_same_v<WeightType, __nv_fp4_e2m1>) {
749+ if constexpr (use_fp8 || use_w4afp8) {
755750#if defined(ENABLE_FP8)
756- static_assert (
757- !std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>,
758- " FP8 GEMM Output not supported" );
751+ static_assert (!std::is_same_v<OutputType, __nv_fp8_e4m3> &&
752+ !std::is_same_v<OutputType, __nv_fp8_e5m2>,
753+ " FP8 GEMM Output not supported" );
759754#endif
760755
761- TLLM_CHECK_WITH_INFO (sm_ == 89 , " For sm >= 80 and < 90, fp8 is only supported with sm == 89" );
762- cutlass_kernels_oss::dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType,
763- cutlass::arch::Sm89, EpilogueTag>(
764- inputs, multi_processor_count_);
765- } else {
766- #ifdef ENABLE_FP4
767- if constexpr (std::is_same_v<WeightType, __nv_fp4_e2m1>) {
768- TLLM_THROW (" FP4 data type is not supported on SM < 90" );
756+ TLLM_CHECK_WITH_INFO (sm_ == 89 ,
757+ " For sm >= 80 and < 90, fp8 is only supported with sm == 89" );
758+ cutlass_kernels_oss::dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType,
759+ cutlass::arch::Sm89, EpilogueTag>(
760+ inputs, multi_processor_count_);
769761 } else {
770762 cutlass_kernels_oss::dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType,
771763 cutlass::arch::Sm80, EpilogueTag>(
772764 inputs, multi_processor_count_);
773765 }
774- #else
775- dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(
776- inputs, multi_processor_count_);
777- #endif
766+ } else {
767+ TLLM_THROW (" FP4 data type is not supported on SM < 90" );
778768 }
779769 } else if (sm_ >= 90 ) {
780770 // For SM120+ pure FP8 MoE (not FP8 x FP4), redirect to SM89 (Ada) FP8 kernel implementations.
@@ -995,9 +985,6 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemmBiasAct(
995985 case ActivationType::Geglu:
996986 runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(inputs, hopper_inputs);
997987 break ;
998- case ActivationType::Relu2:
999- TLLM_THROW (" Relu2 is not supported." );
1000- break ;
1001988 case ActivationType::InvalidType:
1002989 TLLM_THROW (" Activation type for fpA_intB must be valid." );
1003990 break ;
0 commit comments