Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
89fc090
move all aiter kernel ops registration to _aiter_ops.py and centraliz…
vllmellm Aug 28, 2025
2f07bcc
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Aug 29, 2025
b67838b
clean remaining of aiter ops
vllmellm Aug 29, 2025
9c1387b
fix circular import
vllmellm Aug 29, 2025
414c638
avoid circular import
vllmellm Aug 29, 2025
ad0d26c
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Sep 2, 2025
e878b9a
fix fused_topk return
vllmellm Sep 2, 2025
c9663f2
add missing aiter gemm_w8a8_block scale kernel in iater ops.
vllmellm Sep 2, 2025
9cfcaf4
fix ops name
vllmellm Sep 2, 2025
ccf5a0a
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Sep 4, 2025
6ffe492
remove unnecessary cache decorator
vllmellm Sep 8, 2025
ba3fa63
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Sep 8, 2025
9c8a7c7
fix precommit error
vllmellm Sep 10, 2025
bdd7b74
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Sep 12, 2025
37f38d0
fix get_vit_attn_backend after merge/main
vllmellm Sep 12, 2025
e85de16
attempt to fix pre-commit error
vllmellm Sep 12, 2025
c4f132f
make all dispatch function pure by removing envs check in dispatch lo…
vllmellm Sep 12, 2025
53457e4
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Sep 12, 2025
2e50f91
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Sep 12, 2025
d23f344
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Sep 26, 2025
0320f49
refactor quark w4a4 mxfp4
vllmellm Sep 26, 2025
dc8e846
fix impl function return type and improve error message
vllmellm Sep 26, 2025
b02620b
refactor aiter triton embedding
vllmellm Sep 26, 2025
711d32a
fix env variable
vllmellm Sep 26, 2025
5bc6204
fix bug when +rotary_embedding is requested as custom_ops for Llama4 …
vllmellm Sep 26, 2025
c212c02
enable flash_attn_var_len from aiter or vllm package to be used as vi…
vllmellm Sep 26, 2025
7822bcd
bugfix: update aiter moe check
vllmellm Sep 26, 2025
4b52c1b
fix precommit error for foward_hip overriden function
vllmellm Sep 29, 2025
7be2b46
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Sep 29, 2025
1dd9afe
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Oct 9, 2025
642d27f
move shuffle weights into rocm aiter ops
vllmellm Oct 9, 2025
79ee26b
move quark scheme kernels into aiter ops
vllmellm Oct 9, 2025
9c21911
bugfixes of merge with main
vllmellm Oct 10, 2025
de9502f
bugfixes of merge with main
vllmellm Oct 10, 2025
6ddfc47
return quark_ocp_mx to its original base
vllmellm Oct 14, 2025
3232671
resolve merge conflicts
vllmellm Oct 14, 2025
d5990c3
fix api issue of aiter grouped topk and remove unused code
vllmellm Oct 14, 2025
d0d23c8
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Oct 14, 2025
674195c
fix docs
vllmellm Oct 14, 2025
9073213
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Oct 20, 2025
0050bac
bugfixes during merge with main
vllmellm Oct 20, 2025
645d08c
clean code
vllmellm Oct 20, 2025
154eba0
avoid calling aiter enable check in forward pass
vllmellm Oct 20, 2025
5a2efa2
Update vllm/_aiter_ops.py
vllmellm Oct 21, 2025
0932396
address pr comments
vllmellm Oct 21, 2025
4647e06
return back missing code during merge conflict
vllmellm Oct 21, 2025
87c7024
bugfix in per tensor w8a8 gemm when using --quantization fp8 vllm co…
vllmellm Oct 21, 2025
ebc35a4
sync upstream
vllmellm Oct 27, 2025
84ce928
check aiter package in global aiter_ops.py
vllmellm Nov 3, 2025
15c970e
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Nov 3, 2025
c2b050d
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Nov 5, 2025
488b5b8
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Nov 7, 2025
fb25a7c
attempt to fix topk dispatcher unit test
vllmellm Nov 7, 2025
b28b6d0
fix topk dispatcher unit test
vllmellm Nov 7, 2025
5309c88
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Nov 8, 2025
97e4c80
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Nov 8, 2025
dab6827
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Nov 8, 2025
b1aef31
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Nov 9, 2025
f081a2b
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Nov 9, 2025
b6cb4d2
Merge remote-tracking branch 'origin/main' into rfc-aiter-ops
vllmellm Nov 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
| naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |

Expand Down
11 changes: 6 additions & 5 deletions tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"""

import functools
import importlib
import sys
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
Expand All @@ -20,6 +22,7 @@
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
Expand Down Expand Up @@ -412,14 +415,12 @@ def test_mixtral_moe(
huggingface."""

# clear the cache before every test
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
# Force reload aiter_ops to pick up the new environment variables.
if "rocm_aiter_ops" in sys.modules:
importlib.reload(rocm_aiter_ops)

is_rocm_aiter_moe_enabled.cache_clear()
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")

Expand Down
41 changes: 14 additions & 27 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import torch

from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (
Expand All @@ -15,9 +16,6 @@
dispatch_topk_func,
vllm_topk_softmax,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.layernorm import (
RMSNorm,
dispatch_rocm_rmsnorm_func,
Expand Down Expand Up @@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled()


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
topk_func = dispatch_topk_func()
is_rocm_aiter_moe_enabled.cache_clear()
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_topk_softmax,
)
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_topk_dispatch(use_rocm_aiter: bool):
topk_func = dispatch_topk_func(use_rocm_aiter)

assert topk_func == rocm_aiter_topk_softmax
if current_platform.is_rocm() and use_rocm_aiter:
assert topk_func == rocm_aiter_ops.topk_softmax
else:
assert topk_func == vllm_topk_softmax


@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm"
)
def test_rms_norm_dispatch(
add_residual: bool,
dtype: torch.dtype,
use_rocm_aiter: str,
use_rocm_aiter_norm: str,
monkeypatch,
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)

should_use_rocm_aiter = (
current_platform.is_rocm()
and int(use_rocm_aiter)
and int(use_rocm_aiter_norm)
and use_rocm_aiter
and dtype in RMS_NORM_SUPPORTED_DTYPES
)

if add_residual and should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
elif should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
assert rms_norm_func == rocm_aiter_ops.rms_norm
elif add_residual:
assert rms_norm_func == fused_add_rms_norm
else:
Expand Down
Loading