Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -765,10 +765,15 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {

FusedMoeLauncher::check_routing_common();

if (args->n_group != 0) {
TVM_FFI_ICHECK(static_cast<RoutingMethodType>(routing_method_type) ==
RoutingMethodType::DeepSeekV3)
<< "Routing kernel with groups implies DeepSeekV3 routing method.";
if (static_cast<RoutingMethodType>(routing_method_type) != RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK(args->n_group <= 1)
<< "Current routing kernel (no groups) only supports n_group <= 1";
TVM_FFI_ICHECK(args->topk_group <= 1)
<< "Current routing kernel (no groups) only supports topk_group <= 1";
}

if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK(args->n_group != 0) << "n_group should not be zero for DeepSeekV3 routing";
TVM_FFI_ICHECK(args->topk_group != 0) << "if n_group is given, topk_group must be given";
TVM_FFI_ICHECK_EQ(args->num_experts % args->n_group, 0)
<< "num_experts must be divisible by n_group";
Expand All @@ -790,6 +795,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
TVM_FFI_ICHECK_EQ(args->top_k, 1)
<< "Current routing kernel (no groups, Llama4) only supports top_k=1.";
}

TVM_FFI_ICHECK_EQ(args->num_experts % 4, 0)
<< "Routing kernel expects that num_experts must be divisible by 4";
TVM_FFI_ICHECK_GT(args->num_experts, args->top_k) << "num_experts must be greater than top_k";
Expand Down Expand Up @@ -2004,9 +2010,8 @@ Array<Array<int64_t>> trtllm_get_valid_moe_configs(
}

TVM_FFI_LOG_AND_THROW(NotImplementedError)
<< "Unsupported data type combination for getValidConfigs: "
<< "dtype_act=" << static_cast<int>(dtype_act)
<< ", dtype_weights=" << static_cast<int>(dtype_weights)
<< "Unsupported data type combination for getValidConfigs: " << "dtype_act="
<< static_cast<int>(dtype_act) << ", dtype_weights=" << static_cast<int>(dtype_weights)
<< ", useDeepSeekFp8=" << useDeepSeekFp8;

// Unreachable code - added to suppress compiler warning
Expand Down
Loading