Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,27 @@
from torch._inductor.utils import fresh_cache


# This function will load AITER function pointers and make them
# available for vllm tests.
def load_aiter_ops_if_available():
from vllm.platforms import current_platform
from importlib.util import find_spec

# NOTE: it's not possible to use vllm._aiter_ops.is_aiter_found
# because the aiter ops won't load and all tests that want
# to use aiter will fail because no aiter ops will be loaded.
if current_platform.is_rocm() and find_spec("aiter") is not None:
os.environ["VLLM_ROCM_USE_AITER"] = "1"
# Load the ops so they can be used by tests that need them.
import vllm._aiter_ops

# Unset the environment variable, so tests that need this variable
# to be unset will function properly.
os.environ["VLLM_ROCM_USE_AITER"] = "0"


load_aiter_ops_if_available()

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.generation.utils import GenerateOutput
Expand Down
4 changes: 0 additions & 4 deletions tests/kernels/moe/test_rocm_aiter_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# and the platform is not ROCm.

import importlib.util
import os

import pytest
import torch
Expand All @@ -20,9 +19,6 @@
if not current_platform.is_rocm():
pytest.skip("This test can only run on ROCm.", allow_module_level=True)

# This environment variable must be set so ops will be registered.
os.environ["VLLM_ROCM_USE_AITER"] = "1"

# this import statement is needed to ensure the ops are registered
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401

Expand Down
1 change: 1 addition & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_RMSNORM
and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
and cfg.model_config is not None
and cfg.model_config.get_hidden_size() == 2880
)

Expand Down