Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
28337ca
enable aiter all reduce and fused ar_rmsnorm
vllmellm Mar 6, 2026
7a692cf
add unit test for aiter custom all reduce and move all reduce to same…
vllmellm Mar 10, 2026
d5a9c81
fix unittest
vllmellm Mar 11, 2026
a7c4229
add number of token threshold for aiter fused all reduce
vllmellm Mar 11, 2026
253e511
use aiter all reduce only for fusion pass
vllmellm Mar 14, 2026
d7bf6af
clean code
vllmellm Mar 16, 2026
45f638f
bugfixes
vllmellm Mar 17, 2026
aa3b50b
enable allreduce + rmsnorm using rocm_aiter_ops only
vllmellm Mar 19, 2026
e624300
Merge remote-tracking branch 'origin/main' into aiter-all-reduce-fuse…
vllmellm Mar 20, 2026
31d5e0c
remove unnecessary log
vllmellm Mar 20, 2026
fd5496b
[ROCm][Perf] Add fused AllReduce+RMSNorm for DeepSeek on MI355X
attila-dusnoki-htec Mar 20, 2026
65a9a35
Add clarifications to constrains
attila-dusnoki-htec Mar 23, 2026
92384b9
Add missed tests
attila-dusnoki-htec Mar 24, 2026
64419c8
Merge remote-tracking branch 'emb/aiter-all-reduce-fused-rmsnorm' int…
attila-dusnoki-htec Mar 25, 2026
771cf0f
Merge branch 'main' into dsr1-ar-rmsnorm
attila-dusnoki-htec Mar 31, 2026
eb80fed
Remove unnecessary change
vllmellm Apr 2, 2026
1039899
Remove unnecessary blank line
vllmellm Apr 2, 2026
e2daa9f
Merge remote-tracking branch 'origin/main' into aiter-all-reduce-fuse…
vllmellm Apr 9, 2026
2ca9464
Merge branch 'aiter-all-reduce-fused-rmsnorm' of https://github.com/E…
vllmellm Apr 9, 2026
e011547
Merge branch 'main' into aiter-all-reduce-fused-rmsnorm
vllmellm Apr 9, 2026
26ab444
fix unit-test
vllmellm Apr 9, 2026
0d7bbd3
Merge remote-tracking branch 'emb/aiter-all-reduce-fused-rmsnorm' int…
attila-dusnoki-htec Apr 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3687,6 +3687,7 @@ steps:
- pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py
# TODO: this test is not supported on ROCm, there are aiter kernels for this.
# - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
- pytest -v -s tests/compile/passes/distributed/test_tp2_ar_rms.py::test_tp2_ar_rms_fusions
# - pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
# - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"

Expand Down
20 changes: 16 additions & 4 deletions tests/compile/fusions_e2e/test_tp2_ar_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
FLASHMLA_SPARSE_ATTN,
ROCM_AITER_UNIFIED_ATTN,
ROCM_ATTN,
TRITON_ATTN,
deepseek_coder_v2_lite_fp8,
deepseek_v3_fp8,
Expand All @@ -33,8 +35,6 @@
qwen3_a3b_fp8,
)

pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
Expand All @@ -54,6 +54,7 @@
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
def test_tp2_ar_rms_fp8_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
Expand Down Expand Up @@ -123,6 +124,7 @@ def test_tp2_ar_rms_fp8_fusions(
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@pytest.mark.skipif(not is_blackwell(), reason="Blackwell required for fp4")
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
def test_tp2_ar_rms_fp4_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
Expand Down Expand Up @@ -175,10 +177,19 @@ def test_tp2_ar_rms_fp4_fusions(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b, qwen3_a3b, gpt_oss_20b],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
@pytest.mark.parametrize(
"attn_backend",
[
TRITON_ATTN,
FLASHINFER_ATTN,
ROCM_ATTN,
ROCM_AITER_UNIFIED_ATTN,
],
)
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("custom_ops", tuple(custom_ops_combos("rms_norm")))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only test CUDA/ROCm")
def test_tp2_ar_rms_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
Expand Down Expand Up @@ -220,4 +231,5 @@ def test_tp2_ar_rms_fusions(
compilation_config,
matches_check,
tp_size=2,
use_aiter=current_platform.is_rocm(),
)
84 changes: 72 additions & 12 deletions tests/compile/passes/distributed/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
import vllm.envs as envs
from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer, has_module_attribute, multi_gpu_test
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.passes.fusion.allreduce_rms_fusion import AllReduceFusionPass
from vllm.compilation.passes.fusion.allreduce_rms_fusion import (
AllReduceFusionPass,
RocmAiterAllReduceFusionPass,
)
from vllm.compilation.passes.utility.fix_functionalization import (
FixFunctionalizationPass,
)
Expand Down Expand Up @@ -40,13 +44,19 @@

class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
self,
hidden_size=16,
token_num=16,
eps=1e-6,
dtype: torch.dtype = torch.float16,
use_aiter: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
self.use_aiter = use_aiter

def forward(self, x):
# avoid having graph input be an arg to a pattern directly
Expand Down Expand Up @@ -74,6 +84,8 @@ def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]

def ops_in_model_after(self):
if self.use_aiter:
return [rocm_aiter_ops.get_fused_allreduce_rmsnorm_op()]
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]


Expand Down Expand Up @@ -192,12 +204,36 @@ def ops_in_model_before(self):

@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"test_model, enable_quant_fp8_custom_op",
"test_model, enable_quant_fp8_custom_op, use_aiter",
[
(TestAllReduceRMSNormModel, False),
(TestAllReduceRMSNormStaticQuantFP8Model, True),
(TestAllReduceRMSNormStaticQuantFP8Model, False),
(TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
(TestAllReduceRMSNormModel, False, IS_AITER_FOUND),
pytest.param(
TestAllReduceRMSNormStaticQuantFP8Model,
True,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
pytest.param(
TestAllReduceRMSNormStaticQuantFP8Model,
False,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
pytest.param(
TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
False,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
],
)
@pytest.mark.parametrize("batch_size", [8])
Expand All @@ -208,9 +244,18 @@ def ops_in_model_before(self):
@pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "allreduce_fusion")
or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"),
current_platform.is_rocm() and not IS_AITER_FOUND,
reason="aiter is not found",
)
@pytest.mark.skipif(
current_platform.is_cuda()
and (
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "allreduce_fusion")
or not has_module_attribute(
"flashinfer.comm", "create_allreduce_fusion_workspace"
)
),
reason="flashinfer is not found or flashinfer "
"is not compiled with allreduce_fusion",
)
Expand All @@ -223,7 +268,14 @@ def test_all_reduce_fusion_pass_replace(
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
use_aiter: bool,
monkeypatch: pytest.MonkeyPatch,
):
if use_aiter:
with monkeypatch.context() as m:
m.setenv("VLLM_ROCM_USE_AITER", str(use_aiter))
rocm_aiter_ops.refresh_env_variables()

num_processes = 2
if (
test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
Expand All @@ -247,6 +299,8 @@ def run_torch_spawn(fn, nprocs):
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
use_aiter,
monkeypatch,
),
nprocs=nprocs,
)
Expand All @@ -265,6 +319,8 @@ def all_reduce_fusion_pass_on_test_model(
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
use_aiter: bool,
monkeypatch: pytest.MonkeyPatch,
):
set_random_seed(0)

Expand Down Expand Up @@ -311,7 +367,11 @@ def all_reduce_fusion_pass_on_test_model(
)
with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
all_reduce_fusion_pass = (
RocmAiterAllReduceFusionPass(vllm_config)
if use_aiter
else AllReduceFusionPass(vllm_config)
)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
Expand All @@ -321,7 +381,7 @@ def all_reduce_fusion_pass_on_test_model(
)

token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num, dtype=dtype)
model = test_model_cls(hidden_size, token_num, dtype=dtype, use_aiter=use_aiter)

hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)

Expand Down
Loading
Loading