diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 82e97bfbb1b2..692df1ece101 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -3579,6 +3579,7 @@ steps: - pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py # TODO: this test is not supported on ROCm, there are aiter kernels for this. # - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py + - pytest -v -s tests/compile/passes/distributed/test_tp2_ar_rms.py::test_tp2_ar_rms_fusions # - pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm # - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index 301409b2bf6a..3cd0a86452ad 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -18,6 +18,8 @@ from .models import ( FLASHINFER_ATTN, FLASHINFER_MLA_ATTN, + ROCM_AITER_UNIFIED_ATTN, + ROCM_ATTN, TRITON_ATTN, deepseek_v3_fp8, gpt_oss_20b, @@ -30,8 +32,6 @@ qwen3_a3b_fp8, ) -pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") - @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( @@ -45,6 +45,7 @@ @pytest.mark.parametrize("n_layers", [4]) @pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") def test_tp2_ar_rms_fp8_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -110,6 +111,7 @@ def test_tp2_ar_rms_fp8_fusions( @pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif(not is_blackwell(), reason="Blackwell required for fp4") +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") def test_tp2_ar_rms_fp4_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -161,10 +163,19 @@ def test_tp2_ar_rms_fp4_fusions( "model_name, matches_fn, model_kwargs, hf_overrides", [llama3_8b, qwen3_a3b, gpt_oss_20b], ) -@pytest.mark.parametrize("attn_backend", [TRITON_ATTN]) +@pytest.mark.parametrize( + "attn_backend", + [ + TRITON_ATTN, + FLASHINFER_ATTN, + ROCM_ATTN, + ROCM_AITER_UNIFIED_ATTN, + ], +) @pytest.mark.parametrize("n_layers", [4]) -@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) +@pytest.mark.parametrize("custom_ops", tuple(custom_ops_combos("rms_norm"))) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only test CUDA/ROCm") def test_tp2_ar_rms_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -205,4 +216,5 @@ def test_tp2_ar_rms_fusions( compilation_config, matches_check, tp_size=2, + use_aiter=current_platform.is_rocm(), ) diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index 92e7402c0537..c37b1ef822e4 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -8,8 +8,12 @@ import vllm.envs as envs from tests.compile.backend import TestBackend from tests.utils import TestFP8Layer, has_module_attribute, multi_gpu_test +from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.compilation.passes.fusion.allreduce_rms_fusion import AllReduceFusionPass +from vllm.compilation.passes.fusion.allreduce_rms_fusion import ( + AllReduceFusionPass, + RocmAiterAllReduceFusionPass, +) from vllm.compilation.passes.utility.fix_functionalization import ( FixFunctionalizationPass, ) @@ -39,12 +43,13 @@ class TestAllReduceRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): + def __init__(self, hidden_size=16, token_num=16, eps=1e-6, use_aiter=False): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] + self.use_aiter = use_aiter def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -72,6 +77,8 @@ def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] def ops_in_model_after(self): + if self.use_aiter: + return [rocm_aiter_ops.get_fused_allreduce_rmsnorm_op()] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -185,12 +192,36 @@ def ops_in_model_before(self): @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "test_model, enable_quant_fp8_custom_op", + "test_model, enable_quant_fp8_custom_op, use_aiter", [ - (TestAllReduceRMSNormModel, False), - (TestAllReduceRMSNormStaticQuantFP8Model, True), - (TestAllReduceRMSNormStaticQuantFP8Model, False), - (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), + (TestAllReduceRMSNormModel, False, IS_AITER_FOUND), + pytest.param( + TestAllReduceRMSNormStaticQuantFP8Model, + True, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Not supported on ROCm platform", + ), + ), + pytest.param( + TestAllReduceRMSNormStaticQuantFP8Model, + False, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Not supported on ROCm platform", + ), + ), + pytest.param( + TestAllReduceFusedAddRMSNormStaticQuantFP4Model, + False, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Not supported on ROCm platform", + ), + ), ], ) @pytest.mark.parametrize("batch_size", [8]) @@ -201,9 +232,18 @@ def ops_in_model_before(self): @pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( - not find_spec("flashinfer") - or not has_module_attribute("flashinfer.comm", "allreduce_fusion") - or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"), + current_platform.is_rocm() and not IS_AITER_FOUND, + reason="aiter is not found", +) +@pytest.mark.skipif( + current_platform.is_cuda() + and ( + not find_spec("flashinfer") + or not has_module_attribute("flashinfer.comm", "allreduce_fusion") + or not has_module_attribute( + "flashinfer.comm", "create_allreduce_fusion_workspace" + ) + ), reason="flashinfer is not found or flashinfer " "is not compiled with allreduce_fusion", ) @@ -216,7 +256,14 @@ def test_all_reduce_fusion_pass_replace( enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter: bool, + monkeypatch: pytest.MonkeyPatch, ): + if use_aiter: + with monkeypatch.context() as m: + m.setenv("VLLM_ROCM_USE_AITER", str(use_aiter)) + rocm_aiter_ops.refresh_env_variables() + num_processes = 2 if ( test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model @@ -240,6 +287,8 @@ def run_torch_spawn(fn, nprocs): enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter, + monkeypatch, ), nprocs=nprocs, ) @@ -258,6 +307,8 @@ def all_reduce_fusion_pass_on_test_model( enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter: bool, + monkeypatch: pytest.MonkeyPatch, ): set_random_seed(0) @@ -304,7 +355,11 @@ def all_reduce_fusion_pass_on_test_model( ) with set_current_vllm_config(vllm_config): initialize_model_parallel(tensor_model_parallel_size=world_size) - all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + all_reduce_fusion_pass = ( + AllReduceFusionPass(vllm_config) + if use_aiter + else RocmAiterAllReduceFusionPass(vllm_config) + ) noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) @@ -314,7 +369,7 @@ def all_reduce_fusion_pass_on_test_model( ) token_num = batch_size * seq_len - model = test_model_cls(hidden_size, token_num) + model = test_model_cls(hidden_size, token_num, use_aiter=use_aiter) hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c4ba8053cc58..3f38115b2792 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -2,9 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from collections.abc import Callable +from contextlib import contextmanager +from typing import Protocol import torch from torch._ops import OpOverload +from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.platforms import current_platform @@ -33,6 +36,25 @@ def is_aiter_found() -> bool: IS_AITER_FOUND = is_aiter_found() +class AiterCustomAllreduceProto(Protocol): + max_size: int + world_size: int + fully_connected: bool + + @contextmanager + def capture(self): ... + def close(self) -> None: ... + def custom_fused_ar_rms( + self, + input: torch.Tensor, + residual_inp: torch.Tensor, + weight: torch.Tensor, + eps: float, + use_1stage: bool, + ) -> tuple[torch.Tensor, torch.Tensor] | None: ... + def should_custom_ar(self, inp: torch.Tensor) -> bool: ... + + def is_aiter_found_and_supported() -> bool: """Check if AITER library is available and platform supports it. @@ -633,6 +655,47 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake( return out, y_scale +def _rocm_aiter_fused_allreduce_rmsnorm_impl( + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + assert aiter_ar is not None, "aiter allreduce must be initialized" + + total_bytes = input_.numel() * input_.element_size() + hidden_dim = input_.shape[-1] + token_num = input_.shape[0] + hidden_ok = hidden_dim in (512, 1024, 2048, 4096) + token_ok = token_num <= 80 + world_size = aiter_ar.world_size + full_nvlink = aiter_ar.fully_connected + + if world_size == 2: + size_ok = True + elif full_nvlink and world_size <= 4: + size_ok = total_bytes < 160 * 1024 + elif full_nvlink and world_size <= 8: + size_ok = total_bytes < 80 * 1024 + else: + size_ok = False + + use_1stage = hidden_ok and token_ok and size_ok + result = aiter_ar.custom_fused_ar_rms(input_, residual, weight, epsilon, use_1stage) + assert result is not None + return result[0], result[1] + + +def _rocm_aiter_fused_allreduce_rmsnorm_fake( + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(input_), torch.empty_like(residual) + + def _rocm_aiter_per_tensor_quant_impl( x: torch.Tensor, quant_dtype: torch.dtype, @@ -1033,6 +1096,9 @@ class rocm_aiter_ops: # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + _ALL_REDUCE_MAX_SIZE: int = 8192 * 1024 * 8 * 2 + _CUSTOM_ALL_REDUCE: AiterCustomAllreduceProto | None = None + @classmethod def refresh_env_variables(cls): """ @@ -1200,6 +1266,34 @@ def is_triton_rotary_embed_enabled(cls) -> bool: def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM + @classmethod + @if_aiter_supported + def initialize_aiter_allreduce( + cls, group: ProcessGroup, device: torch.device + ) -> None: + try: + from aiter.dist.device_communicators.custom_all_reduce import ( + CustomAllreduce as AiterCustomAllreduce, + ) + + cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device) + except Exception: + cls._CUSTOM_ALL_REDUCE = None + + @classmethod + def get_aiter_allreduce(cls) -> AiterCustomAllreduceProto | None: + return cls._CUSTOM_ALL_REDUCE + + @classmethod + def destroy_aiter_allreduce(cls) -> None: + if cls._CUSTOM_ALL_REDUCE is not None: + cls._CUSTOM_ALL_REDUCE.close() + cls._CUSTOM_ALL_REDUCE = None + + @classmethod + def get_aiter_allreduce_max_size(cls) -> int: + return cls._ALL_REDUCE_MAX_SIZE + @staticmethod @if_aiter_supported def register_ops_once() -> None: @@ -1386,6 +1480,12 @@ def register_ops_once() -> None: fake_impl=_triton_rotary_embedding_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_fused_allreduce_rmsnorm", + op_func=_rocm_aiter_fused_allreduce_rmsnorm_impl, + fake_impl=_rocm_aiter_fused_allreduce_rmsnorm_fake, + ) + _OPS_REGISTERED = True @staticmethod @@ -1432,6 +1532,10 @@ def get_triton_add_rmsnorm_pad_op() -> OpOverload: def get_triton_rotary_embedding_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default + @staticmethod + def get_fused_allreduce_rmsnorm_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_fused_allreduce_rmsnorm.default + @staticmethod def rms_norm( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index f141a7c171f7..4f851f5ec57e 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -10,9 +10,11 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -860,3 +862,243 @@ def __del__(self) -> None: return with contextlib.suppress(Exception): destroy_fi_ar_workspace() + + +class AiterAllreduceFusedRMSNormPattern: + FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() + + def __init__( + self, epsilon: float, dtype: torch.dtype, use_aiter_rmsnorm: bool = True + ) -> None: + self.dtype = dtype + self.epsilon = epsilon + self.rmsnorm_matcher = MatcherRMSNorm( + epsilon, match_rocm_aiter=use_aiter_rmsnorm + ) + + def get_inputs(self) -> list[torch.Tensor]: + input, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms = self.rmsnorm_matcher(allreduce_output, weight) + + return rms, allreduce_output + + def replacement( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + residual = torch.empty_like(input) + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + return allreduce[0], allreduce[1] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class AiterAllreduceFusedAddRMSNormPattern: + FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() + + def __init__( + self, epsilon: float, dtype: torch.dtype, use_aiter_rmsnorm: bool = True + ) -> None: + self.epsilon = epsilon + self.dtype = dtype + self.rmsnorm_matcher = MatcherFusedAddRMSNorm( + epsilon, match_rocm_aiter=use_aiter_rmsnorm + ) + + def get_inputs(self) -> list[torch.Tensor]: + input, residual, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) + + return rms, residual + + def replacement( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + return allreduce[0], allreduce[1] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class RocmAiterAllReduceFusionPass(VllmPatternMatcherPass): + def __init__(self, config: VllmConfig) -> None: + super().__init__(config) + self.disabled = True + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size <= 1: + logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.") + return + + if config.model_config is None: + logger.warning_once( + "AllReduce fusion pass is disabled for missing model_config." + ) + return + + device_comm = get_tp_group().device_communicator + if device_comm is None: + logger.warning_once("Device communicator is required.") + return + + ca_comm = getattr(device_comm, "ca_comm", None) + if ca_comm is None: + logger.warning_once("Custom Allreduce is required.") + return + self.ca_comm = ca_comm + + assert isinstance(ca_comm, CustomAllreduce) + + group = get_tp_group().cpu_group + rocm_aiter_ops.initialize_aiter_allreduce(group, self.device) + hidden_dim = config.model_config.get_hidden_size() + element_size = torch.tensor([], dtype=self.model_dtype).element_size() + max_size = rocm_aiter_ops.get_aiter_allreduce_max_size() + if max_size is None: + logger.warning("AITER allreduce fusion must be initialized") + return + + max_token_num = (max_size / 2) // (hidden_dim * element_size) + self.max_token_num = min( + max_token_num, + config.scheduler_config.max_num_batched_tokens, + ) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_allreduce_rmsnorm_fusion_pass" + ) + + self.register_patterns() + self.dump_patterns(config, self.patterns) + + @enable_fake_mode + def register_patterns(self): + for epsilon in [1e-5, 1e-6]: + AiterAllreduceFusedRMSNormPattern( + epsilon, + self.model_dtype, + ).register(self.patterns) + + AiterAllreduceFusedAddRMSNormPattern( + epsilon, + self.model_dtype, + ).register(self.patterns) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + self.disabled = False + + def is_applicable_for_range(self, compile_range: Range) -> bool: + if self.disabled: + logger.warning_once("AllReduce fusion pass is disabled.") + return False + return bool(compile_range.end <= self.max_token_num) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + if self.disabled: + logger.debug("ROCmAiterAllReduceRMSNormFusionPass disabled") + return + + self._bypass_noop_views_after_allreduce(graph) + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def _bypass_noop_views_after_allreduce(self, graph: fx.Graph) -> None: + """Remove no-op view/reshape nodes sitting between all_reduce and + rmsnorm so the pattern matcher can fuse them. + + Some models (e.g. DeepSeek MoE) insert + ``final_hidden_states.view(num_tokens, hidden_dim)`` after + all_reduce. The view is identity-shaped but creates an intermediate + node that prevents the ``all_reduce -> rmsnorm`` pattern from + matching. + """ + from torch.fx.experimental.symbolic_shapes import statically_known_true + + count = 0 + for node in list(graph.nodes): + if node.op != "call_function" or node.target not in ( + torch.ops.aten.view.default, + torch.ops.aten.reshape.default, + ): + continue + + input_node = node.args[0] + if not isinstance(input_node, fx.Node): + continue + + if ( + input_node.op != "call_function" + or input_node.target != torch.ops.vllm.all_reduce.default + ): + continue + + input_val = input_node.meta.get("val") + output_val = node.meta.get("val") + if input_val is None or output_val is None: + continue + + in_shape = input_val.shape + out_shape = output_val.shape + if len(in_shape) != len(out_shape): + continue + if not all( + statically_known_true(s == o) + for s, o in zip(in_shape, out_shape) + ): + continue + + node.replace_all_uses_with(input_node) + graph.erase_node(node) + count += 1 + + if count: + logger.debug( + "Bypassed %s no-op view(s) after all_reduce", count + ) + + def __del__(self) -> None: + if getattr(self, "disabled", True): + return + with contextlib.suppress(Exception): + rocm_aiter_ops.destroy_aiter_allreduce() + + def uuid(self) -> str: + return VllmInductorPass.hash_source( + self, + AiterAllreduceFusedRMSNormPattern, + AiterAllreduceFusedAddRMSNormPattern, + ) diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 70f86c8d2ae3..61189b8889d0 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -17,6 +17,9 @@ from .vllm_inductor_pass import VllmInductorPass if rocm_aiter_ops.is_enabled(): + from .fusion.allreduce_rms_fusion import ( + RocmAiterAllReduceFusionPass, + ) from .fusion.rocm_aiter_fusion import ( RocmAiterRMSNormQuantFusionPass, RocmAiterSiluMulFp8GroupQuantFusionPass, @@ -122,7 +125,10 @@ def configure(self, config: VllmConfig) -> None: self.passes += [AsyncTPPass(config)] if self.pass_config.fuse_allreduce_rms: - self.passes += [AllReduceFusionPass(config)] + if rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterAllReduceFusionPass(config)] + else: + self.passes += [AllReduceFusionPass(config)] if self.pass_config.fuse_norm_quant: self.passes += [RMSNormQuantFusionPass(config)] diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index f525ac871c3e..8407e00fce60 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -115,6 +115,15 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer + if current_platform.is_rocm(): + from vllm._aiter_ops import rocm_aiter_ops + + return ( + rocm_aiter_ops.is_enabled() + and rocm_aiter_ops.is_rmsnorm_enabled() + and cfg.parallel_config.tensor_parallel_size > 1 + ) + return ( cfg.parallel_config.tensor_parallel_size > 1 and current_platform.is_cuda() diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index bd5741e8dc72..c37256287c2f 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -56,7 +56,6 @@ def __init__( self.use_custom_allreduce = use_custom_allreduce self.use_torch_symm_mem = use_torch_symm_mem self.use_flashinfer_allreduce = use_flashinfer_allreduce - # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 04187b34ec7a..b947d89076bd 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -470,6 +470,7 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None # only cuda uses this function, # so we don't abstract it into the base class maybe_ca_context = nullcontext() + maybe_aiter_context = nullcontext() from vllm.distributed.device_communicators.cuda_communicator import ( CudaCommunicator, ) @@ -480,13 +481,21 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None if ca_comm is not None: maybe_ca_context = ca_comm.capture() # type: ignore + from vllm._aiter_ops import rocm_aiter_ops + + aiter_enabled = rocm_aiter_ops.is_enabled() + if aiter_enabled: + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + if aiter_ar is not None: + maybe_aiter_context = aiter_ar.capture() # type: ignore + # ensure all initialization operations complete before attempting to # capture the graph on another stream curr_stream = torch.cuda.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) - with torch.cuda.stream(stream), maybe_ca_context: + with torch.cuda.stream(stream), maybe_ca_context, maybe_aiter_context: yield graph_capture_context def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: