@@ -369,7 +369,7 @@ void trtllm_fp8_per_tensor_scale_moe(
369369 auto const hidden_size = hidden_states.size (1 );
370370 bool mUseDeepSeekFp8 {false }; // FP8 per-tensor doesn't use DeepSeek FP8
371371
372- std::vector<int32_t > mSupportedTileN = {8 , 16 , 32 , 64 };
372+ std::vector<int32_t > mSupportedTileN = {8 , 16 , 32 , 64 , 128 };
373373 std::set<int32_t > selected_tile_nums =
374374 computeSelectedTileN (mSupportedTileN , num_tokens, top_k, local_num_experts);
375375
@@ -929,10 +929,6 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
929929
930930 Tensor permuted_idx_to_token_idx =
931931 alloc_tensor ({max_num_padded_tokens}, dl_int32, hidden_states.device ());
932- // Tensor expert_weights = alloc_tensor(
933- // {args.num_tokens, args.top_k}, dl_bfloat16, hidden_states.device());
934- // Tensor expert_indexes = alloc_tensor(
935- // {args.num_tokens, args.top_k}, dl_int32, hidden_states.device();
936932 int64_t const size_of_expert_count_histogram = std::max (num_experts * 2 , int64_t (256 * 2 ));
937933 Tensor expert_count_histogram =
938934 alloc_tensor ({size_of_expert_count_histogram}, dl_int32, hidden_states.device ());
@@ -942,10 +938,6 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
942938 // allocate workspace for activation/gemm/finalize kernels
943939 auto const gemm1_output_hidden =
944940 dtype_act == btg::Dtype::E2m1 ? intermediate_size / 2 : intermediate_size;
945- // Tensor gemm1_output = alloc_tensor(
946- // {max_num_padded_tokens, gemm1_output_hidden},
947- // dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_float8_e4m3fn,
948- // hidden_states.device());
949941 Tensor gemm1_output = alloc_tensor ({max_num_padded_tokens_gemm1, gemm1_output_hidden},
950942 dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8,
951943 hidden_states.device ());
@@ -1274,8 +1266,14 @@ int64_t trtllm_get_default_moe_configs(int64_t const dtype_act_, int64_t const d
12741266 auto dtype_act = static_cast <btg::Dtype>(dtype_act_);
12751267 auto dtype_weights = static_cast <btg::Dtype>(dtype_weights_);
12761268 std::vector<int32_t > supported_tile_nums = {8 , 16 , 32 , 64 };
1277- if ((dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1278- dtype_act != btg::Dtype::Bfloat16) {
1269+ // Check if we should add tile size 128
1270+ bool is_fp4_without_bf16_act =
1271+ (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1272+ dtype_act != btg::Dtype::Bfloat16;
1273+ bool is_fp8_per_tensor =
1274+ dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8;
1275+
1276+ if (is_fp4_without_bf16_act || is_fp8_per_tensor) {
12791277 supported_tile_nums.push_back (128 );
12801278 }
12811279 std::set<int32_t > selected_tile_nums =
@@ -1300,8 +1298,14 @@ Array<Array<int64_t>> trtllm_get_valid_moe_configs(
13001298 auto dtype_act = static_cast <btg::Dtype>(dtype_act_);
13011299 auto dtype_weights = static_cast <btg::Dtype>(dtype_weights_);
13021300 std::vector<int32_t > supported_tile_nums = {8 , 16 , 32 , 64 };
1303- if ((dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1304- dtype_act != btg::Dtype::Bfloat16) {
1301+ // Check if we should add tile size 128
1302+ bool is_fp4_without_bf16_act =
1303+ (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1304+ dtype_act != btg::Dtype::Bfloat16;
1305+ bool is_fp8_per_tensor =
1306+ dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8;
1307+
1308+ if (is_fp4_without_bf16_act || is_fp8_per_tensor) {
13051309 supported_tile_nums.push_back (128 );
13061310 }
13071311 std::set<int32_t > selected_tile_nums =
0 commit comments