From 28337caa2bd34239e1398250a0099e694aa2c15c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 6 Mar 2026 11:02:10 +0000 Subject: [PATCH 01/10] enable aiter all reduce and fused ar_rmsnorm Signed-off-by: vllmellm --- vllm/_aiter_ops.py | 116 ++++++++++++++ .../passes/fusion/rocm_aiter_fusion.py | 151 ++++++++++++++++++ vllm/compilation/passes/pass_manager.py | 6 +- .../device_communicators/cuda_communicator.py | 27 +++- 4 files changed, 298 insertions(+), 2 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c8366ecce543..6b1abd3674fe 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -633,6 +633,102 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake( return out, y_scale +def _rocm_aiter_fused_allreduce_rmsnorm_impl( + input_: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + from vllm.distributed import get_tp_group + + group = get_tp_group() + + device_comm = group.device_communicator + if device_comm is not None: + aiter_ar_comm = getattr(device_comm, "aiter_ar_comm", None) + + if ( + aiter_ar_comm is not None + and not aiter_ar_comm.disabled + and aiter_ar_comm.should_custom_ar(input_) + and hasattr(aiter_ar_comm, "custom_fused_ar_rms") + ): + total_bytes = input_.numel() * input_.element_size() + use_1stage = total_bytes <= 128 * 1024 + + out, res_out = aiter_ar_comm.custom_fused_ar_rms( + input_, + residual_inp=torch.zeros_like(input_), + weight=weight, + eps=epsilon, + use_1stage=use_1stage, + ) + return out, res_out + + # Fallback: launch all-reduce and rmsnorm separately + ar_out = group._all_reduce_out_place(input_) + + out = _rocm_aiter_rms_norm_impl(ar_out, weight, epsilon) + return ar_out, out + + +def _rocm_aiter_fused_allreduce_rmsnorm_fake( + input_: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(input_), torch.empty_like(input_) + + +def _rocm_aiter_fused_allreduce_add_rmsnorm_impl( + input_: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + residual: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + from vllm.distributed import get_tp_group + + group = get_tp_group() + + device_comm = group.device_communicator + if device_comm is not None: + aiter_ar_comm = getattr(device_comm, "aiter_ar_comm", None) + + if ( + aiter_ar_comm is not None + and not aiter_ar_comm.disabled + and aiter_ar_comm.should_custom_ar(input_) + and hasattr(aiter_ar_comm, "custom_fused_ar_rms") + ): + total_bytes = input_.numel() * input_.element_size() + use_1stage = total_bytes <= 128 * 1024 + out, res_out = aiter_ar_comm.custom_fused_ar_rms( + input_, + residual_inp=residual, + weight=weight, + eps=epsilon, + use_1stage=use_1stage, + ) + return out, res_out + + # Fallback: launch all-reduce and rmsnorm separately + ar_out = group._all_reduce_out_place(input_) + + out, residual_out = _rocm_aiter_rmsnorm2d_fwd_with_add_impl( + ar_out, residual, weight, epsilon + ) + + return out, residual_out + + +def _rocm_aiter_fused_allreduce_add_rmsnorm_fake( + input_: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + residual: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(input_), torch.empty_like(residual) + + def _rocm_aiter_per_tensor_quant_impl( x: torch.Tensor, quant_dtype: torch.dtype, @@ -1345,6 +1441,18 @@ def register_ops_once() -> None: fake_impl=_triton_rotary_embedding_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_fused_allreduce_rmsnorm", + op_func=_rocm_aiter_fused_allreduce_rmsnorm_impl, + fake_impl=_rocm_aiter_fused_allreduce_rmsnorm_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_fused_allreduce_add_rmsnorm", + op_func=_rocm_aiter_fused_allreduce_add_rmsnorm_impl, + fake_impl=_rocm_aiter_fused_allreduce_add_rmsnorm_fake, + ) + _OPS_REGISTERED = True @staticmethod @@ -1391,6 +1499,14 @@ def get_triton_add_rmsnorm_pad_op() -> OpOverload: def get_triton_rotary_embedding_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default + @staticmethod + def get_fused_allreduce_rmsnorm() -> OpOverload: + return torch.ops.vllm.rocm_aiter_fused_allreduce_rmsnorm.default + + @staticmethod + def get_fused_allreduce_add_rmsnorm() -> OpOverload: + return torch.ops.vllm.rocm_aiter_fused_allreduce_add_rmsnorm.default + @staticmethod def rms_norm( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float diff --git a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py index 59c94db5e812..a7d157b597bf 100644 --- a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py +++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py @@ -9,6 +9,10 @@ import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -500,3 +504,150 @@ def __call__(self, graph: torch.fx.Graph) -> None: def uuid(self) -> str: return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern) + + +class AiterAllreduceFusedRMSNormPattern: + FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm() + + def __init__( + self, epsilon: float, dtype: torch.dtype, use_aiter_rmsnorm: bool = True + ) -> None: + self.dtype = dtype + self.epsilon = epsilon + self.rmsnorm_matcher = MatcherRMSNorm( + epsilon, match_rocm_aiter=use_aiter_rmsnorm + ) + + def get_inputs(self) -> list[torch.Tensor]: + input, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms = self.rmsnorm_matcher(allreduce_output, weight) + + return rms, allreduce_output + + def replacement( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + weight=weight, + epsilon=self.epsilon, + ) + return allreduce[0], allreduce[1] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class AiterAllreduceFusedAddRMSNormPattern: + FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_add_rmsnorm() + + def __init__( + self, epsilon: float, dtype: torch.dtype, use_aiter_rmsnorm: bool = True + ) -> None: + self.epsilon = epsilon + self.dtype = dtype + self.rmsnorm_matcher = MatcherFusedAddRMSNorm( + epsilon, match_rocm_aiter=use_aiter_rmsnorm + ) + + def get_inputs(self) -> list[torch.Tensor]: + input, residual, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) + + return rms, residual + + def replacement( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + weight=weight, + epsilon=self.epsilon, + residual=residual, + ) + return allreduce[0], allreduce[1] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class RocmAiterAllReduceFusionPass(VllmPatternMatcherPass): + def __init__(self, config: VllmConfig) -> None: + super().__init__(config) + self.disabled = True + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size <= 1: + logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.") + return + + if config.model_config is None: + logger.warning_once( + "AllReduce fusion pass is disabled for missing model_config." + ) + return + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_allreduce_rmsnorm_fusion_pass" + ) + + self.register_patterns() + self.dump_patterns(config, self.patterns) + + @enable_fake_mode + def register_patterns(self): + for epsilon in [1e-5, 1e-6]: + AiterAllreduceFusedRMSNormPattern( + epsilon, + self.model_dtype, + ).register(self.patterns) + + AiterAllreduceFusedAddRMSNormPattern( + epsilon, + self.model_dtype, + ).register(self.patterns) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + self.disabled = False + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + if self.disabled: + logger.debug("ROCmAiterAllReduceRMSNormFusionPass disabled") + return + + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def __del__(self) -> None: + if getattr(self, "disabled", True): + return + + def uuid(self) -> str: + return VllmInductorPass.hash_source( + self, + AiterAllreduceFusedRMSNormPattern, + AiterAllreduceFusedAddRMSNormPattern, + ) diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 70f86c8d2ae3..389aaf14582b 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -18,6 +18,7 @@ if rocm_aiter_ops.is_enabled(): from .fusion.rocm_aiter_fusion import ( + RocmAiterAllReduceFusionPass, RocmAiterRMSNormQuantFusionPass, RocmAiterSiluMulFp8GroupQuantFusionPass, RocmAiterTritonAddRMSNormPadFusionPass, @@ -122,7 +123,10 @@ def configure(self, config: VllmConfig) -> None: self.passes += [AsyncTPPass(config)] if self.pass_config.fuse_allreduce_rms: - self.passes += [AllReduceFusionPass(config)] + if rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterAllReduceFusionPass(config)] + else: + self.passes += [AllReduceFusionPass(config)] if self.pass_config.fuse_norm_quant: self.passes += [RMSNormQuantFusionPass(config)] diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 5e18dbde91d2..358a2241e6e4 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -46,17 +46,20 @@ def __init__( use_custom_allreduce = False use_torch_symm_mem = False use_flashinfer_allreduce = False + use_aiter_allreduce = False else: + from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed.parallel_state import _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM use_flashinfer_allreduce = envs.VLLM_ALLREDUCE_USE_FLASHINFER + use_aiter_allreduce = rocm_aiter_ops.is_enabled() self.use_custom_allreduce = use_custom_allreduce self.use_torch_symm_mem = use_torch_symm_mem self.use_flashinfer_allreduce = use_flashinfer_allreduce - + self.use_aiter_allreduce = use_aiter_allreduce # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce, @@ -83,6 +86,7 @@ def __init__( self.qr_comm: QuickAllReduce | None = None self.symm_mem_comm: SymmMemCommunicator | None = None self.fi_ar_comm: FlashInferAllReduce | None = None + self.aiter_ar_comm = None if use_torch_symm_mem and current_platform.is_cuda(): self.symm_mem_comm = SymmMemCommunicator( @@ -96,6 +100,16 @@ def __init__( device=self.device, ) + if self.use_aiter_allreduce and self.world_size > 1: + from aiter.dist.device_communicators.custom_all_reduce import ( + CustomAllreduce as AiterCustomAllreduce, + ) + + self.aiter_ar_comm = AiterCustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -187,6 +201,17 @@ def all_reduce(self, input_): out = fi_ar_comm.all_reduce(input_) assert out is not None return out + aiter_ar_comm = self.aiter_ar_comm + if ( + aiter_ar_comm is not None + and not aiter_ar_comm.disabled + and aiter_ar_comm.should_custom_ar(input_) + ): + out = aiter_ar_comm.custom_all_reduce( + input_, use_new=True, open_fp8_quant=False + ) + assert out is not None + return out ca_comm = self.ca_comm if ( ca_comm is not None From 7a692cfad9f6994e6499848e460758624087976a Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 10 Mar 2026 11:38:07 +0000 Subject: [PATCH 02/10] add unit test for aiter custom all reduce and move all reduce to same file Signed-off-by: vllmellm --- .buildkite/test-amd.yaml | 1 + tests/compile/fusions_e2e/test_tp2_ar_rms.py | 21 ++- .../distributed/test_fusion_all_reduce.py | 78 +++++++-- vllm/_aiter_ops.py | 4 +- .../passes/fusion/allreduce_rms_fusion.py | 148 +++++++++++++++++ .../passes/fusion/rocm_aiter_fusion.py | 151 ------------------ vllm/compilation/passes/pass_manager.py | 4 +- vllm/config/vllm.py | 9 ++ .../device_communicators/cuda_communicator.py | 9 +- 9 files changed, 253 insertions(+), 172 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 6eda7bce9586..5e9a7570c34e 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -3246,6 +3246,7 @@ steps: - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py - pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py + - pytest -v -s tests/compile/passes/distributed/test_tp2_ar_rms.py::test_tp2_ar_rms_fusions #- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm # - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" # Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293 diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index ab4aefcaf79a..05ad62d63932 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -17,6 +17,8 @@ ) from .models import ( FLASHINFER_ATTN, + ROCM_AITER_UNIFIED_ATTN, + ROCM_ATTN, TRITON_ATTN, llama3_8b, llama3_8b_fp4, @@ -27,8 +29,6 @@ qwen3_a3b_fp8, ) -pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") - @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( @@ -36,10 +36,10 @@ # qwen3-fp8 should still fuse AR+rms even though group quant is not yet supported [llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8], ) -@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN]) @pytest.mark.parametrize("n_layers", [4]) @pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") def test_tp2_ar_rms_fp8_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -104,6 +104,7 @@ def test_tp2_ar_rms_fp8_fusions( @pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif(not is_blackwell(), reason="Blackwell required for fp4") +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") def test_tp2_ar_rms_fp4_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -155,10 +156,19 @@ def test_tp2_ar_rms_fp4_fusions( "model_name, matches_fn, model_kwargs, hf_overrides", [llama3_8b, qwen3_a3b], ) -@pytest.mark.parametrize("attn_backend", [TRITON_ATTN]) +@pytest.mark.parametrize( + "attn_backend", + [ + TRITON_ATTN, + FLASHINFER_ATTN, + ROCM_ATTN, + ROCM_AITER_UNIFIED_ATTN, + ], +) @pytest.mark.parametrize("n_layers", [4]) -@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) +@pytest.mark.parametrize("custom_ops", tuple(custom_ops_combos("rms_norm"))) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only test CUDA/ROCm") def test_tp2_ar_rms_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -199,4 +209,5 @@ def test_tp2_ar_rms_fusions( compilation_config, matches_check, tp_size=2, + use_aiter=current_platform.is_rocm(), ) diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index 4beac8c4fb53..cd854631eb7b 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -8,8 +8,12 @@ import vllm.envs as envs from tests.compile.backend import TestBackend from tests.utils import TestFP8Layer, has_module_attribute, multi_gpu_test +from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.compilation.passes.fusion.allreduce_rms_fusion import AllReduceFusionPass +from vllm.compilation.passes.fusion.allreduce_rms_fusion import ( + AllReduceFusionPass, + RocmAiterAllReduceFusionPass, +) from vllm.compilation.passes.utility.fix_functionalization import ( FixFunctionalizationPass, ) @@ -39,12 +43,13 @@ class TestAllReduceRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): + def __init__(self, hidden_size=16, token_num=16, eps=1e-6, use_aiter=False): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] + self.use_aiter = use_aiter def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -69,9 +74,13 @@ def forward(self, x): return y4 def ops_in_model_before(self): + if self.use_aiter: + return [rocm_aiter_ops.get_rmsnorm_fused_add_op()] return [torch.ops.vllm.all_reduce.default] def ops_in_model_after(self): + if self.use_aiter: + return [rocm_aiter_ops.get_fused_allreduce_rmsnorm_op()] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -185,12 +194,33 @@ def ops_in_model_before(self): @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "test_model, enable_quant_fp8_custom_op", + "test_model, enable_quant_fp8_custom_op, use_aiter", [ - (TestAllReduceRMSNormModel, False), - (TestAllReduceRMSNormStaticQuantFP8Model, True), - (TestAllReduceRMSNormStaticQuantFP8Model, False), - (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), + (TestAllReduceRMSNormModel, False, IS_AITER_FOUND), + pytest.param( + TestAllReduceRMSNormStaticQuantFP8Model, + True, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + ), + ), + pytest.param( + TestAllReduceRMSNormStaticQuantFP8Model, + False, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + ), + ), + pytest.param( + TestAllReduceFusedAddRMSNormStaticQuantFP4Model, + False, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + ), + ), ], ) @pytest.mark.parametrize("batch_size", [8]) @@ -201,9 +231,18 @@ def ops_in_model_before(self): @pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( - not find_spec("flashinfer") - or not has_module_attribute("flashinfer.comm", "allreduce_fusion") - or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"), + current_platform.is_rocm() and not IS_AITER_FOUND, + reason="aiter is not found", +) +@pytest.mark.skipif( + current_platform.is_cuda() + and ( + not find_spec("flashinfer") + or not has_module_attribute("flashinfer.comm", "allreduce_fusion") + or not has_module_attribute( + "flashinfer.comm", "create_allreduce_fusion_workspace" + ) + ), reason="flashinfer is not found or flashinfer " "is not compiled with allreduce_fusion", ) @@ -216,7 +255,14 @@ def test_all_reduce_fusion_pass_replace( enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter: bool, + monkeypatch: pytest.MonkeyPatch, ): + if use_aiter: + with monkeypatch.context() as m: + m.setenv("VLLM_ROCM_USE_AITER", str(use_aiter)) + rocm_aiter_ops.refresh_env_variables() + num_processes = 2 if ( test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model @@ -240,6 +286,8 @@ def run_torch_spawn(fn, nprocs): enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter, + monkeypatch, ), nprocs=nprocs, ) @@ -258,6 +306,8 @@ def all_reduce_fusion_pass_on_test_model( enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter: bool, + monkeypatch: pytest.MonkeyPatch, ): set_random_seed(0) @@ -304,7 +354,11 @@ def all_reduce_fusion_pass_on_test_model( ) with set_current_vllm_config(vllm_config): initialize_model_parallel(tensor_model_parallel_size=world_size) - all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + all_reduce_fusion_pass = ( + AllReduceFusionPass(vllm_config) + if use_aiter + else RocmAiterAllReduceFusionPass(vllm_config) + ) noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) @@ -314,7 +368,7 @@ def all_reduce_fusion_pass_on_test_model( ) token_num = batch_size * seq_len - model = test_model_cls(hidden_size, token_num) + model = test_model_cls(hidden_size, token_num, use_aiter=use_aiter) hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 6b1abd3674fe..602c7727a5e8 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1500,11 +1500,11 @@ def get_triton_rotary_embedding_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default @staticmethod - def get_fused_allreduce_rmsnorm() -> OpOverload: + def get_fused_allreduce_rmsnorm_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_fused_allreduce_rmsnorm.default @staticmethod - def get_fused_allreduce_add_rmsnorm() -> OpOverload: + def get_fused_allreduce_add_rmsnorm_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_fused_allreduce_add_rmsnorm.default @staticmethod diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 44dc3d67bb98..fe8587f552d6 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -10,6 +10,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce @@ -860,3 +861,150 @@ def __del__(self) -> None: return with contextlib.suppress(Exception): destroy_fi_ar_workspace() + + +class AiterAllreduceFusedRMSNormPattern: + FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() + + def __init__( + self, epsilon: float, dtype: torch.dtype, use_aiter_rmsnorm: bool = True + ) -> None: + self.dtype = dtype + self.epsilon = epsilon + self.rmsnorm_matcher = MatcherRMSNorm( + epsilon, match_rocm_aiter=use_aiter_rmsnorm + ) + + def get_inputs(self) -> list[torch.Tensor]: + input, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms = self.rmsnorm_matcher(allreduce_output, weight) + + return rms, allreduce_output + + def replacement( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + weight=weight, + epsilon=self.epsilon, + ) + return allreduce[0], allreduce[1] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class AiterAllreduceFusedAddRMSNormPattern: + FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_add_rmsnorm_op() + + def __init__( + self, epsilon: float, dtype: torch.dtype, use_aiter_rmsnorm: bool = True + ) -> None: + self.epsilon = epsilon + self.dtype = dtype + self.rmsnorm_matcher = MatcherFusedAddRMSNorm( + epsilon, match_rocm_aiter=use_aiter_rmsnorm + ) + + def get_inputs(self) -> list[torch.Tensor]: + input, residual, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) + + return rms, residual + + def replacement( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + weight=weight, + epsilon=self.epsilon, + residual=residual, + ) + return allreduce[0], allreduce[1] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class RocmAiterAllReduceFusionPass(VllmPatternMatcherPass): + def __init__(self, config: VllmConfig) -> None: + super().__init__(config) + self.disabled = True + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size <= 1: + logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.") + return + + if config.model_config is None: + logger.warning_once( + "AllReduce fusion pass is disabled for missing model_config." + ) + return + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_allreduce_rmsnorm_fusion_pass" + ) + + self.register_patterns() + self.dump_patterns(config, self.patterns) + + @enable_fake_mode + def register_patterns(self): + for epsilon in [1e-5, 1e-6]: + AiterAllreduceFusedRMSNormPattern( + epsilon, + self.model_dtype, + ).register(self.patterns) + + AiterAllreduceFusedAddRMSNormPattern( + epsilon, + self.model_dtype, + ).register(self.patterns) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + self.disabled = False + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + if self.disabled: + logger.debug("ROCmAiterAllReduceRMSNormFusionPass disabled") + return + + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def __del__(self) -> None: + if getattr(self, "disabled", True): + return + + def uuid(self) -> str: + return VllmInductorPass.hash_source( + self, + AiterAllreduceFusedRMSNormPattern, + AiterAllreduceFusedAddRMSNormPattern, + ) diff --git a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py index a7d157b597bf..59c94db5e812 100644 --- a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py +++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py @@ -9,10 +9,6 @@ import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig -from vllm.distributed import ( - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -504,150 +500,3 @@ def __call__(self, graph: torch.fx.Graph) -> None: def uuid(self) -> str: return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern) - - -class AiterAllreduceFusedRMSNormPattern: - FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm() - - def __init__( - self, epsilon: float, dtype: torch.dtype, use_aiter_rmsnorm: bool = True - ) -> None: - self.dtype = dtype - self.epsilon = epsilon - self.rmsnorm_matcher = MatcherRMSNorm( - epsilon, match_rocm_aiter=use_aiter_rmsnorm - ) - - def get_inputs(self) -> list[torch.Tensor]: - input, weight = self.rmsnorm_matcher.inputs() - - # input goes through allreduce first, always 16-bit - return [input.to(self.dtype), weight] - - def register(self, pm_pass: PatternMatcherPass) -> None: - def pattern( - input: torch.Tensor, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - allreduce_output = tensor_model_parallel_all_reduce(input) - rms = self.rmsnorm_matcher(allreduce_output, weight) - - return rms, allreduce_output - - def replacement( - input: torch.Tensor, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - allreduce = self.FUSED_AR_RMSNORM_OP( - input_=input, - weight=weight, - epsilon=self.epsilon, - ) - return allreduce[0], allreduce[1] - - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass - ) - - -class AiterAllreduceFusedAddRMSNormPattern: - FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_add_rmsnorm() - - def __init__( - self, epsilon: float, dtype: torch.dtype, use_aiter_rmsnorm: bool = True - ) -> None: - self.epsilon = epsilon - self.dtype = dtype - self.rmsnorm_matcher = MatcherFusedAddRMSNorm( - epsilon, match_rocm_aiter=use_aiter_rmsnorm - ) - - def get_inputs(self) -> list[torch.Tensor]: - input, residual, weight = self.rmsnorm_matcher.inputs() - - # input goes through allreduce first, always 16-bit - return [residual, input.to(self.dtype), weight] - - def register(self, pm_pass: PatternMatcherPass) -> None: - def pattern( - residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - allreduce_output = tensor_model_parallel_all_reduce(input) - rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) - - return rms, residual - - def replacement( - residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - allreduce = self.FUSED_AR_RMSNORM_OP( - input_=input, - weight=weight, - epsilon=self.epsilon, - residual=residual, - ) - return allreduce[0], allreduce[1] - - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass - ) - - -class RocmAiterAllReduceFusionPass(VllmPatternMatcherPass): - def __init__(self, config: VllmConfig) -> None: - super().__init__(config) - self.disabled = True - self.tp_size = get_tensor_model_parallel_world_size() - if self.tp_size <= 1: - logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.") - return - - if config.model_config is None: - logger.warning_once( - "AllReduce fusion pass is disabled for missing model_config." - ) - return - - self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="rocm_aiter_allreduce_rmsnorm_fusion_pass" - ) - - self.register_patterns() - self.dump_patterns(config, self.patterns) - - @enable_fake_mode - def register_patterns(self): - for epsilon in [1e-5, 1e-6]: - AiterAllreduceFusedRMSNormPattern( - epsilon, - self.model_dtype, - ).register(self.patterns) - - AiterAllreduceFusedAddRMSNormPattern( - epsilon, - self.model_dtype, - ).register(self.patterns) - - # WARNING: This is a hack to clear the pattern matcher cache - # and allow multiple values of epsilon. - torch._inductor.pattern_matcher._seen_patterns.clear() - - self.disabled = False - - @VllmInductorPass.time_and_log - def __call__(self, graph: fx.Graph): - if self.disabled: - logger.debug("ROCmAiterAllReduceRMSNormFusionPass disabled") - return - - self.matched_count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", self.matched_count) - - def __del__(self) -> None: - if getattr(self, "disabled", True): - return - - def uuid(self) -> str: - return VllmInductorPass.hash_source( - self, - AiterAllreduceFusedRMSNormPattern, - AiterAllreduceFusedAddRMSNormPattern, - ) diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 389aaf14582b..61189b8889d0 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -17,8 +17,10 @@ from .vllm_inductor_pass import VllmInductorPass if rocm_aiter_ops.is_enabled(): - from .fusion.rocm_aiter_fusion import ( + from .fusion.allreduce_rms_fusion import ( RocmAiterAllReduceFusionPass, + ) + from .fusion.rocm_aiter_fusion import ( RocmAiterRMSNormQuantFusionPass, RocmAiterSiluMulFp8GroupQuantFusionPass, RocmAiterTritonAddRMSNormPadFusionPass, diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index a7c43135388c..4e4084d23ece 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -115,6 +115,15 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer + if current_platform.is_rocm(): + from vllm._aiter_ops import rocm_aiter_ops + + return ( + rocm_aiter_ops.is_enabled() + and rocm_aiter_ops.is_rmsnorm_enabled() + and cfg.parallel_config.tensor_parallel_size > 1 + ) + return ( cfg.parallel_config.tensor_parallel_size > 1 and current_platform.is_cuda() diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 358a2241e6e4..5b3831df276c 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -6,6 +6,7 @@ from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm.config import get_cached_compilation_config from vllm.distributed.device_communicators.all_reduce_utils import ( should_nccl_symm_mem_allreduce, ) @@ -51,10 +52,16 @@ def __init__( from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed.parallel_state import _ENABLE_CUSTOM_ALL_REDUCE + fuse_allreduce_rms_enabled = ( + get_cached_compilation_config().pass_config.fuse_allreduce_rms + ) + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM use_flashinfer_allreduce = envs.VLLM_ALLREDUCE_USE_FLASHINFER - use_aiter_allreduce = rocm_aiter_ops.is_enabled() + use_aiter_allreduce = ( + rocm_aiter_ops.is_enabled() and fuse_allreduce_rms_enabled + ) self.use_custom_allreduce = use_custom_allreduce self.use_torch_symm_mem = use_torch_symm_mem From d5a9c813fe05e4bc50a1cde4d19d9163f1c2dca9 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 11 Mar 2026 03:06:08 +0000 Subject: [PATCH 03/10] fix unittest Signed-off-by: vllmellm --- .../distributed/test_fusion_all_reduce.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index cd854631eb7b..c1e33eb733b7 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -74,8 +74,6 @@ def forward(self, x): return y4 def ops_in_model_before(self): - if self.use_aiter: - return [rocm_aiter_ops.get_rmsnorm_fused_add_op()] return [torch.ops.vllm.all_reduce.default] def ops_in_model_after(self): @@ -203,6 +201,7 @@ def ops_in_model_before(self): False, marks=pytest.mark.skipif( current_platform.is_rocm(), + reason="Not supported on ROCm platform", ), ), pytest.param( @@ -211,6 +210,7 @@ def ops_in_model_before(self): False, marks=pytest.mark.skipif( current_platform.is_rocm(), + reason="Not supported on ROCm platform", ), ), pytest.param( @@ -219,6 +219,7 @@ def ops_in_model_before(self): False, marks=pytest.mark.skipif( current_platform.is_rocm(), + reason="Not supported on ROCm platform", ), ), ], @@ -382,6 +383,17 @@ def all_reduce_fusion_pass_on_test_model( assert all_reduce_fusion_pass.matched_count == 4, ( f"{all_reduce_fusion_pass.matched_count=}" ) - backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) + if use_aiter: + # aiter all_reduce is not a torch op, check by callable identity + import aiter as aiter_ops + + pre_nodes = [ + n + for n in backend.graph_pre_pass.nodes + if n.op == "call_function" and n.target is aiter_ops.all_reduce + ] + assert len(pre_nodes) > 0 + else: + backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) del all_reduce_fusion_pass From a7c422946ad16e2d29a80e263e628f789f305a2b Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 11 Mar 2026 03:37:44 +0000 Subject: [PATCH 04/10] add number of token threshold for aiter fused all reduce Signed-off-by: vllmellm --- vllm/_aiter_ops.py | 12 +++++++++++ .../passes/fusion/allreduce_rms_fusion.py | 20 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 602c7727a5e8..33b65ddd8c20 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1096,6 +1096,8 @@ class rocm_aiter_ops: # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + _AR_MAX_SIZE = 8192 * 1024 * 8 + @classmethod def refresh_env_variables(cls): """ @@ -1184,6 +1186,16 @@ def get_aiter_quant_type(quant_type_str: str): } return mapping.get(name) + @classmethod + def custom_allreduce_max_size(cls, world_size: int) -> int | None: + """Returns max allreduce size in bytes for aiter. None if unsupported.""" + if not cls.is_enabled(): + return None + AITER_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + if world_size not in AITER_SUPPORTED_WORLD_SIZES: + return None + return cls._AR_MAX_SIZE + @classmethod @if_aiter_supported def is_enabled(cls) -> bool: diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index fe8587f552d6..93194921cf79 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -963,6 +963,20 @@ def __init__(self, config: VllmConfig) -> None: ) return + hidden_dim = config.model_config.get_hidden_size() + max_size = rocm_aiter_ops.custom_allreduce_max_size(self.tp_size) + + if max_size is None: + # AITER doesn't support current world size + logger.warning( + "AITER allreduce fusion is not supported for world size %s" + " or max size is not provided", + self.tp_size, + ) + return + element_size = torch.tensor([], dtype=self.model_dtype).element_size() + self.max_token_num = max_size // (hidden_dim * element_size) + self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="rocm_aiter_allreduce_rmsnorm_fusion_pass" ) @@ -989,6 +1003,12 @@ def register_patterns(self): self.disabled = False + def is_applicable_for_range(self, compile_range: Range) -> bool: + if self.disabled: + logger.warning_once("AllReduce fusion pass is disabled.") + return False + return bool(compile_range.end <= self.max_token_num) + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): if self.disabled: From 253e511ea78b4625ecab3b398e50ec76ae8a7326 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sat, 14 Mar 2026 03:58:22 +0000 Subject: [PATCH 05/10] use aiter all reduce only for fusion pass Signed-off-by: vllmellm --- .../distributed/test_fusion_all_reduce.py | 13 +- vllm/_aiter_ops.py | 122 ++++----- .../passes/fusion/allreduce_rms_fusion.py | 37 ++- .../device_communicators/aiter_all_reduce.py | 236 ++++++++++++++++++ .../device_communicators/cuda_communicator.py | 43 +--- vllm/distributed/parallel_state.py | 15 +- 6 files changed, 334 insertions(+), 132 deletions(-) create mode 100644 vllm/distributed/device_communicators/aiter_all_reduce.py diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index c1e33eb733b7..d9efc99bc79a 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -383,17 +383,6 @@ def all_reduce_fusion_pass_on_test_model( assert all_reduce_fusion_pass.matched_count == 4, ( f"{all_reduce_fusion_pass.matched_count=}" ) - if use_aiter: - # aiter all_reduce is not a torch op, check by callable identity - import aiter as aiter_ops - - pre_nodes = [ - n - for n in backend.graph_pre_pass.nodes - if n.op == "call_function" and n.target is aiter_ops.all_reduce - ] - assert len(pre_nodes) > 0 - else: - backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) + backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) del all_reduce_fusion_pass diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 33b65ddd8c20..13daaae09b2c 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -635,80 +635,56 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake( def _rocm_aiter_fused_allreduce_rmsnorm_impl( input_: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, epsilon: float, ) -> tuple[torch.Tensor, torch.Tensor]: - from vllm.distributed import get_tp_group - - group = get_tp_group() - - device_comm = group.device_communicator - if device_comm is not None: - aiter_ar_comm = getattr(device_comm, "aiter_ar_comm", None) - - if ( - aiter_ar_comm is not None - and not aiter_ar_comm.disabled - and aiter_ar_comm.should_custom_ar(input_) - and hasattr(aiter_ar_comm, "custom_fused_ar_rms") - ): - total_bytes = input_.numel() * input_.element_size() - use_1stage = total_bytes <= 128 * 1024 - - out, res_out = aiter_ar_comm.custom_fused_ar_rms( - input_, - residual_inp=torch.zeros_like(input_), - weight=weight, - eps=epsilon, - use_1stage=use_1stage, - ) - return out, res_out - - # Fallback: launch all-reduce and rmsnorm separately - ar_out = group._all_reduce_out_place(input_) - - out = _rocm_aiter_rms_norm_impl(ar_out, weight, epsilon) - return ar_out, out - - -def _rocm_aiter_fused_allreduce_rmsnorm_fake( - input_: torch.Tensor, - weight: torch.Tensor, - epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(input_), torch.empty_like(input_) + from aiter import fused_allreduce_rmsnorm - -def _rocm_aiter_fused_allreduce_add_rmsnorm_impl( - input_: torch.Tensor, - weight: torch.Tensor, - epsilon: float, - residual: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: from vllm.distributed import get_tp_group + from vllm.distributed.device_communicators.aiter_all_reduce import ( + get_aiter_allreduce, + ) group = get_tp_group() + out = torch.empty_like(input_) + residual_out = torch.empty_like(input_) + + aiter_ar = get_aiter_allreduce() device_comm = group.device_communicator - if device_comm is not None: - aiter_ar_comm = getattr(device_comm, "aiter_ar_comm", None) - - if ( - aiter_ar_comm is not None - and not aiter_ar_comm.disabled - and aiter_ar_comm.should_custom_ar(input_) - and hasattr(aiter_ar_comm, "custom_fused_ar_rms") - ): - total_bytes = input_.numel() * input_.element_size() - use_1stage = total_bytes <= 128 * 1024 - out, res_out = aiter_ar_comm.custom_fused_ar_rms( - input_, - residual_inp=residual, - weight=weight, - eps=epsilon, - use_1stage=use_1stage, - ) - return out, res_out + + if ( + aiter_ar is not None + and device_comm is not None + and not aiter_ar.disabled + and aiter_ar.should_custom_ar(input_) + ): + if aiter_ar._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + # Static CUDA graph buffer — already registered, pass directly + reg_buffer = None + else: + # Warmup run before graph capture — return dummy + return out, residual_out + else: + # Eager mode — use aiter's pre-registered staging buffer + reg_buffer = aiter_ar.input_buffer + + total_bytes = input_.numel() * input_.element_size() + use_1stage = total_bytes <= 128 * 1024 + fused_allreduce_rmsnorm( + aiter_ar._ptr, + input_, + residual, + residual_out, + out, + weight, + epsilon, + reg_buffer, + use_1stage, + ) + return out, residual_out # Fallback: launch all-reduce and rmsnorm separately ar_out = group._all_reduce_out_place(input_) @@ -720,11 +696,11 @@ def _rocm_aiter_fused_allreduce_add_rmsnorm_impl( return out, residual_out -def _rocm_aiter_fused_allreduce_add_rmsnorm_fake( +def _rocm_aiter_fused_allreduce_rmsnorm_fake( input_: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, epsilon: float, - residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(input_), torch.empty_like(residual) @@ -1096,7 +1072,7 @@ class rocm_aiter_ops: # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM - _AR_MAX_SIZE = 8192 * 1024 * 8 + _AR_MAX_SIZE = 8192 * 1024 * 8 * 2 @classmethod def refresh_env_variables(cls): @@ -1459,12 +1435,6 @@ def register_ops_once() -> None: fake_impl=_rocm_aiter_fused_allreduce_rmsnorm_fake, ) - direct_register_custom_op( - op_name="rocm_aiter_fused_allreduce_add_rmsnorm", - op_func=_rocm_aiter_fused_allreduce_add_rmsnorm_impl, - fake_impl=_rocm_aiter_fused_allreduce_add_rmsnorm_fake, - ) - _OPS_REGISTERED = True @staticmethod @@ -1515,10 +1485,6 @@ def get_triton_rotary_embedding_op() -> OpOverload: def get_fused_allreduce_rmsnorm_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_fused_allreduce_rmsnorm.default - @staticmethod - def get_fused_allreduce_add_rmsnorm_op() -> OpOverload: - return torch.ops.vllm.rocm_aiter_fused_allreduce_add_rmsnorm.default - @staticmethod def rms_norm( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 93194921cf79..79b476506467 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -14,6 +14,11 @@ from vllm.config import VllmConfig from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators.aiter_all_reduce import ( + destroy_aiter_allreduce, + initialize_aiter_allreduce, +) +from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -893,8 +898,10 @@ def pattern( def replacement( input: torch.Tensor, weight: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: + residual = torch.empty_like(input) allreduce = self.FUSED_AR_RMSNORM_OP( input_=input, + residual=residual, weight=weight, epsilon=self.epsilon, ) @@ -906,7 +913,7 @@ def replacement( class AiterAllreduceFusedAddRMSNormPattern: - FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_add_rmsnorm_op() + FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() def __init__( self, epsilon: float, dtype: torch.dtype, use_aiter_rmsnorm: bool = True @@ -937,9 +944,9 @@ def replacement( ) -> tuple[torch.Tensor, torch.Tensor]: allreduce = self.FUSED_AR_RMSNORM_OP( input_=input, + residual=residual, weight=weight, epsilon=self.epsilon, - residual=residual, ) return allreduce[0], allreduce[1] @@ -963,20 +970,30 @@ def __init__(self, config: VllmConfig) -> None: ) return + device_comm = get_tp_group().device_communicator + if device_comm is None: + logger.warning_once("Device communicator is required.") + return + + ca_comm = getattr(device_comm, "ca_comm", None) + if ca_comm is None: + logger.warning_once("Custom Allreduce is required.") + return + self.ca_comm = ca_comm + + assert isinstance(ca_comm, CustomAllreduce) hidden_dim = config.model_config.get_hidden_size() max_size = rocm_aiter_ops.custom_allreduce_max_size(self.tp_size) - if max_size is None: - # AITER doesn't support current world size - logger.warning( - "AITER allreduce fusion is not supported for world size %s" - " or max size is not provided", - self.tp_size, - ) + logger.warning_once("max size is required.") return + element_size = torch.tensor([], dtype=self.model_dtype).element_size() self.max_token_num = max_size // (hidden_dim * element_size) + rank = get_tensor_model_parallel_rank() + group = get_tp_group().cpu_group + initialize_aiter_allreduce(self.tp_size, rank, max_size, group, self.device) self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="rocm_aiter_allreduce_rmsnorm_fusion_pass" ) @@ -1021,6 +1038,8 @@ def __call__(self, graph: fx.Graph): def __del__(self) -> None: if getattr(self, "disabled", True): return + with contextlib.suppress(Exception): + destroy_aiter_allreduce() def uuid(self) -> str: return VllmInductorPass.hash_source( diff --git a/vllm/distributed/device_communicators/aiter_all_reduce.py b/vllm/distributed/device_communicators/aiter_all_reduce.py new file mode 100644 index 000000000000..4e77a80ac3fd --- /dev/null +++ b/vllm/distributed/device_communicators/aiter_all_reduce.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# Global instance — one per worker process, initialized once +_aiter_allreduce: Optional["AiterAllreduce"] = None + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or ( + inp.storage().nbytes() - inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size() + ) + + +def get_aiter_allreduce() -> Optional["AiterAllreduce"]: + return _aiter_allreduce + + +def initialize_aiter_allreduce( + world_size: int, rank: int, max_size: int, group: ProcessGroup, device: torch.device +) -> None: + """Initialize the aiter fused AR+RMSNorm instance if not already done. + + Called by RocmAiterAllReduceFusionPass at model init time. + The instance owns the aiter C++ ptr, staging buffer, and capture state. + """ + global _aiter_allreduce + if _aiter_allreduce is not None: + return + try: + _aiter_allreduce = AiterAllreduce(world_size, rank, max_size, group, device) + logger.debug( + "Initialized aiter allreduce: world_size=%d, rank=%d, max_size=%d", + world_size, + rank, + max_size, + ) + except Exception as e: + logger.warning("Failed to initialize aiter allreduce: %s", e) + _aiter_allreduce = None + + +def destroy_aiter_allreduce() -> None: + global _aiter_allreduce + if _aiter_allreduce is not None: + _aiter_allreduce.close() + _aiter_allreduce = None + + +class AiterAllreduce: + """Self-contained instance for aiter's fused allreduce+RMSNorm kernel. + + Owns: + - aiter C++ custom_ar ptr (_ptr) + - local staging buffer (input_buffer) — IPC-registered with all ranks + - CUDA graph capture state (_IS_CAPTURING) + + Intentionally separate from vLLM's CustomAllreduce so that vLLM's CA + is used for regular (non-fused) allreduce while this object is used + exclusively for the fused AR+RMS path. + """ + + def __init__( + self, + world_size: int, + rank: int, + max_size: int, + group: ProcessGroup, + device: torch.device, + ) -> None: + import aiter as aiter_ops + + self.group = group + self.rank = rank + self.world_size = world_size + self.max_size = max_size + self.group = group + self._IS_CAPTURING = False + self._ptr = 0 + self.device = device + + fully_connected = True + if world_size > 2 and not fully_connected: + logger.warning( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly." + ) + return + self.disabled = False + self.fully_connected = fully_connected + # buffers memory are owned by this Python class and passed to C++ + # meta data composes of two parts: meta data for synchronization + # (256 bytes) and a temporary buffer for storing intermediate + # allreduce results. + # if current_platform.is_rocm(): + self.meta = aiter_ops.allocate_meta_buffer(aiter_ops.meta_size() + max_size) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.input_buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device) + # This is a pre-registered IPC buffer for output. In eager mode, kernel + # writes results to this buffer, then it's copied to the actual output + self.output_buffer = torch.empty( + max_size, dtype=torch.uint8, device=self.device + ) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self.max_size = max_size + self.world_size = world_size + handle = aiter_ops.get_meta_buffer_ipc_handle(self.meta) + shard_data = ( + handle, # ipc handle to base ptr + 0, # offset of base ptr + ) + handles, offsets = self._gather_ipc_meta(shard_data) + + self._ptr = aiter_ops.init_custom_ar( + self.meta, self.rank_data, handles, offsets, self.rank, self.fully_connected + ) + # Register both input and output buffers + self.register_input_buffer(self.input_buffer) + + def _get_ipc_meta(self, inp: torch.Tensor): + import aiter as aiter_ops + + # if current_platform.is_rocm(): + if 1: + # _share_cuda_() doesn't accept meta buffer not allocated from + # PyTorch cache allocator, use direct HIP call to get IPC handle + handle = aiter_ops.get_meta_buffer_ipc_handle(inp) + shard_data = ( + handle, # ipc handle to base ptr + 0, # offset of base ptr + ) + else: + data = inp.untyped_storage()._share_cuda_() + shard_data = ( + data[1], # ipc handle to base ptr + data[3], # offset of base ptr + ) + return self._gather_ipc_meta(shard_data) + + def _gather_ipc_meta(self, shard_data): + # Note: don't use `[[None]] * self.world_size` here + # because it will create a list of the same reference + all_data: list[list[Any]] = [[None] for i in range(self.world_size)] + all_data[self.rank][0] = shard_data + + ranks = dist.get_process_group_ranks(group=self.group) + ranks.sort() + for i, rank in enumerate(ranks): + dist.broadcast_object_list( + all_data[i], src=rank, group=self.group, device="cpu" + ) + + # we cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + + handles = [] + offsets = [] + for i in range(len(all_data)): + handles.append(all_data[i][0][0]) # type: ignore + offsets.append(all_data[i][0][1]) # type: ignore + return handles, offsets + + def register_input_buffer(self, inp: torch.Tensor): + import aiter as aiter_ops + + handles, offsets = self._get_ipc_meta(inp) + aiter_ops.register_input_buffer(self._ptr, inp, handles, offsets) + + def register_graph_buffers(self): + import aiter as aiter_ops + + handle, offset = aiter_ops.get_graph_buffer_ipc_meta(self._ptr) + handles, offsets = self._gather_ipc_meta((handle, offset)) + logger.info("Registering %d cuda graph addresses", len(offset)) + aiter_ops.register_graph_buffers(self._ptr, handles, offsets) + + def should_custom_ar(self, inp: torch.Tensor): + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.fully_connected: + return inp_size <= (self.max_size / 2) + return False + + @contextmanager + def capture(self): + """Context manager for CUDA graph capture. + + Sets _IS_CAPTURING so the fused op knows to use registered=True, + then calls register_graph_buffers after capture completes. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if self._ptr: + self.register_graph_buffers() + + def close(self) -> None: + if self._ptr: + try: + import aiter as aiter_ops + + aiter_ops.dispose(self._ptr) + except Exception: + pass + self._ptr = 0 diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 5b3831df276c..abedcbec7384 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -6,7 +6,6 @@ from torch.distributed import ProcessGroup import vllm.envs as envs -from vllm.config import get_cached_compilation_config from vllm.distributed.device_communicators.all_reduce_utils import ( should_nccl_symm_mem_allreduce, ) @@ -47,26 +46,16 @@ def __init__( use_custom_allreduce = False use_torch_symm_mem = False use_flashinfer_allreduce = False - use_aiter_allreduce = False else: - from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed.parallel_state import _ENABLE_CUSTOM_ALL_REDUCE - fuse_allreduce_rms_enabled = ( - get_cached_compilation_config().pass_config.fuse_allreduce_rms - ) - use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM use_flashinfer_allreduce = envs.VLLM_ALLREDUCE_USE_FLASHINFER - use_aiter_allreduce = ( - rocm_aiter_ops.is_enabled() and fuse_allreduce_rms_enabled - ) self.use_custom_allreduce = use_custom_allreduce self.use_torch_symm_mem = use_torch_symm_mem self.use_flashinfer_allreduce = use_flashinfer_allreduce - self.use_aiter_allreduce = use_aiter_allreduce # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce, @@ -107,16 +96,6 @@ def __init__( device=self.device, ) - if self.use_aiter_allreduce and self.world_size > 1: - from aiter.dist.device_communicators.custom_all_reduce import ( - CustomAllreduce as AiterCustomAllreduce, - ) - - self.aiter_ar_comm = AiterCustomAllreduce( - group=self.cpu_group, - device=self.device, - ) - if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -208,17 +187,17 @@ def all_reduce(self, input_): out = fi_ar_comm.all_reduce(input_) assert out is not None return out - aiter_ar_comm = self.aiter_ar_comm - if ( - aiter_ar_comm is not None - and not aiter_ar_comm.disabled - and aiter_ar_comm.should_custom_ar(input_) - ): - out = aiter_ar_comm.custom_all_reduce( - input_, use_new=True, open_fp8_quant=False - ) - assert out is not None - return out + # aiter_ar_comm = self.aiter_ar_comm + # if ( + # aiter_ar_comm is not None + # and not aiter_ar_comm.disabled + # and aiter_ar_comm.should_custom_ar(input_) + # ): + # out = aiter_ar_comm.custom_all_reduce( + # input_, use_new=True, open_fp8_quant=False + # ) + # assert out is not None + # return out ca_comm = self.ca_comm if ( ca_comm is not None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index fe48a6006cc5..03a74851c04a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -467,6 +467,7 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None # only cuda uses this function, # so we don't abstract it into the base class maybe_ca_context = nullcontext() + maybe_aiter_context = nullcontext() from vllm.distributed.device_communicators.cuda_communicator import ( CudaCommunicator, ) @@ -477,13 +478,25 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None if ca_comm is not None: maybe_ca_context = ca_comm.capture() # type: ignore + from vllm._aiter_ops import rocm_aiter_ops + + aiter_enabled = rocm_aiter_ops.is_enabled() + if aiter_enabled: + from vllm.distributed.device_communicators.aiter_all_reduce import ( + get_aiter_allreduce, + ) + + aiter_ar = get_aiter_allreduce() + if aiter_ar is not None: + maybe_aiter_context = aiter_ar.capture() # type: ignore + # ensure all initialization operations complete before attempting to # capture the graph on another stream curr_stream = torch.cuda.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) - with torch.cuda.stream(stream), maybe_ca_context: + with torch.cuda.stream(stream), maybe_ca_context, maybe_aiter_context: yield graph_capture_context def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: From d7bf6af35a010ea2932c7a95e94e4a012660ab5e Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 16 Mar 2026 02:33:37 +0000 Subject: [PATCH 06/10] clean code Signed-off-by: vllmellm --- .../passes/fusion/allreduce_rms_fusion.py | 4 +++- .../device_communicators/cuda_communicator.py | 12 ------------ 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 79b476506467..ba70de9ef2d8 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -989,7 +989,9 @@ def __init__(self, config: VllmConfig) -> None: return element_size = torch.tensor([], dtype=self.model_dtype).element_size() - self.max_token_num = max_size // (hidden_dim * element_size) + self.max_token_num = (max_size / 2) // (hidden_dim * element_size) + print(" max token num") + print(self.max_token_num) rank = get_tensor_model_parallel_rank() group = get_tp_group().cpu_group diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index abedcbec7384..6fab779d08d4 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -82,7 +82,6 @@ def __init__( self.qr_comm: QuickAllReduce | None = None self.symm_mem_comm: SymmMemCommunicator | None = None self.fi_ar_comm: FlashInferAllReduce | None = None - self.aiter_ar_comm = None if use_torch_symm_mem and current_platform.is_cuda(): self.symm_mem_comm = SymmMemCommunicator( @@ -187,17 +186,6 @@ def all_reduce(self, input_): out = fi_ar_comm.all_reduce(input_) assert out is not None return out - # aiter_ar_comm = self.aiter_ar_comm - # if ( - # aiter_ar_comm is not None - # and not aiter_ar_comm.disabled - # and aiter_ar_comm.should_custom_ar(input_) - # ): - # out = aiter_ar_comm.custom_all_reduce( - # input_, use_new=True, open_fp8_quant=False - # ) - # assert out is not None - # return out ca_comm = self.ca_comm if ( ca_comm is not None From 45f638fc674986f77e72e5ce0bcbe28fd91edcfa Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 17 Mar 2026 09:07:41 +0000 Subject: [PATCH 07/10] bugfixes Signed-off-by: vllmellm --- vllm/_aiter_ops.py | 30 +++++++++++-------- .../passes/fusion/allreduce_rms_fusion.py | 14 ++++----- .../device_communicators/aiter_all_reduce.py | 29 ++++++++++-------- 3 files changed, 39 insertions(+), 34 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 13daaae09b2c..05babf6c0ef2 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -672,7 +672,23 @@ def _rocm_aiter_fused_allreduce_rmsnorm_impl( reg_buffer = aiter_ar.input_buffer total_bytes = input_.numel() * input_.element_size() - use_1stage = total_bytes <= 128 * 1024 + hidden_dim = input_.shape[-1] + token_num = input_.shape[0] + hidden_ok = hidden_dim in (512, 1024, 2048, 4096) + token_ok = token_num <= 80 + world_size = aiter_ar.world_size + full_nvlink = aiter_ar.fully_connected + + if world_size == 2: + size_ok = True + elif full_nvlink and world_size <= 4: + size_ok = total_bytes < 160 * 1024 + elif full_nvlink and world_size <= 8: + size_ok = total_bytes < 80 * 1024 + else: + size_ok = False + + use_1stage = hidden_ok and token_ok and size_ok fused_allreduce_rmsnorm( aiter_ar._ptr, input_, @@ -1072,8 +1088,6 @@ class rocm_aiter_ops: # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM - _AR_MAX_SIZE = 8192 * 1024 * 8 * 2 - @classmethod def refresh_env_variables(cls): """ @@ -1162,16 +1176,6 @@ def get_aiter_quant_type(quant_type_str: str): } return mapping.get(name) - @classmethod - def custom_allreduce_max_size(cls, world_size: int) -> int | None: - """Returns max allreduce size in bytes for aiter. None if unsupported.""" - if not cls.is_enabled(): - return None - AITER_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] - if world_size not in AITER_SUPPORTED_WORLD_SIZES: - return None - return cls._AR_MAX_SIZE - @classmethod @if_aiter_supported def is_enabled(cls) -> bool: diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index ba70de9ef2d8..3ada04a19fd8 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -15,6 +15,7 @@ from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.device_communicators.aiter_all_reduce import ( + _AR_MAX_SIZE, destroy_aiter_allreduce, initialize_aiter_allreduce, ) @@ -983,19 +984,16 @@ def __init__(self, config: VllmConfig) -> None: assert isinstance(ca_comm, CustomAllreduce) hidden_dim = config.model_config.get_hidden_size() - max_size = rocm_aiter_ops.custom_allreduce_max_size(self.tp_size) - if max_size is None: - logger.warning_once("max size is required.") - return element_size = torch.tensor([], dtype=self.model_dtype).element_size() - self.max_token_num = (max_size / 2) // (hidden_dim * element_size) - print(" max token num") - print(self.max_token_num) + self.max_token_num = (_AR_MAX_SIZE / 2) // (hidden_dim * element_size) + self.max_token_num = min( + self.max_token_num, config.scheduler_config.max_num_batched_tokens + ) rank = get_tensor_model_parallel_rank() group = get_tp_group().cpu_group - initialize_aiter_allreduce(self.tp_size, rank, max_size, group, self.device) + initialize_aiter_allreduce(self.tp_size, rank, group, self.device) self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="rocm_aiter_allreduce_rmsnorm_fusion_pass" ) diff --git a/vllm/distributed/device_communicators/aiter_all_reduce.py b/vllm/distributed/device_communicators/aiter_all_reduce.py index 4e77a80ac3fd..662d8f88a4ba 100644 --- a/vllm/distributed/device_communicators/aiter_all_reduce.py +++ b/vllm/distributed/device_communicators/aiter_all_reduce.py @@ -15,6 +15,8 @@ # Global instance — one per worker process, initialized once _aiter_allreduce: Optional["AiterAllreduce"] = None +_AR_MAX_SIZE = 8192 * 1024 * 8 * 2 + def is_weak_contiguous(inp: torch.Tensor): return inp.is_contiguous() or ( @@ -28,7 +30,7 @@ def get_aiter_allreduce() -> Optional["AiterAllreduce"]: def initialize_aiter_allreduce( - world_size: int, rank: int, max_size: int, group: ProcessGroup, device: torch.device + world_size: int, rank: int, group: ProcessGroup, device: torch.device ) -> None: """Initialize the aiter fused AR+RMSNorm instance if not already done. @@ -39,12 +41,11 @@ def initialize_aiter_allreduce( if _aiter_allreduce is not None: return try: - _aiter_allreduce = AiterAllreduce(world_size, rank, max_size, group, device) + _aiter_allreduce = AiterAllreduce(world_size, rank, group, device) logger.debug( - "Initialized aiter allreduce: world_size=%d, rank=%d, max_size=%d", + "Initialized aiter allreduce: world_size=%d, rank=%d", world_size, rank, - max_size, ) except Exception as e: logger.warning("Failed to initialize aiter allreduce: %s", e) @@ -75,7 +76,6 @@ def __init__( self, world_size: int, rank: int, - max_size: int, group: ProcessGroup, device: torch.device, ) -> None: @@ -84,7 +84,8 @@ def __init__( self.group = group self.rank = rank self.world_size = world_size - self.max_size = max_size + + self.max_size = _AR_MAX_SIZE self.group = group self._IS_CAPTURING = False self._ptr = 0 @@ -105,14 +106,14 @@ def __init__( # (256 bytes) and a temporary buffer for storing intermediate # allreduce results. # if current_platform.is_rocm(): - self.meta = aiter_ops.allocate_meta_buffer(aiter_ops.meta_size() + max_size) + self.meta_size = aiter_ops.meta_size() + self.meta = aiter_ops.allocate_meta_buffer( + aiter_ops.meta_size() + self.max_size + ) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed - self.input_buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device) - # This is a pre-registered IPC buffer for output. In eager mode, kernel - # writes results to this buffer, then it's copied to the actual output - self.output_buffer = torch.empty( - max_size, dtype=torch.uint8, device=self.device + self.input_buffer = torch.empty( + self.max_size, dtype=torch.uint8, device=self.device ) # This is a buffer for storing the tuples of pointers pointing to # IPC buffers from all ranks. Each registered tuple has size of @@ -122,7 +123,6 @@ def __init__( self.rank_data = torch.empty( 8 * 1024 * 1024, dtype=torch.uint8, device=self.device ) - self.max_size = max_size self.world_size = world_size handle = aiter_ops.get_meta_buffer_ipc_handle(self.meta) shard_data = ( @@ -225,6 +225,9 @@ def capture(self): if self._ptr: self.register_graph_buffers() + def __del__(self): + self.close() + def close(self) -> None: if self._ptr: try: From aa3b50b1740f08ec92e9e89c43d13d4867ba2afa Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 19 Mar 2026 08:11:16 +0000 Subject: [PATCH 08/10] enable allreduce + rmsnorm using rocm_aiter_ops only Signed-off-by: vllmellm --- vllm/_aiter_ops.py | 146 ++++++----- .../passes/fusion/allreduce_rms_fusion.py | 30 ++- .../device_communicators/aiter_all_reduce.py | 239 ------------------ vllm/distributed/parallel_state.py | 6 +- 4 files changed, 95 insertions(+), 326 deletions(-) delete mode 100644 vllm/distributed/device_communicators/aiter_all_reduce.py diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 05babf6c0ef2..b54f727eca0e 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -2,9 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from collections.abc import Callable +from contextlib import contextmanager +from typing import Protocol import torch from torch._ops import OpOverload +from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.platforms import current_platform @@ -33,6 +36,25 @@ def is_aiter_found() -> bool: IS_AITER_FOUND = is_aiter_found() +class AiterCustomAllreduceProto(Protocol): + max_size: int + world_size: int + fully_connected: bool + + @contextmanager + def capture(self): ... + def close(self) -> None: ... + def custom_fused_ar_rms( + self, + input: torch.Tensor, + residual_inp: torch.Tensor, + weight: torch.Tensor, + eps: float, + use_1stage: bool, + ) -> tuple[torch.Tensor, torch.Tensor] | None: ... + def should_custom_ar(self, inp: torch.Tensor) -> bool: ... + + def is_aiter_found_and_supported() -> bool: """Check if AITER library is available and platform supports it. @@ -639,77 +661,30 @@ def _rocm_aiter_fused_allreduce_rmsnorm_impl( weight: torch.Tensor, epsilon: float, ) -> tuple[torch.Tensor, torch.Tensor]: - from aiter import fused_allreduce_rmsnorm - - from vllm.distributed import get_tp_group - from vllm.distributed.device_communicators.aiter_all_reduce import ( - get_aiter_allreduce, - ) - - group = get_tp_group() - - out = torch.empty_like(input_) - residual_out = torch.empty_like(input_) - - aiter_ar = get_aiter_allreduce() - device_comm = group.device_communicator - - if ( - aiter_ar is not None - and device_comm is not None - and not aiter_ar.disabled - and aiter_ar.should_custom_ar(input_) - ): - if aiter_ar._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - # Static CUDA graph buffer — already registered, pass directly - reg_buffer = None - else: - # Warmup run before graph capture — return dummy - return out, residual_out - else: - # Eager mode — use aiter's pre-registered staging buffer - reg_buffer = aiter_ar.input_buffer - - total_bytes = input_.numel() * input_.element_size() - hidden_dim = input_.shape[-1] - token_num = input_.shape[0] - hidden_ok = hidden_dim in (512, 1024, 2048, 4096) - token_ok = token_num <= 80 - world_size = aiter_ar.world_size - full_nvlink = aiter_ar.fully_connected - - if world_size == 2: - size_ok = True - elif full_nvlink and world_size <= 4: - size_ok = total_bytes < 160 * 1024 - elif full_nvlink and world_size <= 8: - size_ok = total_bytes < 80 * 1024 - else: - size_ok = False - - use_1stage = hidden_ok and token_ok and size_ok - fused_allreduce_rmsnorm( - aiter_ar._ptr, - input_, - residual, - residual_out, - out, - weight, - epsilon, - reg_buffer, - use_1stage, - ) - return out, residual_out - - # Fallback: launch all-reduce and rmsnorm separately - ar_out = group._all_reduce_out_place(input_) - - out, residual_out = _rocm_aiter_rmsnorm2d_fwd_with_add_impl( - ar_out, residual, weight, epsilon - ) + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + assert aiter_ar is not None, "aiter allreduce must be initialized" + + total_bytes = input_.numel() * input_.element_size() + hidden_dim = input_.shape[-1] + token_num = input_.shape[0] + hidden_ok = hidden_dim in (512, 1024, 2048, 4096) + token_ok = token_num <= 80 + world_size = aiter_ar.world_size + full_nvlink = aiter_ar.fully_connected + + if world_size == 2: + size_ok = True + elif full_nvlink and world_size <= 4: + size_ok = total_bytes < 160 * 1024 + elif full_nvlink and world_size <= 8: + size_ok = total_bytes < 80 * 1024 + else: + size_ok = False - return out, residual_out + use_1stage = hidden_ok and token_ok and size_ok + result = aiter_ar.custom_fused_ar_rms(input_, residual, weight, epsilon, use_1stage) + assert result is not None + return result[0], result[1] def _rocm_aiter_fused_allreduce_rmsnorm_fake( @@ -1088,6 +1063,9 @@ class rocm_aiter_ops: # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + _ALL_REDUCE_MAX_SIZE: int = 8192 * 1024 * 8 * 2 + _CUSTOM_ALL_REDUCE: AiterCustomAllreduceProto | None = None + @classmethod def refresh_env_variables(cls): """ @@ -1255,6 +1233,34 @@ def is_triton_rotary_embed_enabled(cls) -> bool: def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM + @classmethod + @if_aiter_supported + def initialize_aiter_allreduce( + cls, group: ProcessGroup, device: torch.device + ) -> None: + try: + from aiter.dist.device_communicators.custom_all_reduce import ( + CustomAllreduce as AiterCustomAllreduce, + ) + + cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device) + except Exception: + cls._CUSTOM_ALL_REDUCE = None + + @classmethod + def get_aiter_allreduce(cls) -> AiterCustomAllreduceProto | None: + return cls._CUSTOM_ALL_REDUCE + + @classmethod + def destroy_aiter_allreduce(cls) -> None: + if cls._CUSTOM_ALL_REDUCE is not None: + cls._CUSTOM_ALL_REDUCE.close() + cls._CUSTOM_ALL_REDUCE = None + + @classmethod + def get_aiter_allreduce_max_size(cls) -> int: + return cls._ALL_REDUCE_MAX_SIZE + @staticmethod @if_aiter_supported def register_ops_once() -> None: diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 3ada04a19fd8..0afc1e44691f 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -14,11 +14,6 @@ from vllm.config import VllmConfig from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce -from vllm.distributed.device_communicators.aiter_all_reduce import ( - _AR_MAX_SIZE, - destroy_aiter_allreduce, - initialize_aiter_allreduce, -) from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -983,17 +978,28 @@ def __init__(self, config: VllmConfig) -> None: self.ca_comm = ca_comm assert isinstance(ca_comm, CustomAllreduce) - hidden_dim = config.model_config.get_hidden_size() + group = get_tp_group().cpu_group + rocm_aiter_ops.initialize_aiter_allreduce(group, self.device) + hidden_dim = config.model_config.get_hidden_size() element_size = torch.tensor([], dtype=self.model_dtype).element_size() - self.max_token_num = (_AR_MAX_SIZE / 2) // (hidden_dim * element_size) + max_size = rocm_aiter_ops.get_aiter_allreduce_max_size() + if max_size is None: + logger.warning("AITER allreduce fusion must be initialized") + return + + max_token_num = (max_size / 2) // (hidden_dim * element_size) self.max_token_num = min( - self.max_token_num, config.scheduler_config.max_num_batched_tokens + max_token_num, + config.scheduler_config.max_num_batched_tokens, + ) + + logger.debug_once( + f"AITER stage-1 fused allreduce max tokens: {self.max_token_num} " + f"(tp={self.tp_size}, hidden={hidden_dim})", + scope="global", ) - rank = get_tensor_model_parallel_rank() - group = get_tp_group().cpu_group - initialize_aiter_allreduce(self.tp_size, rank, group, self.device) self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="rocm_aiter_allreduce_rmsnorm_fusion_pass" ) @@ -1039,7 +1045,7 @@ def __del__(self) -> None: if getattr(self, "disabled", True): return with contextlib.suppress(Exception): - destroy_aiter_allreduce() + rocm_aiter_ops.destroy_aiter_allreduce() def uuid(self) -> str: return VllmInductorPass.hash_source( diff --git a/vllm/distributed/device_communicators/aiter_all_reduce.py b/vllm/distributed/device_communicators/aiter_all_reduce.py deleted file mode 100644 index 662d8f88a4ba..000000000000 --- a/vllm/distributed/device_communicators/aiter_all_reduce.py +++ /dev/null @@ -1,239 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from contextlib import contextmanager -from typing import Any, Optional - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup - -from vllm.logger import init_logger - -logger = init_logger(__name__) - -# Global instance — one per worker process, initialized once -_aiter_allreduce: Optional["AiterAllreduce"] = None - -_AR_MAX_SIZE = 8192 * 1024 * 8 * 2 - - -def is_weak_contiguous(inp: torch.Tensor): - return inp.is_contiguous() or ( - inp.storage().nbytes() - inp.storage_offset() * inp.element_size() - == inp.numel() * inp.element_size() - ) - - -def get_aiter_allreduce() -> Optional["AiterAllreduce"]: - return _aiter_allreduce - - -def initialize_aiter_allreduce( - world_size: int, rank: int, group: ProcessGroup, device: torch.device -) -> None: - """Initialize the aiter fused AR+RMSNorm instance if not already done. - - Called by RocmAiterAllReduceFusionPass at model init time. - The instance owns the aiter C++ ptr, staging buffer, and capture state. - """ - global _aiter_allreduce - if _aiter_allreduce is not None: - return - try: - _aiter_allreduce = AiterAllreduce(world_size, rank, group, device) - logger.debug( - "Initialized aiter allreduce: world_size=%d, rank=%d", - world_size, - rank, - ) - except Exception as e: - logger.warning("Failed to initialize aiter allreduce: %s", e) - _aiter_allreduce = None - - -def destroy_aiter_allreduce() -> None: - global _aiter_allreduce - if _aiter_allreduce is not None: - _aiter_allreduce.close() - _aiter_allreduce = None - - -class AiterAllreduce: - """Self-contained instance for aiter's fused allreduce+RMSNorm kernel. - - Owns: - - aiter C++ custom_ar ptr (_ptr) - - local staging buffer (input_buffer) — IPC-registered with all ranks - - CUDA graph capture state (_IS_CAPTURING) - - Intentionally separate from vLLM's CustomAllreduce so that vLLM's CA - is used for regular (non-fused) allreduce while this object is used - exclusively for the fused AR+RMS path. - """ - - def __init__( - self, - world_size: int, - rank: int, - group: ProcessGroup, - device: torch.device, - ) -> None: - import aiter as aiter_ops - - self.group = group - self.rank = rank - self.world_size = world_size - - self.max_size = _AR_MAX_SIZE - self.group = group - self._IS_CAPTURING = False - self._ptr = 0 - self.device = device - - fully_connected = True - if world_size > 2 and not fully_connected: - logger.warning( - "Custom allreduce is disabled because it's not supported on" - " more than two PCIe-only GPUs. To silence this warning, " - "specify disable_custom_all_reduce=True explicitly." - ) - return - self.disabled = False - self.fully_connected = fully_connected - # buffers memory are owned by this Python class and passed to C++ - # meta data composes of two parts: meta data for synchronization - # (256 bytes) and a temporary buffer for storing intermediate - # allreduce results. - # if current_platform.is_rocm(): - self.meta_size = aiter_ops.meta_size() - self.meta = aiter_ops.allocate_meta_buffer( - aiter_ops.meta_size() + self.max_size - ) - # This is a pre-registered IPC buffer. In eager mode, input tensors - # are first copied into this buffer before allreduce is performed - self.input_buffer = torch.empty( - self.max_size, dtype=torch.uint8, device=self.device - ) - # This is a buffer for storing the tuples of pointers pointing to - # IPC buffers from all ranks. Each registered tuple has size of - # 8*world_size bytes where world_size is at most 8. Allocating 8MB - # is enough for 131072 such tuples. The largest model I've seen only - # needs less than 10000 of registered tuples. - self.rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=self.device - ) - self.world_size = world_size - handle = aiter_ops.get_meta_buffer_ipc_handle(self.meta) - shard_data = ( - handle, # ipc handle to base ptr - 0, # offset of base ptr - ) - handles, offsets = self._gather_ipc_meta(shard_data) - - self._ptr = aiter_ops.init_custom_ar( - self.meta, self.rank_data, handles, offsets, self.rank, self.fully_connected - ) - # Register both input and output buffers - self.register_input_buffer(self.input_buffer) - - def _get_ipc_meta(self, inp: torch.Tensor): - import aiter as aiter_ops - - # if current_platform.is_rocm(): - if 1: - # _share_cuda_() doesn't accept meta buffer not allocated from - # PyTorch cache allocator, use direct HIP call to get IPC handle - handle = aiter_ops.get_meta_buffer_ipc_handle(inp) - shard_data = ( - handle, # ipc handle to base ptr - 0, # offset of base ptr - ) - else: - data = inp.untyped_storage()._share_cuda_() - shard_data = ( - data[1], # ipc handle to base ptr - data[3], # offset of base ptr - ) - return self._gather_ipc_meta(shard_data) - - def _gather_ipc_meta(self, shard_data): - # Note: don't use `[[None]] * self.world_size` here - # because it will create a list of the same reference - all_data: list[list[Any]] = [[None] for i in range(self.world_size)] - all_data[self.rank][0] = shard_data - - ranks = dist.get_process_group_ranks(group=self.group) - ranks.sort() - for i, rank in enumerate(ranks): - dist.broadcast_object_list( - all_data[i], src=rank, group=self.group, device="cpu" - ) - - # we cannot directly use `dist.all_gather_object` here - # because it is incompatible with `gloo` backend under inference mode. - # see https://github.com/pytorch/pytorch/issues/126032 for details. - - handles = [] - offsets = [] - for i in range(len(all_data)): - handles.append(all_data[i][0][0]) # type: ignore - offsets.append(all_data[i][0][1]) # type: ignore - return handles, offsets - - def register_input_buffer(self, inp: torch.Tensor): - import aiter as aiter_ops - - handles, offsets = self._get_ipc_meta(inp) - aiter_ops.register_input_buffer(self._ptr, inp, handles, offsets) - - def register_graph_buffers(self): - import aiter as aiter_ops - - handle, offset = aiter_ops.get_graph_buffer_ipc_meta(self._ptr) - handles, offsets = self._gather_ipc_meta((handle, offset)) - logger.info("Registering %d cuda graph addresses", len(offset)) - aiter_ops.register_graph_buffers(self._ptr, handles, offsets) - - def should_custom_ar(self, inp: torch.Tensor): - if self.disabled: - return False - inp_size = inp.numel() * inp.element_size() - # custom allreduce requires input byte size to be multiples of 16 - if inp_size % 16 != 0: - return False - if not is_weak_contiguous(inp): - return False - # for 4 or more non NVLink-capable GPUs, custom allreduce provides - # little performance improvement over NCCL. - if self.world_size == 2 or self.fully_connected: - return inp_size <= (self.max_size / 2) - return False - - @contextmanager - def capture(self): - """Context manager for CUDA graph capture. - - Sets _IS_CAPTURING so the fused op knows to use registered=True, - then calls register_graph_buffers after capture completes. - """ - try: - self._IS_CAPTURING = True - yield - finally: - self._IS_CAPTURING = False - if self._ptr: - self.register_graph_buffers() - - def __del__(self): - self.close() - - def close(self) -> None: - if self._ptr: - try: - import aiter as aiter_ops - - aiter_ops.dispose(self._ptr) - except Exception: - pass - self._ptr = 0 diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 03a74851c04a..717f6d6120d1 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -482,11 +482,7 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None aiter_enabled = rocm_aiter_ops.is_enabled() if aiter_enabled: - from vllm.distributed.device_communicators.aiter_all_reduce import ( - get_aiter_allreduce, - ) - - aiter_ar = get_aiter_allreduce() + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() if aiter_ar is not None: maybe_aiter_context = aiter_ar.capture() # type: ignore From 31d5e0cee02b46ccccadf6d6bb1d4bdafd5d83aa Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 20 Mar 2026 12:08:35 +0000 Subject: [PATCH 09/10] remove unnecessary log Signed-off-by: vllmellm --- vllm/compilation/passes/fusion/allreduce_rms_fusion.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 2d0149c3da55..e27eef39db7b 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -994,12 +994,6 @@ def __init__(self, config: VllmConfig) -> None: config.scheduler_config.max_num_batched_tokens, ) - logger.debug_once( - f"AITER stage-1 fused allreduce max tokens: {self.max_token_num} " - f"(tp={self.tp_size}, hidden={hidden_dim})", - scope="global", - ) - self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="rocm_aiter_allreduce_rmsnorm_fusion_pass" ) From e9a4a877dcc7a9a4454c5d09516dc51e535e7e53 Mon Sep 17 00:00:00 2001 From: Rita Brugarolas Brufau Date: Wed, 1 Apr 2026 17:18:28 -0500 Subject: [PATCH 10/10] [ROCm] Fix allreduce+rmsnorm fusion for DeepSeek MoE layers DeepSeek V2/R1 MoE layers insert a no-op view (final_hidden_states.view(num_tokens, hidden_dim)) between all_reduce and rmsnorm. This breaks the pattern matcher in RocmAiterAllReduceFusionPass because it expects all_reduce -> rmsnorm as adjacent nodes in the FX graph. Add _bypass_noop_views_after_allreduce() to RocmAiterAllReduceFusionPass that removes identity-shaped view/reshape nodes between all_reduce and its consumers before pattern matching runs. This allows the standard all_reduce -> rmsnorm patterns to match for all layers, including MoE. The fix is purely at the compiler pass level -- no model definition changes required. Tested with DeepSeek-R1-0528 FP8, TP=8 on MI355X (gfx950): - All all_reduce -> rmsnorm pairs are now fused (attention + MoE) - rocm_aiter_fused_allreduce_rmsnorm kernel confirmed in Inductor output - Server starts and serves inference without errors Signed-off-by: Rita Brugarolas Brufau --- .../passes/fusion/allreduce_rms_fusion.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index e27eef39db7b..4f851f5ec57e 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -1032,9 +1032,64 @@ def __call__(self, graph: fx.Graph): logger.debug("ROCmAiterAllReduceRMSNormFusionPass disabled") return + self._bypass_noop_views_after_allreduce(graph) self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) + def _bypass_noop_views_after_allreduce(self, graph: fx.Graph) -> None: + """Remove no-op view/reshape nodes sitting between all_reduce and + rmsnorm so the pattern matcher can fuse them. + + Some models (e.g. DeepSeek MoE) insert + ``final_hidden_states.view(num_tokens, hidden_dim)`` after + all_reduce. The view is identity-shaped but creates an intermediate + node that prevents the ``all_reduce -> rmsnorm`` pattern from + matching. + """ + from torch.fx.experimental.symbolic_shapes import statically_known_true + + count = 0 + for node in list(graph.nodes): + if node.op != "call_function" or node.target not in ( + torch.ops.aten.view.default, + torch.ops.aten.reshape.default, + ): + continue + + input_node = node.args[0] + if not isinstance(input_node, fx.Node): + continue + + if ( + input_node.op != "call_function" + or input_node.target != torch.ops.vllm.all_reduce.default + ): + continue + + input_val = input_node.meta.get("val") + output_val = node.meta.get("val") + if input_val is None or output_val is None: + continue + + in_shape = input_val.shape + out_shape = output_val.shape + if len(in_shape) != len(out_shape): + continue + if not all( + statically_known_true(s == o) + for s, o in zip(in_shape, out_shape) + ): + continue + + node.replace_all_uses_with(input_node) + graph.erase_node(node) + count += 1 + + if count: + logger.debug( + "Bypassed %s no-op view(s) after all_reduce", count + ) + def __del__(self) -> None: if getattr(self, "disabled", True): return