diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index f3957e9717..41fcc3768e 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -131,6 +131,7 @@ class FusedMoeLauncher { btg::Dtype mRoutingBiasDtype{ btg::Dtype::Bfloat16}; // Dtype for expert weights in routing, based on routing bias ActivationType activation_type{ActivationType::Swiglu}; + btg::Dtype mDtypeScore{btg::Dtype::Bfloat16}; int64_t intermediate_size_factor{2}; @@ -168,13 +169,19 @@ class FusedMoeLauncher { int64_t weight_layout, ActivationType activation_type); // Routing logits [num_tokens, num_experts] - void check_routing_logits_shape() const { + void check_routing_logits() const { if (routing_logits.has_value()) { + // Check shape TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; TVM_FFI_ICHECK_EQ(routing_logits.value().size(0), hidden_states.size(0)) << "routing_logits and hidden_states must have the same number of tokens."; TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), args->num_experts) << "routing_logits dim1 must match num_experts."; + + // Check dtype + TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || + routing_logits.value().dtype() == dl_bfloat16) + << "routing_logits must be float or bfloat16."; } } @@ -236,7 +243,7 @@ class FusedMoeLauncher { args->local_expert_offset + args->local_num_experts <= args->num_experts) << "expert offset and count must be within valid range"; - check_routing_logits_shape(); + check_routing_logits(); if (routing_bias.has_value()) { check_routing_bias_shape(); @@ -302,6 +309,19 @@ class FusedMoeLauncher { workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx.data_ptr()); workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit.data_ptr()); workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas.data_ptr()); + + // Set dtype of score based on actual routing_logits dtype + if (routing_logits.has_value()) { + if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { + TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) + << "routing_logits must be float."; + mDtypeScore = btg::Dtype::Fp32; + } else if (routing_logits.value().dtype() == dl_float32) { + mDtypeScore = btg::Dtype::Fp32; + } else { + mDtypeScore = btg::Dtype::Bfloat16; + } + } } void check_moe_common() const { @@ -387,8 +407,8 @@ class FusedMoeLauncher { static_cast(num_tokens_per_expert.data_ptr()), static_cast(cta_idx_xy_to_batch_idx.data_ptr()), static_cast(cta_idx_xy_to_mn_limit.data_ptr()), - static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, - use_routing_scales_on_input, use_deep_seek_fp8, + static_cast(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt, + mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream); check_moe(); @@ -498,8 +518,9 @@ class Bf16MoeLauncher : public FusedMoeLauncher { if (has_precomputed_weights) { workspace.expert_weights = const_cast(expert_weights.data_ptr()); } else { + auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; FusedMoeLauncher::expert_weights = - alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device()); workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr(); } } @@ -638,8 +659,9 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + auto expert_weights_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; expert_weights = - alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + alloc_tensor({args->num_tokens, args->top_k}, expert_weights_dtype, hidden_states.device()); workspace.expert_weights = expert_weights.data_ptr(); if (static_cast(routing_method_type) == RoutingMethodType::Llama4) { @@ -913,9 +935,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr bool has_precomputed_weights = expert_weights.ndim() == 2 && expert_weights.size(0) > 0; if (!has_precomputed_weights) { - // Allocate expert_weights buffer for routing output + auto ew_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; FusedMoeLauncher::expert_weights = - alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + alloc_tensor({args->num_tokens, args->top_k}, ew_dtype, hidden_states.device()); workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr(); } else { workspace.expert_weights = const_cast(expert_weights.data_ptr()); @@ -1092,8 +1114,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { static_cast(num_tokens_per_expert.data_ptr()), static_cast(cta_idx_xy_to_batch_idx.data_ptr()), static_cast(cta_idx_xy_to_mn_limit.data_ptr()), - static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, - use_routing_scales_on_input, use_deep_seek_fp8, + static_cast(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt, + mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream); check_moe(); @@ -1188,8 +1210,9 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + auto expert_weights_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; expert_weights = - alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + alloc_tensor({args->num_tokens, args->top_k}, expert_weights_dtype, hidden_states.device()); workspace.expert_weights = expert_weights.data_ptr(); } @@ -1555,8 +1578,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { static_cast(num_tokens_per_expert.data_ptr()), static_cast(cta_idx_xy_to_batch_idx.data_ptr()), static_cast(cta_idx_xy_to_mn_limit.data_ptr()), - static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, - use_routing_scales_on_input, use_deep_seek_fp8, + static_cast(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt, + mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, static_cast(routing_method_type), routing_stream); check_moe(); @@ -1694,13 +1717,12 @@ Array trtllm_fp8_per_tensor_scale_moe( // Basic type validation auto dtype = hidden_states.dtype(); auto activation = static_cast(activation_type); - if (use_routing_scales_on_input) { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; - } else if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float."; - } else { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; + + if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) + << "routing_logits must be float for DeepSeekV3."; } + TVM_FFI_ICHECK(dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16) << "FP8 MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16."; TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) @@ -1799,9 +1821,6 @@ Array trtllm_fp8_block_scale_moe( if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) << "routing_logits must be float."; - } else { - TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_bfloat16) - << "routing_logits must be bfloat16."; } } TVM_FFI_ICHECK(dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) @@ -1930,18 +1949,6 @@ Array trtllm_fp4_block_scale_moe( << "unsupported weight_scale_vec_size."; auto mDtypeWeights = weight_scale_vec_size == 16 ? btg::Dtype::E2m1 : btg::Dtype::MxE2m1; - if (routing_logits.has_value()) { - TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || - routing_logits.value().dtype() == dl_bfloat16) - << "routing_logits must be float or bfloat16."; - TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; - TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts) - << "routing_logits has incorrect shape."; - if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { - TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) - << "routing_logits must be float."; - } - } if (routing_bias.has_value()) { TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || routing_bias.value().dtype() == dl_float32) diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 4091019efc..39dec8a707 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -56,8 +56,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, - int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias, - bool useRoutingScalesOnInput, bool useDeepSeekFp8, + int32_t* numNonExitingCtas, btg::Dtype dtypeScore, btg::Dtype dtypeElt, + btg::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) { if (routingMethodType == RoutingMethodType::DeepSeekV3) { FLASHINFER_CHECK(topK <= 22, "For DeepSeek routing method, must have topK <= 22"); @@ -140,6 +140,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 // routingData.mDtypeExpW = btg::Dtype::Bfloat16; + routingData.mDtypeScore = dtypeScore; + // routingData.mDtypeElt = dtypeElt; // no-op for now as hidden_state is not input routingData.mUsePdl = true; routingData.mDoSoftmaxBeforeTopK = routingMethodType == RoutingMethodType::RenormalizeNaive; diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 28c1603bc5..5a154678c1 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -178,22 +178,45 @@ namespace moe::dev { #define LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ stream, extraFlag1, numExperts, numTopExperts) \ - if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \ + if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \ LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), \ kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Fp32) { \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeExpW == tg::Dtype::Fp32 && \ + !extraFlag1) { \ LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), \ kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) { \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeExpW == tg::Dtype::Bfloat16 && \ + extraFlag1) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeExpW == tg::Dtype::Bfloat16 && \ + !extraFlag1) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32 && \ + extraFlag1) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, float, numExperts, numTopExperts, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32 && \ + !extraFlag1) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, float, numExperts, numTopExperts, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16 && \ + extraFlag1) { \ LAUNCH_TILEN(data, coopLaunch, \ LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \ kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16 && \ + !extraFlag1) { \ LAUNCH_TILEN(data, coopLaunch, \ LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \ kernel, numBlocks, numThreads, smemSize, stream); \ } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ + FLASHINFER_WARN("Unsupported combination of mDtypeScore and mDtypeExpW"); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index 456bcd7a75..83ab472f95 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -265,6 +265,7 @@ namespace routingRenormalize { struct Data : public DataBase { tg::Dtype mDtypeExpW{tg::Dtype::Fp32}; + tg::Dtype mDtypeScore{tg::Dtype::Fp32}; tg::Dtype mDtypeElt{tg::Dtype::Bfloat16}; bool mDoSoftmaxBeforeTopK{false}; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 46617e5dbd..43cc326fb5 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -126,9 +126,9 @@ class Runner { int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, - batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias, - bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, - cudaStream_t stream); + batchedGemm::trtllm::gen::Dtype dtypeScore, batchedGemm::trtllm::gen::Dtype dtypeElt, + batchedGemm::trtllm::gen::Dtype dtypeBias, bool useRoutingScalesOnInput, + bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream); private: int32_t mTileTokensDim{8}; diff --git a/tests/moe/test_dpsk_fused_moe_fp8.py b/tests/moe/test_dpsk_fused_moe_fp8.py index 35d9aae594..04296872fc 100644 --- a/tests/moe/test_dpsk_fused_moe_fp8.py +++ b/tests/moe/test_dpsk_fused_moe_fp8.py @@ -619,6 +619,7 @@ def __init__(self): activation_type=ActivationType.Swiglu, num_tokens=seq_len, hidden_size=7168, # DeepSeek-V3 hidden size + logits_dtype=torch.float32, intermediate_size=intermediate_size, ) diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 887472ddbd..76702b8e9c 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -1329,11 +1329,7 @@ def call_moe( # Use autotuner for optimal kernel selection with autotune(enable_autotune): output = trtllm_fp8_per_tensor_scale_moe( - ( - expert_logits.to(torch.bfloat16) - if routing_method_type == RoutingMethodType.Llama4 - else expert_logits - ), + expert_logits, routing_bias, hidden_states_fp8, static_data["gemm1_weights"], @@ -2577,6 +2573,7 @@ def run_moe_test( weight_processing, activation_type, cache_permute_indices, + logits_dtype, zero_hidden_states=False, gemm1_bias=None, gemm2_bias=None, @@ -2590,6 +2587,7 @@ def run_moe_test( num_tokens, hidden_size, intermediate_size, + logits_dtype, zero_hidden_states=zero_hidden_states, ) @@ -2619,14 +2617,9 @@ def run_moe_test( assert top_k < (top_k_groups * num_experts / n_groups) # Create test data based on routing method - if routing_method_type == RoutingMethodType.DeepSeekV3: - expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( - torch.float - ) - else: - expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( - torch.bfloat16 - ) + expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( + logits_dtype + ) if routing_config["has_routing_bias"]: routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16) @@ -2932,6 +2925,13 @@ def run_moe_test( pytest.param(ActivationType.Geglu.value, id="Geglu"), ], ) +@pytest.mark.parametrize( + "logits_dtype", + [ + pytest.param(torch.float32, id="FP32_logits"), + pytest.param(torch.bfloat16, id="BF16_logits"), + ], +) def test_renormalize_routing( num_tokens, hidden_size, @@ -2941,6 +2941,7 @@ def test_renormalize_routing( weight_processing, activation_type, cache_permute_indices, + logits_dtype, zero_hidden_states, ): """Test Renormalize routing configurations.""" @@ -2953,6 +2954,7 @@ def test_renormalize_routing( weight_processing, activation_type, cache_permute_indices, + logits_dtype, zero_hidden_states=zero_hidden_states, ) @@ -3128,6 +3130,12 @@ def test_renormalize_routing( pytest.param(ActivationType.Relu2.value, id="Relu2"), ], ) +@pytest.mark.parametrize( + "logits_dtype", + [ + pytest.param(torch.float32, id="FP32_logits"), + ], +) def test_deepseekv3_routing( num_tokens, hidden_size, @@ -3136,6 +3144,7 @@ def test_deepseekv3_routing( routing_config, weight_processing, activation_type, + logits_dtype, cache_permute_indices, ): """Test DeepSeekV3 routing configurations.""" @@ -3148,6 +3157,7 @@ def test_deepseekv3_routing( weight_processing, activation_type, cache_permute_indices, + logits_dtype, ) @@ -3203,6 +3213,13 @@ def test_deepseekv3_routing( pytest.param(ActivationType.Geglu.value, id="Geglu"), ], ) +@pytest.mark.parametrize( + "logits_dtype", + [ + pytest.param(torch.float32, id="FP32_logits"), + pytest.param(torch.bfloat16, id="BF16_logits"), + ], +) def test_topk_routing( num_tokens, hidden_size, @@ -3211,6 +3228,7 @@ def test_topk_routing( routing_config, weight_processing, activation_type, + logits_dtype, cache_permute_indices, ): """Test TopK routing configuration.""" @@ -3223,6 +3241,7 @@ def test_topk_routing( weight_processing, activation_type, cache_permute_indices, + logits_dtype, ) @@ -3276,6 +3295,12 @@ def test_topk_routing( pytest.param(ActivationType.Swiglu.value, id="Swiglu"), ], ) +@pytest.mark.parametrize( + "logits_dtype", + [ + pytest.param(torch.bfloat16, id="BF16_logits"), + ], +) def test_llama4_routing( num_tokens, hidden_size, @@ -3284,6 +3309,7 @@ def test_llama4_routing( routing_config, weight_processing, activation_type, + logits_dtype, cache_permute_indices, ): """Test Llama4 routing configuration with FP8 per-tensor.""" @@ -3296,6 +3322,7 @@ def test_llama4_routing( weight_processing, activation_type, cache_permute_indices, + logits_dtype, ) @@ -3347,6 +3374,7 @@ def test_nvfp4_moe_gemm_bias( }, activation_type=ActivationType.Swiglu, cache_permute_indices=cache_permute_indices, + logits_dtype=torch.bfloat16, gemm1_bias=gemm1_bias, gemm2_bias=gemm2_bias, ) diff --git a/tests/moe/utils.py b/tests/moe/utils.py index 7c8339cecf..c3010e9c0e 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -57,6 +57,7 @@ def skip_checks( num_tokens, hidden_size, intermediate_size, + logits_dtype, zero_hidden_states=False, ): """Common skip logic for all tests.""" @@ -157,3 +158,21 @@ def skip_checks( pytest.xfail( "Note(jimmzhou): Make MxFP4xBf16 nonfunctional on SM103 to avoid B200 regression" ) + + if ( + routing_config["routing_method_type"] == RoutingMethodType.DeepSeekV3 + and logits_dtype != torch.float32 + ): + pytest.skip( + f"Incompatible: logits_dtype={logits_dtype} with DeepSeekV3 routing" + ) + + if logits_dtype == torch.float32 and moe_impl.quant_mode not in [ + QuantMode.FP8_PER_TENSOR, + QuantMode.FP8_BLOCK_SCALE_DEEPSEEK, + QuantMode.FP8_BLOCK_SCALE_MXFP8, + QuantMode.BF16, + ]: + pytest.skip( + f"Incompatible: logits_dtype={logits_dtype} with {type(moe_impl).__name__} + {moe_impl.quant_mode}" + )