Skip to content

Commit f080a83

Browse files
[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)
Signed-off-by: vllmellm <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent 40e2eee commit f080a83

File tree

25 files changed

+1194
-925
lines changed

25 files changed

+1194
-925
lines changed

docs/design/moe_kernel_features.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
9797
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
9898
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
9999
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
100-
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] |
100+
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
101101
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
102102
| naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |
103103

tests/kernels/moe/test_moe.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
"""
77

88
import functools
9+
import importlib
10+
import sys
911
from collections.abc import Callable
1012
from dataclasses import dataclass
1113
from typing import Any
@@ -20,6 +22,7 @@
2022
import vllm.model_executor.layers.fused_moe # noqa
2123
from tests.kernels.moe.utils import fused_moe
2224
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
25+
from vllm._aiter_ops import rocm_aiter_ops
2326
from vllm.config import VllmConfig, set_current_vllm_config
2427
from vllm.distributed.parallel_state import init_distributed_environment
2528
from vllm.forward_context import set_forward_context
@@ -412,14 +415,12 @@ def test_mixtral_moe(
412415
huggingface."""
413416

414417
# clear the cache before every test
415-
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
416-
is_rocm_aiter_moe_enabled,
417-
)
418+
# Force reload aiter_ops to pick up the new environment variables.
419+
if "rocm_aiter_ops" in sys.modules:
420+
importlib.reload(rocm_aiter_ops)
418421

419-
is_rocm_aiter_moe_enabled.cache_clear()
420422
if use_rocm_aiter:
421423
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
422-
423424
if dtype == torch.float32:
424425
pytest.skip("AITER ROCm test skip for float32")
425426

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import torch
66

7+
from vllm._aiter_ops import rocm_aiter_ops
78
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
89
from vllm.model_executor.custom_op import CustomOp
910
from vllm.model_executor.layers.activation import (
@@ -15,9 +16,6 @@
1516
dispatch_topk_func,
1617
vllm_topk_softmax,
1718
)
18-
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
19-
is_rocm_aiter_moe_enabled,
20-
)
2119
from vllm.model_executor.layers.layernorm import (
2220
RMSNorm,
2321
dispatch_rocm_rmsnorm_func,
@@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str):
126124
RMSNorm(1024).enabled()
127125

128126

129-
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
130-
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
131-
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
132-
topk_func = dispatch_topk_func()
133-
is_rocm_aiter_moe_enabled.cache_clear()
134-
if current_platform.is_rocm() and int(use_rocm_aiter):
135-
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
136-
rocm_aiter_topk_softmax,
137-
)
127+
@pytest.mark.parametrize(
128+
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
129+
)
130+
def test_topk_dispatch(use_rocm_aiter: bool):
131+
topk_func = dispatch_topk_func(use_rocm_aiter)
138132

139-
assert topk_func == rocm_aiter_topk_softmax
133+
if current_platform.is_rocm() and use_rocm_aiter:
134+
assert topk_func == rocm_aiter_ops.topk_softmax
140135
else:
141136
assert topk_func == vllm_topk_softmax
142137

143138

144139
@pytest.mark.parametrize("add_residual", [True, False])
145140
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
146-
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
147-
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
141+
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
148142
@pytest.mark.skipif(
149143
not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm"
150144
)
151145
def test_rms_norm_dispatch(
152-
add_residual: bool,
153-
dtype: torch.dtype,
154-
use_rocm_aiter: str,
155-
use_rocm_aiter_norm: str,
156-
monkeypatch,
146+
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
157147
):
158-
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
159-
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
160-
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
148+
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)
161149

162150
should_use_rocm_aiter = (
163151
current_platform.is_rocm()
164-
and int(use_rocm_aiter)
165-
and int(use_rocm_aiter_norm)
152+
and use_rocm_aiter
166153
and dtype in RMS_NORM_SUPPORTED_DTYPES
167154
)
168155

169156
if add_residual and should_use_rocm_aiter:
170-
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
157+
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
171158
elif should_use_rocm_aiter:
172-
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
159+
assert rms_norm_func == rocm_aiter_ops.rms_norm
173160
elif add_residual:
174161
assert rms_norm_func == fused_add_rms_norm
175162
else:

0 commit comments

Comments
 (0)