Skip to content
Closed
1 change: 1 addition & 0 deletions .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3579,6 +3579,7 @@ steps:
- pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py
# TODO: this test is not supported on ROCm, there are aiter kernels for this.
# - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
- pytest -v -s tests/compile/passes/distributed/test_tp2_ar_rms.py::test_tp2_ar_rms_fusions
# - pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
# - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"

Expand Down
20 changes: 16 additions & 4 deletions tests/compile/fusions_e2e/test_tp2_ar_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from .models import (
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
ROCM_AITER_UNIFIED_ATTN,
ROCM_ATTN,
TRITON_ATTN,
deepseek_v3_fp8,
gpt_oss_20b,
Expand All @@ -30,8 +32,6 @@
qwen3_a3b_fp8,
)

pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
Expand All @@ -45,6 +45,7 @@
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
def test_tp2_ar_rms_fp8_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
Expand Down Expand Up @@ -110,6 +111,7 @@ def test_tp2_ar_rms_fp8_fusions(
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@pytest.mark.skipif(not is_blackwell(), reason="Blackwell required for fp4")
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
def test_tp2_ar_rms_fp4_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
Expand Down Expand Up @@ -161,10 +163,19 @@ def test_tp2_ar_rms_fp4_fusions(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b, qwen3_a3b, gpt_oss_20b],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
@pytest.mark.parametrize(
"attn_backend",
[
TRITON_ATTN,
FLASHINFER_ATTN,
ROCM_ATTN,
ROCM_AITER_UNIFIED_ATTN,
],
)
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("custom_ops", tuple(custom_ops_combos("rms_norm")))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only test CUDA/ROCm")
def test_tp2_ar_rms_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
Expand Down Expand Up @@ -205,4 +216,5 @@ def test_tp2_ar_rms_fusions(
compilation_config,
matches_check,
tp_size=2,
use_aiter=current_platform.is_rocm(),
)
79 changes: 67 additions & 12 deletions tests/compile/passes/distributed/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
import vllm.envs as envs
from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer, has_module_attribute, multi_gpu_test
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.passes.fusion.allreduce_rms_fusion import AllReduceFusionPass
from vllm.compilation.passes.fusion.allreduce_rms_fusion import (
AllReduceFusionPass,
RocmAiterAllReduceFusionPass,
)
from vllm.compilation.passes.utility.fix_functionalization import (
FixFunctionalizationPass,
)
Expand Down Expand Up @@ -39,12 +43,13 @@


class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6, use_aiter=False):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
self.use_aiter = use_aiter

def forward(self, x):
# avoid having graph input be an arg to a pattern directly
Expand Down Expand Up @@ -72,6 +77,8 @@ def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]

def ops_in_model_after(self):
if self.use_aiter:
return [rocm_aiter_ops.get_fused_allreduce_rmsnorm_op()]
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]


Expand Down Expand Up @@ -185,12 +192,36 @@ def ops_in_model_before(self):

@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"test_model, enable_quant_fp8_custom_op",
"test_model, enable_quant_fp8_custom_op, use_aiter",
[
(TestAllReduceRMSNormModel, False),
(TestAllReduceRMSNormStaticQuantFP8Model, True),
(TestAllReduceRMSNormStaticQuantFP8Model, False),
(TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
(TestAllReduceRMSNormModel, False, IS_AITER_FOUND),
pytest.param(
TestAllReduceRMSNormStaticQuantFP8Model,
True,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
pytest.param(
TestAllReduceRMSNormStaticQuantFP8Model,
False,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
pytest.param(
TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
False,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
],
)
@pytest.mark.parametrize("batch_size", [8])
Expand All @@ -201,9 +232,18 @@ def ops_in_model_before(self):
@pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "allreduce_fusion")
or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"),
current_platform.is_rocm() and not IS_AITER_FOUND,
reason="aiter is not found",
)
@pytest.mark.skipif(
current_platform.is_cuda()
and (
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "allreduce_fusion")
or not has_module_attribute(
"flashinfer.comm", "create_allreduce_fusion_workspace"
)
),
reason="flashinfer is not found or flashinfer "
"is not compiled with allreduce_fusion",
)
Expand All @@ -216,7 +256,14 @@ def test_all_reduce_fusion_pass_replace(
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
use_aiter: bool,
monkeypatch: pytest.MonkeyPatch,
):
if use_aiter:
with monkeypatch.context() as m:
m.setenv("VLLM_ROCM_USE_AITER", str(use_aiter))
rocm_aiter_ops.refresh_env_variables()

num_processes = 2
if (
test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
Expand All @@ -240,6 +287,8 @@ def run_torch_spawn(fn, nprocs):
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
use_aiter,
monkeypatch,
),
nprocs=nprocs,
)
Expand All @@ -258,6 +307,8 @@ def all_reduce_fusion_pass_on_test_model(
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
use_aiter: bool,
monkeypatch: pytest.MonkeyPatch,
):
set_random_seed(0)

Expand Down Expand Up @@ -304,7 +355,11 @@ def all_reduce_fusion_pass_on_test_model(
)
with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
all_reduce_fusion_pass = (
AllReduceFusionPass(vllm_config)
if use_aiter
else RocmAiterAllReduceFusionPass(vllm_config)
)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
Expand All @@ -314,7 +369,7 @@ def all_reduce_fusion_pass_on_test_model(
)

token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
model = test_model_cls(hidden_size, token_num, use_aiter=use_aiter)

hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)

Expand Down
104 changes: 104 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
from contextlib import contextmanager
from typing import Protocol

import torch
from torch._ops import OpOverload
from torch.distributed import ProcessGroup

import vllm.envs as envs
from vllm.platforms import current_platform
Expand Down Expand Up @@ -33,6 +36,25 @@ def is_aiter_found() -> bool:
IS_AITER_FOUND = is_aiter_found()


class AiterCustomAllreduceProto(Protocol):
max_size: int
world_size: int
fully_connected: bool

@contextmanager
def capture(self): ...
def close(self) -> None: ...
def custom_fused_ar_rms(
self,
input: torch.Tensor,
residual_inp: torch.Tensor,
weight: torch.Tensor,
eps: float,
use_1stage: bool,
) -> tuple[torch.Tensor, torch.Tensor] | None: ...
def should_custom_ar(self, inp: torch.Tensor) -> bool: ...


def is_aiter_found_and_supported() -> bool:
"""Check if AITER library is available and platform supports it.

Expand Down Expand Up @@ -633,6 +655,47 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake(
return out, y_scale


def _rocm_aiter_fused_allreduce_rmsnorm_impl(
input_: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
aiter_ar = rocm_aiter_ops.get_aiter_allreduce()
assert aiter_ar is not None, "aiter allreduce must be initialized"

total_bytes = input_.numel() * input_.element_size()
hidden_dim = input_.shape[-1]
token_num = input_.shape[0]
hidden_ok = hidden_dim in (512, 1024, 2048, 4096)
token_ok = token_num <= 80
world_size = aiter_ar.world_size
full_nvlink = aiter_ar.fully_connected

if world_size == 2:
size_ok = True
elif full_nvlink and world_size <= 4:
size_ok = total_bytes < 160 * 1024
elif full_nvlink and world_size <= 8:
size_ok = total_bytes < 80 * 1024
else:
size_ok = False

use_1stage = hidden_ok and token_ok and size_ok
result = aiter_ar.custom_fused_ar_rms(input_, residual, weight, epsilon, use_1stage)
assert result is not None
return result[0], result[1]


def _rocm_aiter_fused_allreduce_rmsnorm_fake(
input_: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(input_), torch.empty_like(residual)


def _rocm_aiter_per_tensor_quant_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
Expand Down Expand Up @@ -1033,6 +1096,9 @@ class rocm_aiter_ops:
# TODO: Consolidate under _LINEAR_ENABLED
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM

_ALL_REDUCE_MAX_SIZE: int = 8192 * 1024 * 8 * 2
_CUSTOM_ALL_REDUCE: AiterCustomAllreduceProto | None = None

@classmethod
def refresh_env_variables(cls):
"""
Expand Down Expand Up @@ -1200,6 +1266,34 @@ def is_triton_rotary_embed_enabled(cls) -> bool:
def is_triton_gemm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM

@classmethod
@if_aiter_supported
def initialize_aiter_allreduce(
cls, group: ProcessGroup, device: torch.device
) -> None:
try:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

initialize_aiter_allreduce should check if _CUSTOM_ALL_REDUCE is already initialized before creating a new one. Repeatedly calling this method (e.g., during multiple pass initializations) without a corresponding destroy call will leak GPU communicator resources, as the old instance is overwritten without being closed.

        if cls._CUSTOM_ALL_REDUCE is not None:
            return
        try:

from aiter.dist.device_communicators.custom_all_reduce import (
CustomAllreduce as AiterCustomAllreduce,
)

cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device)
except Exception:
cls._CUSTOM_ALL_REDUCE = None

@classmethod
def get_aiter_allreduce(cls) -> AiterCustomAllreduceProto | None:
return cls._CUSTOM_ALL_REDUCE

@classmethod
def destroy_aiter_allreduce(cls) -> None:
if cls._CUSTOM_ALL_REDUCE is not None:
cls._CUSTOM_ALL_REDUCE.close()
cls._CUSTOM_ALL_REDUCE = None

@classmethod
def get_aiter_allreduce_max_size(cls) -> int:
return cls._ALL_REDUCE_MAX_SIZE

@staticmethod
@if_aiter_supported
def register_ops_once() -> None:
Expand Down Expand Up @@ -1386,6 +1480,12 @@ def register_ops_once() -> None:
fake_impl=_triton_rotary_embedding_fake,
)

direct_register_custom_op(
op_name="rocm_aiter_fused_allreduce_rmsnorm",
op_func=_rocm_aiter_fused_allreduce_rmsnorm_impl,
fake_impl=_rocm_aiter_fused_allreduce_rmsnorm_fake,
)

_OPS_REGISTERED = True

@staticmethod
Expand Down Expand Up @@ -1432,6 +1532,10 @@ def get_triton_add_rmsnorm_pad_op() -> OpOverload:
def get_triton_rotary_embedding_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default

@staticmethod
def get_fused_allreduce_rmsnorm_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_fused_allreduce_rmsnorm.default

@staticmethod
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
Expand Down
Loading