@@ -435,7 +435,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
435435 WorkspaceInfo const & workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
436436 static_cast <int >(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream);
437437
438- auto const quant_params = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales);
438+ auto const quant_params
439+ = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales, base_activation_type);
439440 kernels::MoeMinLatencyParams min_latency_params{};
440441
441442 // TODO: support lora in the future
@@ -613,7 +614,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
613614 WorkspaceInfo const & workspace_info = getWorkspaceInfo (num_rows, hidden_size, inter_size, num_experts_total,
614615 static_cast <int >(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream);
615616
616- auto const quant_params = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales);
617+ auto const quant_params
618+ = getQuantParams (num_experts_on_rank, hidden_size, inter_size, quant_scales, base_activation_type);
617619
618620 // TODO: support lora in the future
619621 ::tensorrt_llm::kernels::LoraParams lora_params{};
@@ -859,7 +861,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
859861 }
860862
861863 kernels::QuantParams getQuantParams (int64_t const num_experts_on_rank, int64_t const hidden_size,
862- int64_t const inter_size, torch::optional<c10::ArrayRef<torch::Tensor>> const & quant_scales) const
864+ int64_t const inter_size, torch::optional<c10::ArrayRef<torch::Tensor>> const & quant_scales,
865+ ActivationType base_activation_type) const
863866 {
864867 if (isFp8Quant ())
865868 {
@@ -921,16 +924,17 @@ class FusedMoeRunner : public torch::CustomClassHolder
921924 TORCH_CHECK (fc2_weight_block.dim () == 3 , " fc2 weight block must be 3D" );
922925 TORCH_CHECK (fc2_global.dim () == 1 , " fc2 global must be 1D" );
923926 // Check shapes
927+ int expand_ratio = isGatedActivation (base_activation_type) ? 2 : 1 ;
924928 TORCH_CHECK (fc1_weight_block.sizes ()[0 ] == num_experts_on_rank
925929 && fc1_weight_block.sizes ()[1 ]
926930 == TmaWarpSpecializedGroupedGemmInput::alignToSfDim (
927931 inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX)
928- * 2
932+ * expand_ratio
929933 && fc1_weight_block.sizes ()[2 ] * FP8_PER_INT32
930934 * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
931935 == TmaWarpSpecializedGroupedGemmInput::alignToSfDim (
932936 hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX),
933- " fc1 weight block size must be (num_experts_on_rank, inter_size * 2 , hidden_size // 4 // "
937+ " fc1 weight block size must be (num_experts_on_rank, inter_size * expand_ratio , hidden_size // 4 // "
934938 " block_scale_vector_size)" );
935939 TORCH_CHECK (fc1_global.sizes ()[0 ] == num_experts_on_rank, " fc1 global size must be (num_experts_on_rank,)" );
936940 TORCH_CHECK (fc2_act_global.dim () == 0 || fc2_act_global.sizes ()[0 ] == num_experts_on_rank,
@@ -974,16 +978,17 @@ class FusedMoeRunner : public torch::CustomClassHolder
974978 TORCH_CHECK (fc1_global.dim () == 1 , " fc1 global must be 1D" );
975979 TORCH_CHECK (fc2_weight_block.dim () == 3 , " fc2 weight block must be 3D" );
976980 TORCH_CHECK (fc2_global.dim () == 1 , " fc2 global must be 1D" );
981+ int expand_ratio = isGatedActivation (base_activation_type) ? 2 : 1 ;
977982 TORCH_CHECK (fc1_weight_block.sizes ()[0 ] == num_experts_on_rank
978983 && fc1_weight_block.sizes ()[1 ]
979984 == TmaWarpSpecializedGroupedGemmInput::alignToSfDim (
980985 inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX)
981- * 2
986+ * expand_ratio
982987 && fc1_weight_block.sizes ()[2 ] * FP8_PER_INT32
983988 * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
984989 == TmaWarpSpecializedGroupedGemmInput::alignToSfDim (
985990 hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX),
986- " fc1 weight block size must be (num_experts_on_rank, inter_size * 2 , hidden_size // 4 // "
991+ " fc1 weight block size must be (num_experts_on_rank, inter_size * expand_ratio , hidden_size // 4 // "
987992 " block_scale_vector_size)" );
988993 TORCH_CHECK (fc1_global.sizes ()[0 ] == num_experts_on_rank, " fc1 global size must be (num_experts_on_rank,)" );
989994 TORCH_CHECK (fc2_weight_block.sizes ()[0 ] == num_experts_on_rank
@@ -1040,16 +1045,17 @@ class FusedMoeRunner : public torch::CustomClassHolder
10401045 // Check shapes
10411046 TORCH_CHECK (fc1_act_global.dim () == 0 || fc1_act_global.sizes ()[0 ] == num_experts_on_rank,
10421047 " fc1 act global must be scalar or (num_experts_on_rank,)" );
1048+ int expand_ratio = isGatedActivation (base_activation_type) ? 2 : 1 ;
10431049 TORCH_CHECK (fc1_weight_block.sizes ()[0 ] == num_experts_on_rank
10441050 && fc1_weight_block.sizes ()[1 ]
10451051 == TmaWarpSpecializedGroupedGemmInput::alignToSfDim (
10461052 inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4)
1047- * 2
1053+ * expand_ratio
10481054 && fc1_weight_block.sizes ()[2 ] * FP8_PER_INT32
10491055 * TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
10501056 == TmaWarpSpecializedGroupedGemmInput::alignToSfDim (
10511057 hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4),
1052- " fc1 weight block size must be (num_experts_on_rank, inter_size * 2 , hidden_size // 4 // "
1058+ " fc1 weight block size must be (num_experts_on_rank, inter_size * expand_ratio , hidden_size // 4 // "
10531059 " block_scale_vector_size)" );
10541060 TORCH_CHECK (fc1_global.sizes ()[0 ] == num_experts_on_rank, " fc1 global size must be (num_experts_on_rank,)" );
10551061 TORCH_CHECK (fc2_act_global.dim () == 0 || fc2_act_global.sizes ()[0 ] == num_experts_on_rank,
0 commit comments