Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
28337ca
enable aiter all reduce and fused ar_rmsnorm
vllmellm Mar 6, 2026
7a692cf
add unit test for aiter custom all reduce and move all reduce to same…
vllmellm Mar 10, 2026
d5a9c81
fix unittest
vllmellm Mar 11, 2026
a7c4229
add number of token threshold for aiter fused all reduce
vllmellm Mar 11, 2026
253e511
use aiter all reduce only for fusion pass
vllmellm Mar 14, 2026
d7bf6af
clean code
vllmellm Mar 16, 2026
45f638f
bugfixes
vllmellm Mar 17, 2026
aa3b50b
enable allreduce + rmsnorm using rocm_aiter_ops only
vllmellm Mar 19, 2026
e624300
Merge remote-tracking branch 'origin/main' into aiter-all-reduce-fuse…
vllmellm Mar 20, 2026
31d5e0c
remove unnecessary log
vllmellm Mar 20, 2026
eb80fed
Remove unnecessary change
vllmellm Apr 2, 2026
1039899
Remove unnecessary blank line
vllmellm Apr 2, 2026
e2daa9f
Merge remote-tracking branch 'origin/main' into aiter-all-reduce-fuse…
vllmellm Apr 9, 2026
2ca9464
Merge branch 'aiter-all-reduce-fused-rmsnorm' of https://github.com/E…
vllmellm Apr 9, 2026
e011547
Merge branch 'main' into aiter-all-reduce-fused-rmsnorm
vllmellm Apr 9, 2026
26ab444
fix unit-test
vllmellm Apr 9, 2026
8bd7669
[ROCm] Use AITER fused_ar_rms API and refine use_1stage heuristic (#81)
rbrugaro-amd Apr 27, 2026
a110817
Merge remote-tracking branch 'origin/main' into aiter-all-reduce-fuse…
junkang1991 Apr 27, 2026
cbc3219
Merge remote-tracking branch 'origin/main' into aiter-all-reduce-fuse…
vllmellm Apr 27, 2026
661dddd
Update vllm/distributed/parallel_state.py
vllmellm Apr 27, 2026
bf722ad
fix pattern matcher replacement
vllmellm Apr 27, 2026
e3e6feb
add todo comment
vllmellm Apr 27, 2026
9898365
fix ci: keep aiter op in instance
vllmellm Apr 29, 2026
e4a4d57
fix platform check in test
vllmellm Apr 29, 2026
6b6d16c
fix allreduce+rmsnorm+quant fusion test
vllmellm Apr 29, 2026
2c94d38
add max token size to compile range
vllmellm Apr 30, 2026
9a0c520
Merge remote-tracking branch 'origin/main' into aiter-all-reduce-fuse…
vllmellm Apr 30, 2026
8b2e047
Merge branch 'main' into aiter-all-reduce-fused-rmsnorm
tjtanaa May 1, 2026
fdc50a8
attempt to fix precommit
vllmellm May 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,7 @@ steps:
- export VLLM_TEST_CLEAN_GPU_MEMORY=1
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py
- pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py
- pytest -v -s tests/compile/passes/distributed/test_tp2_ar_rms.py::test_tp2_ar_rms_fusions

#----------------------------------------------------------- mi300 · cuda ------------------------------------------------------------#

Expand Down
22 changes: 19 additions & 3 deletions tests/compile/fusions_e2e/test_tp2_ar_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
FLASHMLA_SPARSE_ATTN,
ROCM_AITER_UNIFIED_ATTN,
ROCM_ATTN,
TRITON_ATTN,
deepseek_coder_v2_lite_fp8,
deepseek_r1_fp4,
Expand All @@ -34,7 +36,9 @@
qwen3_a3b_fp8,
)

pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
Comment thread
tjtanaa marked this conversation as resolved.
pytestmark = pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test CUDA/ROCm"
)


@multi_gpu_test(num_gpus=2)
Expand All @@ -55,6 +59,7 @@
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
def test_tp2_ar_rms_fp8_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
Expand Down Expand Up @@ -124,6 +129,7 @@ def test_tp2_ar_rms_fp8_fusions(
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@pytest.mark.skipif(not is_blackwell(), reason="Blackwell required for fp4")
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
def test_tp2_ar_rms_fp4_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
Expand Down Expand Up @@ -176,10 +182,19 @@ def test_tp2_ar_rms_fp4_fusions(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b, qwen3_a3b, gpt_oss_20b],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
@pytest.mark.parametrize(
"attn_backend",
[
TRITON_ATTN,
FLASHINFER_ATTN,
ROCM_ATTN,
ROCM_AITER_UNIFIED_ATTN,
],
)
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("custom_ops", tuple(custom_ops_combos("rms_norm")))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only test CUDA/ROCm")
def test_tp2_ar_rms_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
Expand Down Expand Up @@ -221,4 +236,5 @@ def test_tp2_ar_rms_fusions(
compilation_config,
matches_check,
tp_size=2,
use_aiter=current_platform.is_rocm(),
)
89 changes: 77 additions & 12 deletions tests/compile/passes/distributed/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
import vllm.envs as envs
from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer, has_module_attribute, multi_gpu_test
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.passes.fusion.allreduce_rms_fusion import AllReduceFusionPass
from vllm.compilation.passes.fusion.allreduce_rms_fusion import (
AllReduceFusionPass,
RocmAiterAllReduceFusionPass,
)
from vllm.compilation.passes.utility.fix_functionalization import (
FixFunctionalizationPass,
)
Expand Down Expand Up @@ -42,13 +46,19 @@

class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
self,
hidden_size=16,
token_num=16,
eps=1e-6,
dtype: torch.dtype = torch.float16,
use_aiter: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
self.use_aiter = use_aiter

def forward(self, x):
# avoid having graph input be an arg to a pattern directly
Expand Down Expand Up @@ -76,6 +86,8 @@ def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]

def ops_in_model_after(self):
if self.use_aiter:
return [rocm_aiter_ops.get_fused_allreduce_rmsnorm_op()]
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]


Expand Down Expand Up @@ -194,12 +206,36 @@ def ops_in_model_before(self):

@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"test_model, enable_quant_fp8_custom_op",
"test_model, enable_quant_fp8_custom_op, use_aiter",
[
(TestAllReduceRMSNormModel, False),
(TestAllReduceRMSNormStaticQuantFP8Model, True),
(TestAllReduceRMSNormStaticQuantFP8Model, False),
(TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
(TestAllReduceRMSNormModel, False, IS_AITER_FOUND),
pytest.param(
TestAllReduceRMSNormStaticQuantFP8Model,
True,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
pytest.param(
TestAllReduceRMSNormStaticQuantFP8Model,
False,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
pytest.param(
TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
False,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
],
)
@pytest.mark.parametrize("batch_size", [8])
Expand All @@ -210,9 +246,18 @@ def ops_in_model_before(self):
@pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "allreduce_fusion")
or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"),
current_platform.is_rocm() and not IS_AITER_FOUND,
reason="aiter is not found",
)
@pytest.mark.skipif(
current_platform.is_cuda()
and (
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "allreduce_fusion")
or not has_module_attribute(
"flashinfer.comm", "create_allreduce_fusion_workspace"
)
),
reason="flashinfer is not found or flashinfer "
"is not compiled with allreduce_fusion",
)
Expand All @@ -225,7 +270,14 @@ def test_all_reduce_fusion_pass_replace(
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
use_aiter: bool,
monkeypatch: pytest.MonkeyPatch,
):
if use_aiter:
with monkeypatch.context() as m:
m.setenv("VLLM_ROCM_USE_AITER", str(use_aiter))
rocm_aiter_ops.refresh_env_variables()

num_processes = 2
if (
test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
Expand All @@ -249,6 +301,8 @@ def run_torch_spawn(fn, nprocs):
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
use_aiter,
monkeypatch,
),
nprocs=nprocs,
)
Expand All @@ -267,6 +321,8 @@ def all_reduce_fusion_pass_on_test_model(
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
use_aiter: bool,
monkeypatch: pytest.MonkeyPatch,
):
set_random_seed(0)

Expand Down Expand Up @@ -313,7 +369,11 @@ def all_reduce_fusion_pass_on_test_model(
)
with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
all_reduce_fusion_pass = (
RocmAiterAllReduceFusionPass(vllm_config)
if use_aiter
else AllReduceFusionPass(vllm_config)
)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
Expand All @@ -323,7 +383,12 @@ def all_reduce_fusion_pass_on_test_model(
)

token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num, dtype=dtype)
if test_model_cls is TestAllReduceRMSNormModel:
model = test_model_cls(
hidden_size, token_num, dtype=dtype, use_aiter=use_aiter
)
else:
model = test_model_cls(hidden_size, token_num, dtype=dtype)

hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)

Expand Down
Loading
Loading