diff --git a/tests/conftest.py b/tests/conftest.py index 822d08e21675..f3ee0ccb3d47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -74,6 +74,27 @@ from torch._inductor.utils import fresh_cache +# This function will load AITER function pointers and make them +# available for vllm tests. +def load_aiter_ops_if_available(): + from vllm.platforms import current_platform + from importlib.util import find_spec + + # NOTE: it's not possible to use vllm._aiter_ops.is_aiter_found + # because the aiter ops won't load and all tests that want + # to use aiter will fail because no aiter ops will be loaded. + if current_platform.is_rocm() and find_spec("aiter") is not None: + os.environ["VLLM_ROCM_USE_AITER"] = "1" + # Load the ops so they can be used by tests that need them. + import vllm._aiter_ops + + # Unset the environment variable, so tests that need this variable + # to be unset will function properly. + os.environ["VLLM_ROCM_USE_AITER"] = "0" + + +load_aiter_ops_if_available() + if TYPE_CHECKING: from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.generation.utils import GenerateOutput diff --git a/tests/kernels/moe/test_rocm_aiter_topk.py b/tests/kernels/moe/test_rocm_aiter_topk.py index 070d00f61120..b0ecc9ed71f6 100644 --- a/tests/kernels/moe/test_rocm_aiter_topk.py +++ b/tests/kernels/moe/test_rocm_aiter_topk.py @@ -10,7 +10,6 @@ # and the platform is not ROCm. import importlib.util -import os import pytest import torch @@ -20,9 +19,6 @@ if not current_platform.is_rocm(): pytest.skip("This test can only run on ROCm.", allow_module_level=True) -# This environment variable must be set so ops will be registered. -os.environ["VLLM_ROCM_USE_AITER"] = "1" - # this import statement is needed to ensure the ops are registered import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401 diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0f499c39ead3..69879cbc78cd 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -109,6 +109,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + and cfg.model_config is not None and cfg.model_config.get_hidden_size() == 2880 )