diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py index d26fe50b815b..c02183852532 100644 --- a/tests/kernels/moe/test_grouped_topk.py +++ b/tests/kernels/moe/test_grouped_topk.py @@ -8,6 +8,12 @@ import pytest import torch +from vllm.config import ( + CompilationConfig, + VllmConfig, + get_cached_compilation_config, + set_current_vllm_config, +) from vllm.model_executor.layers.fused_moe.fused_moe import ( GroupedTopk, fused_grouped_topk, @@ -41,6 +47,11 @@ def test_grouped_topk( routed_scaling_factor: float, dtype: torch.dtype, ): + vllm_config = VllmConfig( + compilation_config=CompilationConfig(custom_ops=["all", "+grouped_topk"]) + ) + get_cached_compilation_config.cache_clear() + current_platform.seed_everything(0) hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda") gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda") @@ -48,7 +59,7 @@ def test_grouped_topk( (n_expert,), dtype=torch.float32, device="cuda" ) - with monkeypatch.context() as m: + with set_current_vllm_config(vllm_config), monkeypatch.context() as m: m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") grouped_topk = GroupedTopk( topk=topk, @@ -58,6 +69,7 @@ def test_grouped_topk( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, ) + assert grouped_topk._forward_method.__name__ == "forward_cuda" baseline_topk_weights, baseline_topk_ids = grouped_topk( hidden_states=hidden_states, gating_output=gating_output,