Skip to content

Commit 0e63887

Browse files
committed
add 128 to FP8 PT
Signed-off-by: jiahanc <[email protected]>
1 parent c4be849 commit 0e63887

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ void trtllm_fp8_per_tensor_scale_moe(
369369
auto const hidden_size = hidden_states.size(1);
370370
bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8
371371

372-
std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64};
372+
std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128};
373373
std::set<int32_t> selected_tile_nums =
374374
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
375375

@@ -929,10 +929,6 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
929929

930930
Tensor permuted_idx_to_token_idx =
931931
alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states.device());
932-
// Tensor expert_weights = alloc_tensor(
933-
// {args.num_tokens, args.top_k}, dl_bfloat16, hidden_states.device());
934-
// Tensor expert_indexes = alloc_tensor(
935-
// {args.num_tokens, args.top_k}, dl_int32, hidden_states.device();
936932
int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2));
937933
Tensor expert_count_histogram =
938934
alloc_tensor({size_of_expert_count_histogram}, dl_int32, hidden_states.device());
@@ -942,10 +938,6 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
942938
// allocate workspace for activation/gemm/finalize kernels
943939
auto const gemm1_output_hidden =
944940
dtype_act == btg::Dtype::E2m1 ? intermediate_size / 2 : intermediate_size;
945-
// Tensor gemm1_output = alloc_tensor(
946-
// {max_num_padded_tokens, gemm1_output_hidden},
947-
// dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_float8_e4m3fn,
948-
// hidden_states.device());
949941
Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, gemm1_output_hidden},
950942
dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8,
951943
hidden_states.device());
@@ -1274,8 +1266,14 @@ int64_t trtllm_get_default_moe_configs(int64_t const dtype_act_, int64_t const d
12741266
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
12751267
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
12761268
std::vector<int32_t> supported_tile_nums = {8, 16, 32, 64};
1277-
if ((dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1278-
dtype_act != btg::Dtype::Bfloat16) {
1269+
// Check if we should add tile size 128
1270+
bool is_fp4_without_bf16_act =
1271+
(dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1272+
dtype_act != btg::Dtype::Bfloat16;
1273+
bool is_fp8_per_tensor =
1274+
dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8;
1275+
1276+
if (is_fp4_without_bf16_act || is_fp8_per_tensor) {
12791277
supported_tile_nums.push_back(128);
12801278
}
12811279
std::set<int32_t> selected_tile_nums =
@@ -1300,8 +1298,14 @@ Array<Array<int64_t>> trtllm_get_valid_moe_configs(
13001298
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
13011299
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
13021300
std::vector<int32_t> supported_tile_nums = {8, 16, 32, 64};
1303-
if ((dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1304-
dtype_act != btg::Dtype::Bfloat16) {
1301+
// Check if we should add tile size 128
1302+
bool is_fp4_without_bf16_act =
1303+
(dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1304+
dtype_act != btg::Dtype::Bfloat16;
1305+
bool is_fp8_per_tensor =
1306+
dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8;
1307+
1308+
if (is_fp4_without_bf16_act || is_fp8_per_tensor) {
13051309
supported_tile_nums.push_back(128);
13061310
}
13071311
std::set<int32_t> selected_tile_nums =

0 commit comments

Comments
 (0)