@@ -41,9 +41,9 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
4141 Tensor routing_logits, Optional<Tensor> routing_bias, Tensor hidden_states,
4242 Tensor gemm1_weights, Tensor output1_scales_scalar, Tensor output1_scales_gate_scalar,
4343 Tensor gemm2_weights, Tensor output2_scales_scalar, int64_t const num_experts,
44- int64_t const top_k, int64_t const n_group, int64_t const topk_group,
44+ int64_t const top_k, Optional< int64_t > const n_group, Optional< int64_t > const topk_group,
4545 int64_t const intermediate_size, int64_t const local_expert_offset,
46- int64_t const local_num_experts, double const routed_scaling_factor,
46+ int64_t const local_num_experts, Optional< double > const routed_scaling_factor,
4747 bool const use_routing_scales_on_input, int64_t const tile_tokens_dim,
4848 int64_t const routing_method_type, bool enable_pdl) {
4949 static const std::tuple<int , int > device_props = [hidden_states] {
@@ -61,8 +61,11 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
6161
6262 if (use_routing_scales_on_input) {
6363 TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_bfloat16) << " routing_logits must be bfloat16." ;
64- } else {
64+ } else if (static_cast <RoutingMethodType>(routing_method_type) ==
65+ RoutingMethodType::DeepSeekV3) {
6566 TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_float32) << " routing_logits must be float." ;
67+ } else {
68+ // TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_bfloat16) << "routing_logits must be bfloat16.";
6669 }
6770 TVM_FFI_ICHECK_EQ (routing_logits->ndim , 2 ) << " routing_logits must be 2D." ;
6871 TVM_FFI_ICHECK_EQ (routing_logits->shape [1 ], num_experts) << " routing_logits has incorrect shape." ;
@@ -73,17 +76,32 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
7376 << " routing_bias has incorrect shape." ;
7477 }
7578
76- if (n_group <= 0 || topk_group <= 0 ) {
77- TVM_FFI_ICHECK_EQ (top_k, 1 ) << " Current routing kernel (no groups) only supports top_k=1." ;
78- } else {
79- TVM_FFI_ICHECK_LE (top_k, 8 ) << " Current routing kernel (with groups) only supports top_k<=8." ;
80- TVM_FFI_ICHECK_LE (topk_group, 4 )
81- << " Current routing kernel (with groups) only supports topk_group<=4." ;
82- TVM_FFI_ICHECK_LE (topk_group, n_group) << " n_group must not be smaller than topk_group." ;
83- TVM_FFI_ICHECK_EQ (num_experts % n_group, 0 ) << " num_experts must be divisible by n_group" ;
79+ if (n_group.has_value () && n_group.value () != 0 ) {
80+ TVM_FFI_ICHECK (static_cast <RoutingMethodType>(routing_method_type) ==
81+ RoutingMethodType::DeepSeekV3)
82+ << " Routing kernel with groups implies DeepSeekV3 routing method." ;
83+ TVM_FFI_ICHECK (topk_group.has_value ()) << " if n_group is given, topk_group must be given" ;
84+ TVM_FFI_ICHECK_EQ (num_experts % n_group.value (), 0 )
85+ << " num_experts must be divisible by n_group" ;
86+ TVM_FFI_ICHECK (top_k <= 8 && top_k > 0 )
87+ << " Current routing kernel (with groups) only supports top_k<=8 && top_k>0." ;
88+ TVM_FFI_ICHECK (topk_group.value () <= 4 && topk_group.value () > 0 )
89+ << " Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0." ;
90+ TVM_FFI_ICHECK_LE (topk_group.value (), n_group.value ())
91+ << " n_group must not be smaller than topk_group." ;
8492 // This check ensures we have enough experts in the selected groups to handle the top_k routing
85- TVM_FFI_ICHECK_LT (top_k, (topk_group * num_experts / n_group))
93+ TVM_FFI_ICHECK_LT (top_k, (topk_group. value () * num_experts / n_group. value () ))
8694 << " top_k must be less than total number of experts in selected groups" ;
95+ } else if (static_cast <RoutingMethodType>(routing_method_type) ==
96+ RoutingMethodType::Renormalize ||
97+ static_cast <RoutingMethodType>(routing_method_type) ==
98+ RoutingMethodType::RenormalizeNaive) {
99+ TVM_FFI_LOG_AND_THROW (NotImplementedError)
100+ << " Don't support routing method type Renormalize(Naive)." ;
101+ } else if (static_cast <RoutingMethodType>(routing_method_type) ==
102+ RoutingMethodType::Llama4) {
103+ TVM_FFI_ICHECK_EQ (top_k, 1 )
104+ << " Current routing kernel (no groups, Llama4) only supports top_k=1." ;
87105 }
88106 TVM_FFI_ICHECK_EQ (num_experts % 4 , 0 )
89107 << " Routing kernel expects that num_experts must be divisible by 4" ;
@@ -121,11 +139,11 @@ Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
121139 args.hidden_size = hidden_states->shape [1 ];
122140 args.hidden_size_output = args.hidden_size ;
123141 args.top_k = top_k;
124- args.n_group = n_group;
125- args.topk_group = topk_group;
142+ args.n_group = n_group. has_value () ? n_group. value () : 0 ;
143+ args.topk_group = topk_group. has_value () ? topk_group. value () : 0 ;
126144 args.local_expert_offset = local_expert_offset;
127145 args.local_num_experts = local_num_experts;
128- args.routed_scaling_factor = routed_scaling_factor;
146+ args.routed_scaling_factor = routed_scaling_factor. has_value () ? routed_scaling_factor. value () : 1.0 ;
129147 args.intermediate_size = intermediate_size;
130148 args.mUseRoutingScalesOnInput = use_routing_scales_on_input;
131149
@@ -279,8 +297,8 @@ Tensor trtllm_fp8_per_tensor_scale_moe(
279297 Tensor routing_logits, Optional<Tensor> routing_bias, Tensor hidden_states,
280298 Tensor gemm1_weights, Tensor output1_scales_scalar, Tensor output1_scales_gate_scalar,
281299 Tensor gemm2_weights, Tensor output2_scales_scalar, int64_t num_experts, int64_t top_k,
282- int64_t n_group, int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset,
283- int64_t local_num_experts, double routed_scaling_factor, bool use_routing_scales_on_input,
300+ Optional< int64_t > n_group, Optional< int64_t > topk_group, int64_t intermediate_size, int64_t local_expert_offset,
301+ int64_t local_num_experts, Optional< double > routed_scaling_factor, bool use_routing_scales_on_input,
284302 int64_t tile_tokens_dim, int64_t routing_method_type, bool enable_pdl) {
285303 auto dtype = hidden_states->dtype ;
286304 if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
@@ -299,9 +317,9 @@ void trtllm_fp8_block_scale_moe_launcher(
299317 Tensor routing_logits, Optional<Tensor> routing_bias, Tensor hidden_states,
300318 Tensor hidden_states_scale, Tensor gemm1_weights, Tensor gemm1_weights_scale,
301319 Tensor gemm2_weights, Tensor gemm2_weights_scale, Tensor output, int64_t const num_experts,
302- int64_t const top_k, int64_t const n_group, int64_t const topk_group,
320+ int64_t const top_k, Optional< int64_t > const n_group, Optional< int64_t > const topk_group,
303321 int64_t const intermediate_size, int64_t const local_expert_offset,
304- int64_t const local_num_experts, double const routed_scaling_factor,
322+ int64_t const local_num_experts, Optional< double > const routed_scaling_factor,
305323 int64_t const tile_tokens_dim, int64_t const routing_method_type,
306324 tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex,
307325 bool enable_pdl) {
@@ -318,7 +336,11 @@ void trtllm_fp8_block_scale_moe_launcher(
318336 << " This kernel requires 10.x architecture. Current device has SM "
319337 << std::get<0 >(device_props) << std::get<1 >(device_props);
320338
321- TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_float32) << " routing_logits must be float." ;
339+ if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
340+ TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_float32) << " routing_logits must be float." ;
341+ } else {
342+ TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_bfloat16) << " routing_logits must be bfloat16." ;
343+ }
322344 TVM_FFI_ICHECK_EQ (routing_logits->ndim , 2 ) << " routing_logits must be 2D." ;
323345 TVM_FFI_ICHECK_EQ (routing_logits->shape [0 ], hidden_states->shape [0 ])
324346 << " routing_logits and hidden_states must have the same number of tokens." ;
@@ -333,18 +355,33 @@ void trtllm_fp8_block_scale_moe_launcher(
333355 << " routing_bias has incorrect shape." ;
334356 }
335357
336- // if (n_group <= 0 || topk_group <= 0) {
337- // TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups) only supports top_k=1.";
338- // } else {
339- // TVM_FFI_ICHECK_LE(top_k, 8) << "Current routing kernel (with groups) only supports
340- // top_k<=8."; TVM_FFI_ICHECK_LE(topk_group, 4)
341- // << "Current routing kernel (with groups) only supports topk_group<=4.";
342- // TVM_FFI_ICHECK_LE(topk_group, n_group) << "n_group must not be smaller than topk_group.";
343- // TVM_FFI_ICHECK_EQ(num_experts % n_group, 0) << "num_experts must be divisible by n_group";
344- // // This check ensures we have enough experts in the selected groups to handle the top_k
345- // routing TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group))
346- // << "top_k must be less than total number of experts in selected groups";
347- // }
358+ if (n_group.has_value () && n_group.value () != 0 ) {
359+ TVM_FFI_ICHECK (static_cast <RoutingMethodType>(routing_method_type) ==
360+ RoutingMethodType::DeepSeekV3)
361+ << " Routing kernel with groups implies DeepSeekV3 routing method." ;
362+ TVM_FFI_ICHECK (topk_group.has_value ()) << " if n_group is given, topk_group must be given" ;
363+ TVM_FFI_ICHECK_EQ (num_experts % n_group.value (), 0 )
364+ << " num_experts must be divisible by n_group" ;
365+ TVM_FFI_ICHECK (top_k <= 8 && top_k > 0 )
366+ << " Current routing kernel (with groups) only supports top_k<=8 && top_k>0." ;
367+ TVM_FFI_ICHECK (topk_group.value () <= 4 && topk_group.value () > 0 )
368+ << " Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0." ;
369+ TVM_FFI_ICHECK_LE (topk_group.value (), n_group.value ())
370+ << " n_group must not be smaller than topk_group." ;
371+ // This check ensures we have enough experts in the selected groups to handle the top_k routing
372+ TVM_FFI_ICHECK_LT (top_k, (topk_group.value () * num_experts / n_group.value ()))
373+ << " top_k must be less than total number of experts in selected groups" ;
374+ } else if (static_cast <RoutingMethodType>(routing_method_type) ==
375+ RoutingMethodType::Renormalize ||
376+ static_cast <RoutingMethodType>(routing_method_type) ==
377+ RoutingMethodType::RenormalizeNaive) {
378+ TVM_FFI_ICHECK (top_k <= 10 && top_k > 0 )
379+ << " Current routing kernel (no groups, renormalize) only supports top_k<=10 && top_k>0." ;
380+ } else if (static_cast <RoutingMethodType>(routing_method_type) ==
381+ RoutingMethodType::Llama4) {
382+ TVM_FFI_ICHECK_EQ (top_k, 1 )
383+ << " Current routing kernel (no groups, Llama4) only supports top_k=1." ;
384+ }
348385 TVM_FFI_ICHECK_EQ (num_experts % 4 , 0 )
349386 << " Routing kernel expects that num_experts must be divisible by 4" ;
350387 TVM_FFI_ICHECK_GT (num_experts, top_k) << " num_experts must be greater than top_k" ;
@@ -380,11 +417,11 @@ void trtllm_fp8_block_scale_moe_launcher(
380417 args.hidden_size = hidden_states->shape [1 ];
381418 args.hidden_size_output = args.hidden_size ;
382419 args.top_k = top_k;
383- args.n_group = n_group;
384- args.topk_group = topk_group;
420+ args.n_group = n_group. has_value () ? n_group. value () : 0 ;
421+ args.topk_group = topk_group. has_value () ? topk_group. value () : 0 ;
385422 args.local_expert_offset = local_expert_offset;
386423 args.local_num_experts = local_num_experts;
387- args.routed_scaling_factor = routed_scaling_factor;
424+ args.routed_scaling_factor = routed_scaling_factor. has_value () ? routed_scaling_factor. value () : 1.0 ;
388425 args.intermediate_size = intermediate_size;
389426 args.mUseDeepSeekFp8 = true ;
390427
@@ -569,10 +606,10 @@ void trtllm_fp8_block_scale_moe(Tensor routing_logits, Optional<Tensor> routing_
569606 Tensor hidden_states, Tensor hidden_states_scale,
570607 Tensor gemm1_weights, Tensor gemm1_weights_scale,
571608 Tensor gemm2_weights, Tensor gemm2_weights_scale, Tensor output,
572- int64_t num_experts, int64_t top_k, int64_t n_group,
573- int64_t topk_group, int64_t intermediate_size,
609+ int64_t num_experts, int64_t top_k, Optional< int64_t > n_group,
610+ Optional< int64_t > topk_group, int64_t intermediate_size,
574611 int64_t local_expert_offset, int64_t local_num_experts,
575- double routed_scaling_factor, int64_t tile_tokens_dim,
612+ Optional< double > routed_scaling_factor, int64_t tile_tokens_dim,
576613 int64_t routing_method_type, bool use_shuffled_weight,
577614 int64_t weight_layout, bool enable_pdl) {
578615 auto dtype = hidden_states->dtype ;
0 commit comments