diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 3ff5413f707e..844dbe639b3c 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -1108,6 +1108,7 @@ steps: - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py - pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py + - pytest -v -s tests/compile/passes/distributed/test_tp2_ar_rms.py::test_tp2_ar_rms_fusions #----------------------------------------------------------- mi300 ยท cuda ------------------------------------------------------------# diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index 9156f6afa06a..b5e2b2dc07ea 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -19,6 +19,8 @@ FLASHINFER_ATTN, FLASHINFER_MLA_ATTN, FLASHMLA_SPARSE_ATTN, + ROCM_AITER_UNIFIED_ATTN, + ROCM_ATTN, TRITON_ATTN, deepseek_coder_v2_lite_fp8, deepseek_r1_fp4, @@ -34,7 +36,9 @@ qwen3_a3b_fp8, ) -pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +pytestmark = pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Only test CUDA/ROCm" +) @multi_gpu_test(num_gpus=2) @@ -55,6 +59,7 @@ @pytest.mark.parametrize("n_layers", [4]) @pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") def test_tp2_ar_rms_fp8_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -124,6 +129,7 @@ def test_tp2_ar_rms_fp8_fusions( @pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif(not is_blackwell(), reason="Blackwell required for fp4") +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") def test_tp2_ar_rms_fp4_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -176,10 +182,19 @@ def test_tp2_ar_rms_fp4_fusions( "model_name, matches_fn, model_kwargs, hf_overrides", [llama3_8b, qwen3_a3b, gpt_oss_20b], ) -@pytest.mark.parametrize("attn_backend", [TRITON_ATTN]) +@pytest.mark.parametrize( + "attn_backend", + [ + TRITON_ATTN, + FLASHINFER_ATTN, + ROCM_ATTN, + ROCM_AITER_UNIFIED_ATTN, + ], +) @pytest.mark.parametrize("n_layers", [4]) -@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) +@pytest.mark.parametrize("custom_ops", tuple(custom_ops_combos("rms_norm"))) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only test CUDA/ROCm") def test_tp2_ar_rms_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -221,4 +236,5 @@ def test_tp2_ar_rms_fusions( compilation_config, matches_check, tp_size=2, + use_aiter=current_platform.is_rocm(), ) diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index e2c461e6692d..1a175b8dd335 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -8,8 +8,12 @@ import vllm.envs as envs from tests.compile.backend import TestBackend from tests.utils import TestFP8Layer, has_module_attribute, multi_gpu_test +from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.compilation.passes.fusion.allreduce_rms_fusion import AllReduceFusionPass +from vllm.compilation.passes.fusion.allreduce_rms_fusion import ( + AllReduceFusionPass, + RocmAiterAllReduceFusionPass, +) from vllm.compilation.passes.utility.fix_functionalization import ( FixFunctionalizationPass, ) @@ -42,13 +46,19 @@ class TestAllReduceRMSNormModel(torch.nn.Module): def __init__( - self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16 + self, + hidden_size=16, + token_num=16, + eps=1e-6, + dtype: torch.dtype = torch.float16, + use_aiter: bool = False, ): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] + self.use_aiter = use_aiter def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -76,6 +86,8 @@ def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] def ops_in_model_after(self): + if self.use_aiter: + return [rocm_aiter_ops.get_fused_allreduce_rmsnorm_op()] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -194,12 +206,36 @@ def ops_in_model_before(self): @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "test_model, enable_quant_fp8_custom_op", + "test_model, enable_quant_fp8_custom_op, use_aiter", [ - (TestAllReduceRMSNormModel, False), - (TestAllReduceRMSNormStaticQuantFP8Model, True), - (TestAllReduceRMSNormStaticQuantFP8Model, False), - (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), + (TestAllReduceRMSNormModel, False, IS_AITER_FOUND), + pytest.param( + TestAllReduceRMSNormStaticQuantFP8Model, + True, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Not supported on ROCm platform", + ), + ), + pytest.param( + TestAllReduceRMSNormStaticQuantFP8Model, + False, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Not supported on ROCm platform", + ), + ), + pytest.param( + TestAllReduceFusedAddRMSNormStaticQuantFP4Model, + False, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Not supported on ROCm platform", + ), + ), ], ) @pytest.mark.parametrize("batch_size", [8]) @@ -210,9 +246,18 @@ def ops_in_model_before(self): @pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( - not find_spec("flashinfer") - or not has_module_attribute("flashinfer.comm", "allreduce_fusion") - or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"), + current_platform.is_rocm() and not IS_AITER_FOUND, + reason="aiter is not found", +) +@pytest.mark.skipif( + current_platform.is_cuda() + and ( + not find_spec("flashinfer") + or not has_module_attribute("flashinfer.comm", "allreduce_fusion") + or not has_module_attribute( + "flashinfer.comm", "create_allreduce_fusion_workspace" + ) + ), reason="flashinfer is not found or flashinfer " "is not compiled with allreduce_fusion", ) @@ -225,7 +270,14 @@ def test_all_reduce_fusion_pass_replace( enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter: bool, + monkeypatch: pytest.MonkeyPatch, ): + if use_aiter: + with monkeypatch.context() as m: + m.setenv("VLLM_ROCM_USE_AITER", str(use_aiter)) + rocm_aiter_ops.refresh_env_variables() + num_processes = 2 if ( test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model @@ -249,6 +301,8 @@ def run_torch_spawn(fn, nprocs): enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter, + monkeypatch, ), nprocs=nprocs, ) @@ -267,6 +321,8 @@ def all_reduce_fusion_pass_on_test_model( enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter: bool, + monkeypatch: pytest.MonkeyPatch, ): set_random_seed(0) @@ -313,7 +369,11 @@ def all_reduce_fusion_pass_on_test_model( ) with set_current_vllm_config(vllm_config): initialize_model_parallel(tensor_model_parallel_size=world_size) - all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + all_reduce_fusion_pass = ( + RocmAiterAllReduceFusionPass(vllm_config) + if use_aiter + else AllReduceFusionPass(vllm_config) + ) noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) @@ -323,7 +383,12 @@ def all_reduce_fusion_pass_on_test_model( ) token_num = batch_size * seq_len - model = test_model_cls(hidden_size, token_num, dtype=dtype) + if test_model_cls is TestAllReduceRMSNormModel: + model = test_model_cls( + hidden_size, token_num, dtype=dtype, use_aiter=use_aiter + ) + else: + model = test_model_cls(hidden_size, token_num, dtype=dtype) hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index f702a025f5ad..b11fc21975ca 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -2,9 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from collections.abc import Callable +from contextlib import contextmanager +from typing import Protocol import torch from torch._ops import OpOverload +from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.platforms import current_platform @@ -39,6 +42,27 @@ def is_aiter_found() -> bool: IS_AITER_FOUND = is_aiter_found() +class AiterCustomAllreduceProto(Protocol): + max_size: int + world_size: int + fully_connected: bool + + @contextmanager + def capture(self): ... + def close(self) -> None: ... + def fused_ar_rms( + self, + inp: torch.Tensor, + res_inp: torch.Tensor, + *, + w: torch.Tensor, + eps: float, + registered: bool = False, + use_1stage: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: ... + def should_custom_ar(self, inp: torch.Tensor) -> bool: ... + + def is_aiter_found_and_supported() -> bool: """Check if AITER library is available and platform supports it. @@ -750,6 +774,55 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake( return out, y_scale +def _rocm_aiter_fused_allreduce_rmsnorm_impl( + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + assert aiter_ar is not None, "aiter allreduce must be initialized" + + total_bytes = input_.numel() * input_.element_size() + hidden_dim = input_.shape[-1] + token_num = input_.shape[0] + hidden_ok = hidden_dim in (512, 1024, 2048, 4096, 7168) + token_ok = token_num <= 80 + world_size = aiter_ar.world_size + full_nvlink = aiter_ar.fully_connected + + if world_size == 2: + size_ok = True + elif full_nvlink and world_size <= 4: + size_ok = total_bytes < 256 * 1024 + elif full_nvlink and world_size <= 8: + size_ok = total_bytes < 128 * 1024 + else: + size_ok = False + + use_1stage = hidden_ok and token_ok and size_ok + + result = aiter_ar.fused_ar_rms( + input_, + residual, + w=weight, + eps=epsilon, + registered=torch.cuda.is_current_stream_capturing(), + use_1stage=use_1stage, + ) + assert result is not None + return result[0], result[1] + + +def _rocm_aiter_fused_allreduce_rmsnorm_fake( + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(input_), torch.empty_like(residual) + + def _rocm_aiter_per_tensor_quant_impl( x: torch.Tensor, quant_dtype: torch.dtype, @@ -1188,6 +1261,9 @@ class rocm_aiter_ops: # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + _ALL_REDUCE_MAX_SIZE: int = 8192 * 1024 * 8 * 2 + _CUSTOM_ALL_REDUCE: AiterCustomAllreduceProto | None = None + @classmethod def refresh_env_variables(cls): """ @@ -1362,6 +1438,35 @@ def is_tgemm_enabled(cls) -> bool: return cls.is_linear_enabled() and on_gfx950() + @classmethod + def initialize_aiter_allreduce( + cls, group: ProcessGroup, device: torch.device + ) -> None: + try: + from aiter.dist.device_communicators.custom_all_reduce import ( + CustomAllreduce as AiterCustomAllreduce, + ) + + cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device) + except Exception: + cls._CUSTOM_ALL_REDUCE = None + + @classmethod + def get_aiter_allreduce(cls) -> AiterCustomAllreduceProto | None: + return cls._CUSTOM_ALL_REDUCE + + @classmethod + def destroy_aiter_allreduce(cls) -> None: + if cls._CUSTOM_ALL_REDUCE is not None: + cls._CUSTOM_ALL_REDUCE.close() + cls._CUSTOM_ALL_REDUCE = None + + @classmethod + def get_aiter_allreduce_max_size(cls) -> int | None: + # effective max input size (based on upstream aiter version: v0.1.10.post3) + # https://github.com/ROCm/aiter/blob/6a0e7b26ccf33164785531212cc2ec2cde0b9243/aiter/dist/device_communicators/custom_all_reduce.py#L272-L273 + return int(cls._ALL_REDUCE_MAX_SIZE / 2) + @staticmethod @if_aiter_supported def register_ops_once() -> None: @@ -1552,6 +1657,12 @@ def register_ops_once() -> None: fake_impl=_triton_rotary_embedding_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_fused_allreduce_rmsnorm", + op_func=_rocm_aiter_fused_allreduce_rmsnorm_impl, + fake_impl=_rocm_aiter_fused_allreduce_rmsnorm_fake, + ) + direct_register_custom_op( op_name="fused_mla_dual_rms_norm", op_func=_fused_mla_dual_rms_norm_impl, @@ -1605,6 +1716,10 @@ def get_triton_add_rmsnorm_pad_op() -> OpOverload: def get_triton_rotary_embedding_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default + @staticmethod + def get_fused_allreduce_rmsnorm_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_fused_allreduce_rmsnorm.default + @staticmethod def get_fused_mla_dual_rms_norm_op() -> OpOverload: return torch.ops.vllm.fused_mla_dual_rms_norm.default diff --git a/vllm/compilation/passes/fusion/act_quant_fusion.py b/vllm/compilation/passes/fusion/act_quant_fusion.py index 73234ec7920d..e35fc5cd4084 100644 --- a/vllm/compilation/passes/fusion/act_quant_fusion.py +++ b/vllm/compilation/passes/fusion/act_quant_fusion.py @@ -190,6 +190,7 @@ def __init__( is_scale_transposed: bool = False, is_e8m0: bool = False, is_tma_aligned: bool = False, + match_aiter: bool = False, ) -> None: super().__init__(quant_key) self.quant_matcher = MatcherQuantFP8( diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 09b9a557fe45..e683b1dfa69f 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -12,12 +12,14 @@ from torch._inductor.pattern_matcher import PatternMatcherPass import vllm.ir.ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.compilation.passes.fusion.rms_quant_fusion import ( _rms_input_weight_dtype_match, ) from vllm.config import VllmConfig from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -32,7 +34,12 @@ ) from ..inductor_pass import enable_fake_mode -from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass +from ..vllm_inductor_pass import ( + VllmFusionPatternMatcherPass, + VllmInductorPass, + VllmPatternMatcherPass, + VllmPatternReplacement, +) from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8 FP8_DTYPE = current_platform.fp8_dtype() @@ -889,3 +896,204 @@ def __del__(self) -> None: return with contextlib.suppress(Exception): destroy_fi_ar_workspace() + + +# TODO: make BasePattern to inherit from VllmPatternReplacement +class AiterAllreduceFusedRMSNormPattern(BasePattern, VllmPatternReplacement): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str | None, + use_aiter_rmsnorm: bool = True, + ) -> None: + super().__init__(dtype, device) + self.dtype = dtype + self.epsilon = epsilon + self.FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() + + def get_inputs(self) -> list[torch.Tensor]: + return [self.empty(5, 16), self.empty(16)] + + @property + def pattern(self): + def _pattern( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms = vllm.ir.ops.rms_norm(allreduce_output, weight, self.epsilon) + + return rms, allreduce_output + + return _pattern + + @property + def replacement(self): + def _replacement( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + residual = torch.empty_like(input) + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + return allreduce[0], allreduce[1] + + return _replacement + + +class AiterAllreduceFusedAddRMSNormPattern(BasePattern, VllmPatternReplacement): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str | None, + use_aiter_rmsnorm: bool = True, + ) -> None: + super().__init__(dtype, device) + self.epsilon = epsilon + self.dtype = dtype + self.rmsnorm_matcher = MatcherFusedAddRMSNorm( + epsilon, match_rocm_aiter=use_aiter_rmsnorm + ) + self.FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() + + def get_inputs(self) -> list[torch.Tensor]: + input, residual, weight = self.rmsnorm_matcher.inputs() + + return [residual, input.to(self.dtype), weight] + + @property + def pattern(self): + def _pattern( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) + + return rms, residual + + return _pattern + + @property + def replacement(self): + def _replacement( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + return allreduce[0], allreduce[1] + + return _replacement + + +class RocmAiterAllReduceFusionPass(VllmFusionPatternMatcherPass): + def __init__(self, config: VllmConfig) -> None: + super().__init__(config, "rocm_aiter_allreduce_fusion_pass") + self.disabled = True + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size <= 1: + logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.") + return + + if config.model_config is None: + logger.warning_once( + "AllReduce fusion pass is disabled for missing model_config." + ) + return + + device_comm = get_tp_group().device_communicator + if device_comm is None: + logger.warning_once("Device communicator is required.") + return + + ca_comm = getattr(device_comm, "ca_comm", None) + if ca_comm is None: + logger.warning_once("Custom Allreduce is required.") + return + self.ca_comm = ca_comm + + assert isinstance(ca_comm, CustomAllreduce) + + group = get_tp_group().cpu_group + rocm_aiter_ops.initialize_aiter_allreduce(group, self.device) + hidden_dim = config.model_config.get_hidden_size() + element_size = torch.tensor([], dtype=self.model_dtype).element_size() + max_size = rocm_aiter_ops.get_aiter_allreduce_max_size() + if max_size is None: + logger.warning("AITER allreduce fusion must be initialized") + return + + # Aiter's fused_allreduce_rmsnorm kernel dispatches on hidden_dim. + # Before aiter v0.1.12 the launcher was template-specialized on HIDDEN_DIM + # and silently no-op'd for sizes outside {512, 1024, 2048, 4096}. From v0.1.12 + # hidden_dim is a runtime argument. Detect the older API via the missing + # `_pool` attribute and skip fusion for unsupported sizes. + # Ref (old kernel): https://github.com/ROCm/aiter/blob/6a0e7b26ccf33164785531212cc2ec2cde0b9243/csrc/include/custom_all_reduce.cuh#L2590 + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + _AITER_OLD_FUSED_AR_RMS_HIDDEN = (512, 1024, 2048, 4096) + if ( + aiter_ar is not None + and not hasattr(aiter_ar, "_pool") + and hidden_dim not in _AITER_OLD_FUSED_AR_RMS_HIDDEN + ): + logger.warning_once( + "AITER allreduce-rmsnorm fusion disabled: aiter<0.1.12 " + "only supports hidden_dim in %s; got %d. Upgrade aiter to " + ">=0.1.12 to enable fusion for this model.", + _AITER_OLD_FUSED_AR_RMS_HIDDEN, + hidden_dim, + ) + # Tear down aiter's custom-allreduce so its IPC handles don't + # race with vllm's ca_comm on the unfused fallback path. + with contextlib.suppress(Exception): + rocm_aiter_ops.destroy_aiter_allreduce() + return + + max_token_num = max_size // (hidden_dim * element_size) + self.max_token_num = min( + max_token_num, + config.scheduler_config.max_num_batched_tokens, + ) + + for epsilon in [1e-5, 1e-6]: + self.register( + AiterAllreduceFusedRMSNormPattern( + epsilon, + self.model_dtype, + self.device, + ) + ) + self.register( + AiterAllreduceFusedAddRMSNormPattern( + epsilon, + self.model_dtype, + self.device, + ) + ) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + self.disabled = False + + self.dump_patterns(config, self.pm_pass) + + def is_applicable_for_range(self, compile_range: Range) -> bool: + if self.disabled: + logger.warning_once("AllReduce fusion pass is disabled.") + return False + return bool(compile_range.end <= self.max_token_num) + + def __del__(self) -> None: + if getattr(self, "disabled", True): + return + with contextlib.suppress(Exception): + rocm_aiter_ops.destroy_aiter_allreduce() diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index b7c0d525c91d..3dc0d7b096ba 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -18,6 +18,9 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass if rocm_aiter_ops.is_enabled(): + from .fusion.allreduce_rms_fusion import ( + RocmAiterAllReduceFusionPass, + ) from .fusion.rocm_aiter_fusion import ( MLADualRMSNormFusionPass, RocmAiterRMSNormQuantFusionPass, @@ -137,7 +140,10 @@ def configure(self, config: VllmConfig) -> None: self.passes += [AsyncTPPass(config)] if self.pass_config.fuse_allreduce_rms: - self.passes += [AllReduceFusionPass(config)] + if rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterAllReduceFusionPass(config)] + else: + self.passes += [AllReduceFusionPass(config)] if self.pass_config.fuse_minimax_qk_norm: self.passes += [MiniMaxQKNormPass(config)] diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 870ce57e4398..bb3ea81bce52 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -121,6 +121,15 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer + if current_platform.is_rocm(): + from vllm._aiter_ops import rocm_aiter_ops + + return ( + rocm_aiter_ops.is_enabled() + and rocm_aiter_ops.is_rmsnorm_enabled() + and cfg.parallel_config.tensor_parallel_size > 1 + ) + return ( cfg.parallel_config.tensor_parallel_size > 1 and current_platform.is_cuda() @@ -1604,10 +1613,15 @@ def _set_compile_ranges(self): if compile_range_end is not None: computed_compile_ranges_endpoints.append(compile_range_end) - # Add the compile ranges for flashinfer + # Add the compile ranges for flashinfer/aiter. if compilation_config.pass_config.fuse_allreduce_rms: tp_size = self.parallel_config.tensor_parallel_size - max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_enabled(): + max_size = rocm_aiter_ops.get_aiter_allreduce_max_size() + else: + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) if max_size is not None: assert isinstance(self.model_config.dtype, torch.dtype) max_token_num = max_size // ( diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 473acb908b28..58c49c09dc54 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -472,6 +472,7 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None # only cuda uses this function, # so we don't abstract it into the base class maybe_ca_context = nullcontext() + maybe_aiter_context = nullcontext() from vllm.distributed.device_communicators.cuda_communicator import ( CudaCommunicator, ) @@ -482,13 +483,20 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None if ca_comm is not None: maybe_ca_context = ca_comm.capture() # type: ignore + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_enabled(): + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + if aiter_ar is not None: + maybe_aiter_context = aiter_ar.capture() # type: ignore + # ensure all initialization operations complete before attempting to # capture the graph on another stream curr_stream = torch.cuda.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) - with torch.cuda.stream(stream), maybe_ca_context: + with torch.cuda.stream(stream), maybe_ca_context, maybe_aiter_context: yield graph_capture_context def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: