@@ -59,8 +59,8 @@ inline int32_t nextPowerOfTwo(float value) {
5959}
6060
6161std::set<int32_t > computeSelectedTileN (std::vector<int32_t > const & supported_tile_nums,
62- int64_t const num_tokens, int64_t const top_k,
63- int64_t const num_local_experts) {
62+ int64_t const num_tokens, int64_t const top_k,
63+ int64_t const num_local_experts) {
6464 float const avg_tokens_per_expert = static_cast <float >(num_tokens * top_k) / num_local_experts;
6565 int32_t tile_tokens_dim = std::clamp (nextPowerOfTwo (avg_tokens_per_expert),
6666 supported_tile_nums.front (), supported_tile_nums.back ());
@@ -82,7 +82,9 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
8282 int64_t const intermediate_size, int64_t const local_expert_offset,
8383 int64_t const local_num_experts, Optional<double > const routed_scaling_factor,
8484 bool const use_routing_scales_on_input, int64_t const tile_tokens_dim,
85- int64_t const routing_method_type, bool enable_pdl) {
85+ int64_t const routing_method_type,
86+ tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex,
87+ bool enable_pdl) {
8688 static const std::tuple<int , int > device_props = [hidden_states] {
8789 int major, minor;
8890 cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor,
@@ -160,6 +162,7 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
160162 } else {
161163 TVM_FFI_LOG_AND_THROW (NotImplementedError) << " Unsupported input dtype for MoE." ;
162164 }
165+ args.mDtypeOut = btg::Dtype::Bfloat16; // Output is always bfloat16 for fp8 per-tensor scale
163166
164167 args.routing_logits = routing_logits.data_ptr ();
165168 auto const routing_bias_dtype =
@@ -194,6 +197,13 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
194197 int32_t max_num_padded_tokens =
195198 tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount (
196199 args.num_tokens , top_k, num_experts, tile_tokens_dim);
200+ int32_t max_num_padded_tokens_gemm1 =
201+ tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount (
202+ max_num_padded_tokens, args.intermediate_size , btg::dtypeGetNumBits (args.mDtypeElt ));
203+ int32_t max_num_padded_tokens_gemm2 =
204+ tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount (
205+ max_num_padded_tokens, args.hidden_size , btg::dtypeGetNumBits (args.mDtypeOut ));
206+
197207 Tensor total_num_padded_tokens = alloc_tensor ({1 }, dl_int32, routing_logits.device ());
198208 Tensor expanded_idx_to_permuted_idx =
199209 alloc_tensor ({args.num_tokens * args.top_k }, dl_int32, routing_logits.device ());
@@ -210,16 +220,17 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
210220 routing_logits.device ());
211221
212222 // allocate workspace for activation/gemm/finalize kernels
213- Tensor gemm1_output =
214- alloc_tensor ({max_num_padded_tokens, 2 * intermediate_size}, dl_uint8, hidden_states.device ());
215- Tensor gemm1_output_scale = alloc_tensor ({2 * intermediate_size / 128 , max_num_padded_tokens},
216- dl_float32, hidden_states.device ());
217- Tensor activation_output =
218- alloc_tensor ({max_num_padded_tokens, intermediate_size}, dl_uint8, hidden_states.device ());
219- Tensor activation_output_scale = alloc_tensor ({intermediate_size / 128 , max_num_padded_tokens},
220- dl_float32, hidden_states.device ());
221- Tensor gemm2_output =
222- alloc_tensor ({max_num_padded_tokens, args.hidden_size }, dl_bfloat16, hidden_states.device ());
223+ Tensor gemm1_output = alloc_tensor ({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8,
224+ hidden_states.device ());
225+ Tensor gemm1_output_scale =
226+ alloc_tensor ({2 * intermediate_size / 128 , max_num_padded_tokens_gemm1}, dl_float32,
227+ hidden_states.device ());
228+ Tensor activation_output = alloc_tensor ({max_num_padded_tokens_gemm1, intermediate_size},
229+ dl_uint8, hidden_states.device ());
230+ Tensor activation_output_scale = alloc_tensor (
231+ {intermediate_size / 128 , max_num_padded_tokens_gemm1}, dl_float32, hidden_states.device ());
232+ Tensor gemm2_output = alloc_tensor ({max_num_padded_tokens_gemm2, args.hidden_size }, dl_bfloat16,
233+ hidden_states.device ());
223234
224235 int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim (
225236 args.num_tokens , args.top_k , args.num_experts , tile_tokens_dim);
@@ -289,7 +300,8 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
289300
290301 // setup workspace
291302 workspace.total_num_padded_tokens = static_cast <int *>(total_num_padded_tokens.data_ptr ());
292- workspace.total_max_padded_tokens = max_num_padded_tokens;
303+ workspace.total_max_padded_tokens =
304+ std::max (max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2);
293305 workspace.ProjUpTileN = tile_tokens_dim;
294306 workspace.routing_expert_indexes = static_cast <int *>(expert_indexes.data_ptr ());
295307 workspace.permuted_idx_size = static_cast <int *>(total_num_padded_tokens.data_ptr ());
@@ -315,13 +327,6 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
315327 args.output = output.data_ptr ();
316328 args.output_scale = nullptr ;
317329
318- tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner (
319- args.mDtypeElt , args.mUseDeepSeekFp8 , tile_tokens_dim, /* useShuffledMatrixA*/ true );
320-
321- auto const moeConfigIndex =
322- moe_runner.getDefaultValidConfigIndex (args.top_k , args.hidden_size , args.intermediate_size ,
323- args.local_num_experts , args.num_tokens );
324-
325330 auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes (args, moeConfigIndex);
326331 Tensor workspace_fc1 =
327332 alloc_tensor ({std::get<0 >(workspace_sizes)}, dl_int8, hidden_states.device ());
@@ -341,16 +346,56 @@ void trtllm_fp8_per_tensor_scale_moe(
341346 TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k,
342347 Optional<int64_t > n_group, Optional<int64_t > topk_group, int64_t intermediate_size,
343348 int64_t local_expert_offset, int64_t local_num_experts, Optional<double > routed_scaling_factor,
344- bool use_routing_scales_on_input, int64_t tile_tokens_dim, int64_t routing_method_type ,
345- bool enable_pdl ) {
349+ bool use_routing_scales_on_input, int64_t routing_method_type, bool enable_pdl ,
350+ Array< int64_t > config_index ) {
346351 auto dtype = hidden_states.dtype ();
347352 if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
353+ using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
354+
355+ // Convert PyTorch dtype to TensorRT-LLM dtype
356+ btg::Dtype mDtypeElt ;
357+ if (dtype == dl_float16) {
358+ mDtypeElt = btg::Dtype::Fp16;
359+ } else if (dtype == dl_bfloat16) {
360+ mDtypeElt = btg::Dtype::Bfloat16;
361+ } else if (dtype == dl_float8_e4m3fn) {
362+ mDtypeElt = btg::Dtype::E4m3;
363+ } else {
364+ TVM_FFI_LOG_AND_THROW (NotImplementedError) << " Unsupported input dtype for MoE." ;
365+ }
366+
367+ auto const num_tokens = hidden_states.size (0 );
368+ auto const hidden_size = hidden_states.size (1 );
369+ bool mUseDeepSeekFp8 {false }; // FP8 per-tensor doesn't use DeepSeek FP8
370+
371+ std::vector<int32_t > mSupportedTileN = {8 , 16 , 32 , 64 };
372+ std::set<int32_t > selected_tile_nums =
373+ computeSelectedTileN (mSupportedTileN , num_tokens, top_k, local_num_experts);
374+
375+ // Build runners for all supported tile sizes
376+ std::unordered_map<int32_t , std::unique_ptr<RunnerType>> mRunners ;
377+ for (int32_t tile_N : selected_tile_nums) {
378+ // Always use the two-parameter constructor for consistency
379+ mRunners .emplace (tile_N, std::make_unique<RunnerType>(mDtypeElt , mUseDeepSeekFp8 , tile_N,
380+ /* useShuffledMatrixA*/ true ));
381+ }
382+
383+ // moeConfigIndex corresponds to pair (tile_N, config)
384+ int64_t tile_N = config_index[0 ];
385+ int64_t config = config_index[1 ];
386+ // Autotuner has requested a default or 'fallback' config index
387+ if (tile_N == -1 || config == -1 ) {
388+ tile_N = *selected_tile_nums.begin ();
389+ config = mRunners [tile_N]->getDefaultValidConfigIndex (top_k, hidden_size, intermediate_size,
390+ local_num_experts, num_tokens);
391+ }
392+
348393 trtllm_fp8_per_tensor_scale_moe_launcher (
349394 routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar,
350395 output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, output, num_experts,
351396 top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
352- routed_scaling_factor, use_routing_scales_on_input, tile_tokens_dim , routing_method_type,
353- enable_pdl);
397+ routed_scaling_factor, use_routing_scales_on_input, tile_N , routing_method_type,
398+ * mRunners [tile_N], config, enable_pdl);
354399 } else {
355400 TVM_FFI_LOG_AND_THROW (NotImplementedError) << " Unsupported input dtype." ;
356401 }
@@ -651,16 +696,14 @@ void trtllm_fp8_block_scale_moe_launcher(
651696 enable_pdl);
652697}
653698
654- void trtllm_fp8_block_scale_moe (TensorView routing_logits, Optional<TensorView> routing_bias,
655- TensorView hidden_states, TensorView hidden_states_scale,
656- TensorView gemm1_weights, TensorView gemm1_weights_scale,
657- TensorView gemm2_weights, TensorView gemm2_weights_scale,
658- TensorView output, int64_t num_experts, int64_t top_k,
659- Optional<int64_t > n_group, Optional<int64_t > topk_group,
660- int64_t intermediate_size, int64_t local_expert_offset,
661- int64_t local_num_experts, Optional<double > routed_scaling_factor,
662- int64_t tile_tokens_dim, int64_t routing_method_type,
663- bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl) {
699+ void trtllm_fp8_block_scale_moe (
700+ TensorView routing_logits, Optional<TensorView> routing_bias, TensorView hidden_states,
701+ TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale,
702+ TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output,
703+ int64_t num_experts, int64_t top_k, Optional<int64_t > n_group, Optional<int64_t > topk_group,
704+ int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts,
705+ Optional<double > routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight,
706+ int64_t weight_layout, bool enable_pdl, Array<int64_t > config_index) {
664707 auto dtype = hidden_states.dtype ();
665708 if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
666709 using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
@@ -671,24 +714,36 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional<TensorView>
671714 TVM_FFI_ICHECK (0 <= weight_layout && weight_layout <= 2 )
672715 << " the value of weight_layout is not recognized" ;
673716
674- // Properly initialize the runner using make_unique like in the original code
675- auto mRunner = std::make_unique<RunnerType>(
676- mDtypeElt , mUseDeepSeekFp8 , tile_tokens_dim, use_shuffled_weight,
677- static_cast <batchedGemm::gemm::MatrixLayout>(weight_layout));
678-
679- // Always use fallback config (equivalent to moeConfigIndex == -1 case from original code)
680717 auto const num_tokens = hidden_states.size (0 );
681718 auto const hidden_size = hidden_states.size (1 );
682719
683- int64_t moeConfigIndex = mRunner ->getDefaultValidConfigIndex (
684- top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
720+ std::vector<int32_t > mSupportedTileN = {8 , 16 , 32 , 64 };
721+ std::set<int32_t > selected_tile_nums =
722+ computeSelectedTileN (mSupportedTileN , num_tokens, top_k, local_num_experts);
723+
724+ // Build runners for all supported tile sizes
725+ std::unordered_map<int32_t , std::unique_ptr<RunnerType>> mRunners ;
726+ for (int32_t tile_N : selected_tile_nums) {
727+ mRunners .emplace (tile_N, std::make_unique<RunnerType>(
728+ mDtypeElt , mUseDeepSeekFp8 , tile_N, use_shuffled_weight,
729+ static_cast <batchedGemm::gemm::MatrixLayout>(weight_layout)));
730+ }
731+
732+ // moeConfigIndex corresponds to pair (tile_N, config)
733+ int64_t tile_N = config_index[0 ];
734+ int64_t config = config_index[1 ];
735+ // Autotuner has requested a default or 'fallback' config index
736+ if (tile_N == -1 || config == -1 ) {
737+ tile_N = *selected_tile_nums.begin ();
738+ config = mRunners [tile_N]->getDefaultValidConfigIndex (top_k, hidden_size, intermediate_size,
739+ local_num_experts, num_tokens);
740+ }
685741
686742 trtllm_fp8_block_scale_moe_launcher (
687743 routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights,
688744 gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output, num_experts, top_k,
689745 n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
690- routed_scaling_factor, tile_tokens_dim, routing_method_type, *mRunner , moeConfigIndex,
691- enable_pdl);
746+ routed_scaling_factor, tile_N, routing_method_type, *mRunners [tile_N], config, enable_pdl);
692747 } else {
693748 TVM_FFI_LOG_AND_THROW (NotImplementedError) << " Unsupported hidden state dtype." ;
694749 }
@@ -1184,29 +1239,29 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
11841239 computeSelectedTileN (mSupportedTileN , num_tokens, top_k, local_num_experts);
11851240 // Build runners for all supported tile sizes
11861241 std::unordered_map<int32_t , std::unique_ptr<RunnerType>> mRunners ;
1187- for (int32_t tileN : selected_tile_nums) {
1188- mRunners .emplace (tileN ,
1189- std::make_unique<RunnerType>(mDtypeAct , mDtypeWeights , mUseDeepSeekFp8 , tileN ,
1242+ for (int32_t tile_N : selected_tile_nums) {
1243+ mRunners .emplace (tile_N ,
1244+ std::make_unique<RunnerType>(mDtypeAct , mDtypeWeights , mUseDeepSeekFp8 , tile_N ,
11901245 static_cast <GatedActType>(gated_act_type),
11911246 /* useShuffledMatrixA*/ true ));
11921247 }
11931248
1194- // moeConfigIndex corresponds to pair (tileN , config)
1195- int64_t tileN = config_index[0 ];
1249+ // moeConfigIndex corresponds to pair (tile_N , config)
1250+ int64_t tile_N = config_index[0 ];
11961251 int64_t config = config_index[1 ];
11971252 // Autotuner has requested a default or 'fallback' config index
1198- if (tileN == -1 || config == -1 ) {
1199- tileN = *selected_tile_nums.begin ();
1200- config = mRunners [tileN ]->getDefaultValidConfigIndex (top_k, hidden_size, intermediate_size,
1201- local_num_experts, num_tokens);
1253+ if (tile_N == -1 || config == -1 ) {
1254+ tile_N = *selected_tile_nums.begin ();
1255+ config = mRunners [tile_N ]->getDefaultValidConfigIndex (top_k, hidden_size, intermediate_size,
1256+ local_num_experts, num_tokens);
12021257 }
12031258 return trtllm_fp4_block_scale_moe_launcher (
12041259 routing_logits, topk_ids, expert_weights, routing_bias, hidden_states, hidden_states_scale,
12051260 gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit,
12061261 gemm2_weights, gemm2_weights_scale, gemm2_bias, output1_scales_scalar,
12071262 output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
1208- intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN ,
1209- routing_method_type, do_finalize, *mRunners [tileN ], mDtypeAct , mDtypeWeights , config,
1263+ intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tile_N ,
1264+ routing_method_type, do_finalize, *mRunners [tile_N ], mDtypeAct , mDtypeWeights , config,
12101265 enable_pdl, output);
12111266}
12121267
@@ -1218,41 +1273,58 @@ int64_t trtllm_get_default_moe_configs(int64_t const dtype_act_, int64_t const d
12181273 auto dtype_act = static_cast <btg::Dtype>(dtype_act_);
12191274 auto dtype_weights = static_cast <btg::Dtype>(dtype_weights_);
12201275 std::vector<int32_t > supported_tile_nums = {8 , 16 , 32 , 64 };
1221- if (dtype_act != btg::Dtype::Bfloat16) {
1276+ if ((dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1277+ dtype_act != btg::Dtype::Bfloat16) {
12221278 supported_tile_nums.push_back (128 );
12231279 }
12241280 std::set<int32_t > selected_tile_nums =
12251281 computeSelectedTileN (supported_tile_nums, num_tokens, top_k, num_local_experts);
1226- tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner (
1227- dtype_act, dtype_weights, useDeepSeekFp8, *selected_tile_nums.begin (),
1228- static_cast <GatedActType>(gated_act_type), /* useShuffledMatrixA*/ true );
1229- return moe_runner.getDefaultValidConfigIndex (top_k, hidden_size, intermediate_size,
1230- num_local_experts, num_tokens);
1282+
1283+ std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner =
1284+ std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
1285+ dtype_act, dtype_weights, useDeepSeekFp8, *selected_tile_nums.begin (),
1286+ static_cast <GatedActType>(gated_act_type), /* useShuffledMatrixA*/ true );
1287+
1288+ return moe_runner->getDefaultValidConfigIndex (top_k, hidden_size, intermediate_size,
1289+ num_local_experts, num_tokens);
12311290}
12321291
12331292Array<Array<int64_t >> trtllm_get_valid_moe_configs (
12341293 int64_t const dtype_act_, int64_t const dtype_weights_, bool const useDeepSeekFp8,
12351294 int64_t const top_k, int64_t const hidden_size, int64_t const intermediate_size,
1236- int64_t const num_local_experts, int64_t const gated_act_type, int64_t const num_tokens) {
1237- // returns (tileN, config)
1295+ int64_t const num_local_experts, int64_t const gated_act_type, bool const use_shuffled_weight,
1296+ int64_t const weight_layout, int64_t const num_tokens) {
1297+ // returns (tile_N, config)
12381298 Array<Array<int64_t >> valid_configs;
12391299 auto dtype_act = static_cast <btg::Dtype>(dtype_act_);
12401300 auto dtype_weights = static_cast <btg::Dtype>(dtype_weights_);
12411301 std::vector<int32_t > supported_tile_nums = {8 , 16 , 32 , 64 };
1242- if (dtype_act != btg::Dtype::Bfloat16) {
1302+ if ((dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
1303+ dtype_act != btg::Dtype::Bfloat16) {
12431304 supported_tile_nums.push_back (128 );
12441305 }
12451306 std::set<int32_t > selected_tile_nums =
12461307 computeSelectedTileN (supported_tile_nums, num_tokens, top_k, num_local_experts);
12471308
1248- for (int32_t tileN : selected_tile_nums) {
1249- tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner (
1250- dtype_act, dtype_weights, useDeepSeekFp8, tileN, static_cast <GatedActType>(gated_act_type),
1251- /* useShuffledMatrixA*/ true );
1252- auto cfgs = moe_runner.getValidConfigIndices (top_k, hidden_size, intermediate_size,
1253- num_local_experts, num_tokens);
1309+ for (int32_t tile_N : selected_tile_nums) {
1310+ std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner;
1311+
1312+ if (dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3) {
1313+ // FP8 block scale MOE runner
1314+ moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
1315+ dtype_weights, useDeepSeekFp8, tile_N, use_shuffled_weight,
1316+ static_cast <batchedGemm::gemm::MatrixLayout>(weight_layout));
1317+ } else {
1318+ // FP4 block scale MOE runner
1319+ moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
1320+ dtype_act, dtype_weights, useDeepSeekFp8, tile_N,
1321+ static_cast <GatedActType>(gated_act_type),
1322+ /* useShuffledMatrixA*/ true );
1323+ }
1324+ auto cfgs = moe_runner->getValidConfigIndices (top_k, hidden_size, intermediate_size,
1325+ num_local_experts, num_tokens);
12541326 for (auto cfg : cfgs) {
1255- valid_configs.push_back ({tileN , cfg});
1327+ valid_configs.push_back ({tile_N , cfg});
12561328 }
12571329 }
12581330 return valid_configs;
0 commit comments