diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 7f9a664291..21faec8ec7 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#include + #include "flashinfer/exception.h" #include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" @@ -56,10 +58,9 @@ __global__ void routingMainKernel(KernelParams params) { } } - // note that for invalid scores, we simply use a negative value: - // they work well even with the compacted format used in topK, and - // sigmoid / bias activated scores cannot be negative - static constexpr float invalidScoreFloat = -1.F; + // note that for invalid scores, we use negative infinity, + // needed for GLM-style routing where bias can be negative + static constexpr float invalidScoreFloat = -float(INFINITY); const OutputT invalidScore = OutputT{invalidScoreFloat}; // load bias already; each warp represents one expert group @@ -101,8 +102,8 @@ __global__ void routingMainKernel(KernelParams params) { smemScoreSigmoid[threadExpert] = scoreSigmoid; } // get the score with bias - // note that with invalid values, because sigmoid is < 1 and bias is -1, - // we must get a negative value, which is smaller than any valid value + // note: with invalid values, invalidScoreFloat ensures values are always smaller than valid + // ones auto scoreBias = float{scoreSigmoid + float{biasVal}}; if (expertSelected) { diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index c209e5c509..471aba6fc1 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -2691,6 +2691,22 @@ def test_renormalize_routing( }, id="DSLite", ), + pytest.param( + { + "num_experts": 160, + "top_k": 8, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe, BF16Moe], + "compatible_intermediate_size": [512, 1024, 1536], + "enable_autotune": False, + }, + id="GLM4_MoE", + ), ], ) @pytest.mark.parametrize(