diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 2821ce829a..77f2aad3c6 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -989,7 +989,7 @@ def __init__( use_deepseek_fp8: bool, hidden_size: int, intermediate_size: int, - activation_type: int = ActivationType.Swiglu, + activation_type: int = ActivationType.Swiglu.value, use_shuffled_weight: bool = False, weight_layout: int = WeightLayout.MajorK, use_packed_weights: bool = False, @@ -1422,7 +1422,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( routing_method_type: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, - activation_type: ActivationType = ActivationType.Swiglu, + activation_type: int = ActivationType.Swiglu.value, ) -> torch.Tensor: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) @@ -1482,7 +1482,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( use_routing_scales_on_input=use_routing_scales_on_input, routing_method_type=routing_method_type, enable_pdl=enable_pdl, - activation_type=activation_type.value, + activation_type=activation_type, ) # Call the C++ function result = moe_op.trtllm_fp8_per_tensor_scale_moe( @@ -1507,7 +1507,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( routing_method_type, enable_pdl, [-1, -1] if tactic == -1 else tactic, - activation_type.value, + activation_type, ) return result diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index a93767e457..58d2e2bb9d 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -2667,8 +2667,8 @@ def run_moe_test( @pytest.mark.parametrize( "activation_type", [ - pytest.param(ActivationType.Swiglu, id="Swiglu"), - pytest.param(ActivationType.Geglu, id="Geglu"), + pytest.param(ActivationType.Swiglu.value, id="Swiglu"), + pytest.param(ActivationType.Geglu.value, id="Geglu"), ], ) def test_renormalize_routing( @@ -2855,9 +2855,9 @@ def test_renormalize_routing( @pytest.mark.parametrize( "activation_type", [ - pytest.param(ActivationType.Swiglu, id="Swiglu"), - pytest.param(ActivationType.Geglu, id="Geglu"), - pytest.param(ActivationType.Relu2, id="Relu2"), + pytest.param(ActivationType.Swiglu.value, id="Swiglu"), + pytest.param(ActivationType.Geglu.value, id="Geglu"), + pytest.param(ActivationType.Relu2.value, id="Relu2"), ], ) def test_deepseekv3_routing( @@ -2931,8 +2931,8 @@ def test_deepseekv3_routing( @pytest.mark.parametrize( "activation_type", [ - pytest.param(ActivationType.Swiglu, id="Swiglu"), - pytest.param(ActivationType.Geglu, id="Geglu"), + pytest.param(ActivationType.Swiglu.value, id="Swiglu"), + pytest.param(ActivationType.Geglu.value, id="Geglu"), ], ) def test_topk_routing( @@ -3005,7 +3005,7 @@ def test_topk_routing( @pytest.mark.parametrize( "activation_type", [ - pytest.param(ActivationType.Swiglu, id="Swiglu"), + pytest.param(ActivationType.Swiglu.value, id="Swiglu"), ], ) def test_llama4_routing(