diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index c7e9e13f2fd3..296bb12d146d 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -3687,6 +3687,7 @@ steps: - pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py # TODO: this test is not supported on ROCm, there are aiter kernels for this. # - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py + - pytest -v -s tests/compile/passes/distributed/test_tp2_ar_rms.py::test_tp2_ar_rms_fusions # - pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm # - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index 4b0a0859b023..fa031c207523 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -19,6 +19,8 @@ FLASHINFER_ATTN, FLASHINFER_MLA_ATTN, FLASHMLA_SPARSE_ATTN, + ROCM_AITER_UNIFIED_ATTN, + ROCM_ATTN, TRITON_ATTN, deepseek_coder_v2_lite_fp8, deepseek_v3_fp8, @@ -33,8 +35,6 @@ qwen3_a3b_fp8, ) -pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") - @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( @@ -54,6 +54,7 @@ @pytest.mark.parametrize("n_layers", [4]) @pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") def test_tp2_ar_rms_fp8_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -123,6 +124,7 @@ def test_tp2_ar_rms_fp8_fusions( @pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) @pytest.mark.skipif(not is_blackwell(), reason="Blackwell required for fp4") +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") def test_tp2_ar_rms_fp4_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -175,10 +177,19 @@ def test_tp2_ar_rms_fp4_fusions( "model_name, matches_fn, model_kwargs, hf_overrides", [llama3_8b, qwen3_a3b, gpt_oss_20b], ) -@pytest.mark.parametrize("attn_backend", [TRITON_ATTN]) +@pytest.mark.parametrize( + "attn_backend", + [ + TRITON_ATTN, + FLASHINFER_ATTN, + ROCM_ATTN, + ROCM_AITER_UNIFIED_ATTN, + ], +) @pytest.mark.parametrize("n_layers", [4]) -@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) +@pytest.mark.parametrize("custom_ops", tuple(custom_ops_combos("rms_norm"))) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only test CUDA/ROCm") def test_tp2_ar_rms_fusions( model_name: str, matches_fn: Callable[[int], Matches], @@ -220,4 +231,5 @@ def test_tp2_ar_rms_fusions( compilation_config, matches_check, tp_size=2, + use_aiter=current_platform.is_rocm(), ) diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index a50d3ca8e3e1..c6b79e7db9af 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -8,8 +8,12 @@ import vllm.envs as envs from tests.compile.backend import TestBackend from tests.utils import TestFP8Layer, has_module_attribute, multi_gpu_test +from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.compilation.passes.fusion.allreduce_rms_fusion import AllReduceFusionPass +from vllm.compilation.passes.fusion.allreduce_rms_fusion import ( + AllReduceFusionPass, + RocmAiterAllReduceFusionPass, +) from vllm.compilation.passes.utility.fix_functionalization import ( FixFunctionalizationPass, ) @@ -40,13 +44,19 @@ class TestAllReduceRMSNormModel(torch.nn.Module): def __init__( - self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16 + self, + hidden_size=16, + token_num=16, + eps=1e-6, + dtype: torch.dtype = torch.float16, + use_aiter: bool = False, ): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] + self.use_aiter = use_aiter def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -74,6 +84,8 @@ def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] def ops_in_model_after(self): + if self.use_aiter: + return [rocm_aiter_ops.get_fused_allreduce_rmsnorm_op()] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -192,12 +204,36 @@ def ops_in_model_before(self): @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "test_model, enable_quant_fp8_custom_op", + "test_model, enable_quant_fp8_custom_op, use_aiter", [ - (TestAllReduceRMSNormModel, False), - (TestAllReduceRMSNormStaticQuantFP8Model, True), - (TestAllReduceRMSNormStaticQuantFP8Model, False), - (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), + (TestAllReduceRMSNormModel, False, IS_AITER_FOUND), + pytest.param( + TestAllReduceRMSNormStaticQuantFP8Model, + True, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Not supported on ROCm platform", + ), + ), + pytest.param( + TestAllReduceRMSNormStaticQuantFP8Model, + False, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Not supported on ROCm platform", + ), + ), + pytest.param( + TestAllReduceFusedAddRMSNormStaticQuantFP4Model, + False, + False, + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Not supported on ROCm platform", + ), + ), ], ) @pytest.mark.parametrize("batch_size", [8]) @@ -208,9 +244,18 @@ def ops_in_model_before(self): @pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( - not find_spec("flashinfer") - or not has_module_attribute("flashinfer.comm", "allreduce_fusion") - or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"), + current_platform.is_rocm() and not IS_AITER_FOUND, + reason="aiter is not found", +) +@pytest.mark.skipif( + current_platform.is_cuda() + and ( + not find_spec("flashinfer") + or not has_module_attribute("flashinfer.comm", "allreduce_fusion") + or not has_module_attribute( + "flashinfer.comm", "create_allreduce_fusion_workspace" + ) + ), reason="flashinfer is not found or flashinfer " "is not compiled with allreduce_fusion", ) @@ -223,7 +268,14 @@ def test_all_reduce_fusion_pass_replace( enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter: bool, + monkeypatch: pytest.MonkeyPatch, ): + if use_aiter: + with monkeypatch.context() as m: + m.setenv("VLLM_ROCM_USE_AITER", str(use_aiter)) + rocm_aiter_ops.refresh_env_variables() + num_processes = 2 if ( test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model @@ -247,6 +299,8 @@ def run_torch_spawn(fn, nprocs): enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter, + monkeypatch, ), nprocs=nprocs, ) @@ -265,6 +319,8 @@ def all_reduce_fusion_pass_on_test_model( enable_rms_norm_custom_op, enable_quant_fp8_custom_op, flashinfer_allreduce_backend, + use_aiter: bool, + monkeypatch: pytest.MonkeyPatch, ): set_random_seed(0) @@ -311,7 +367,11 @@ def all_reduce_fusion_pass_on_test_model( ) with set_current_vllm_config(vllm_config): initialize_model_parallel(tensor_model_parallel_size=world_size) - all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + all_reduce_fusion_pass = ( + RocmAiterAllReduceFusionPass(vllm_config) + if use_aiter + else AllReduceFusionPass(vllm_config) + ) noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) @@ -321,7 +381,7 @@ def all_reduce_fusion_pass_on_test_model( ) token_num = batch_size * seq_len - model = test_model_cls(hidden_size, token_num, dtype=dtype) + model = test_model_cls(hidden_size, token_num, dtype=dtype, use_aiter=use_aiter) hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) diff --git a/tests/compile/passes/distributed/test_rocm_fused_ar_rmsnorm.py b/tests/compile/passes/distributed/test_rocm_fused_ar_rmsnorm.py new file mode 100644 index 000000000000..a099fc45ab6f --- /dev/null +++ b/tests/compile/passes/distributed/test_rocm_fused_ar_rmsnorm.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Multi-GPU compiler-pass integration test for fused AR+RMSNorm on ROCm. + +Tests that RocmAiterRMSNormQuantFusionPass correctly handles both paths: + +FP4/BF16 path (no FP8 quant consumers): + - fused_allreduce_rmsnorm is preserved as-is in the compiled graph. + - At runtime, AITER's fused AR+RMSNorm kernel handles the operation. + +FP8 path (FP8 quant consumer follows the normed output): + - fused_allreduce_rmsnorm is decomposed into all_reduce + rmsnorm_with_add. + - Then rmsnorm_with_add + fp8_quant are fused into a single AITER op. +""" + +import pytest +import torch + +from vllm._aiter_ops import IS_AITER_FOUND +from vllm.platforms import current_platform + +from tests.utils import multi_gpu_test + +pytestmark = pytest.mark.skipif( + not current_platform.is_rocm() or not IS_AITER_FOUND, + reason="ROCm with AITER required", +) + + +class TestFusedARRMSNormNoQuantModel(torch.nn.Module): + """Model with fused_allreduce_rmsnorm but NO FP8 quant consumers. + + Simulates the FP4/BF16 path: the decomposition pass should NOT + decompose these nodes, preserving the AITER fused AR+RMSNorm kernel. + """ + + def __init__(self, hidden_size=64, eps=1e-5): + super().__init__() + from vllm.model_executor.layers.layernorm import RMSNorm + + self.norm = torch.nn.ModuleList([ + RMSNorm(hidden_size, eps, fused_allreduce=False), + RMSNorm(hidden_size, eps, fused_allreduce=True), + RMSNorm(hidden_size, eps, fused_allreduce=True), + ]) + self.w = torch.nn.ParameterList([ + torch.nn.Parameter(torch.rand(hidden_size, hidden_size)) + for _ in range(2) + ]) + + def forward(self, x): + z = torch.relu(x) + resid = z + y = self.norm[0](z) + + z2 = torch.mm(y, self.w[0]) + y2, resid = self.norm[1](z2, resid) + + z3 = torch.mm(y2, self.w[1]) + y3, resid = self.norm[2](z3, resid) + return y3 + + def ops_in_model_before(self): + return [torch.ops.vllm.fused_allreduce_rmsnorm.default] + + +class TestFusedARRMSNormFP8Model(torch.nn.Module): + """Model with fused_allreduce_rmsnorm AND FP8 per-token quant consumer. + + Simulates the FP8 path: the decomposition pass should decompose + fused_allreduce_rmsnorm into all_reduce + rmsnorm_with_add, then the + pattern matcher fuses rmsnorm_with_add + fp8_quant into one AITER op. + """ + + def __init__(self, hidden_size=64, eps=1e-5): + super().__init__() + from vllm.model_executor.layers.layernorm import RMSNorm + + self.norm = torch.nn.ModuleList([ + RMSNorm(hidden_size, eps, fused_allreduce=False), + RMSNorm(hidden_size, eps, fused_allreduce=True), + ]) + self.w = torch.nn.Parameter(torch.rand(hidden_size, hidden_size)) + + def forward(self, x): + z = torch.relu(x) + resid = z + y = self.norm[0](z) + + z2 = torch.mm(y, self.w) + y2, resid = self.norm[1](z2, resid) + + quant_out, scale = torch.ops.vllm.rocm_aiter_per_token_quant( + y2, torch.float8_e4m3fnuz, + ) + return quant_out.to(x.dtype) * scale + + def ops_in_model_before(self): + return [torch.ops.vllm.fused_allreduce_rmsnorm.default] + + +def _run_rocm_fused_ar_test( + local_rank: int, + world_size: int, + test_model_cls: type, + hidden_size: int, + dtype: torch.dtype, + expect_decomposition: bool, +): + """Worker process for the multi-GPU test.""" + from vllm._aiter_ops import rocm_aiter_ops + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) + from vllm.compilation.passes.utility.noop_elimination import ( + NoOpEliminationPass, + ) + from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass + from vllm.config import ( + CompilationConfig, + CompilationMode, + DeviceConfig, + ModelConfig, + PassConfig, + VllmConfig, + set_current_vllm_config, + ) + from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, + ) + from vllm.utils.system_utils import update_environment_variables + from vllm.utils.torch_utils import set_random_seed + + from tests.compile.backend import TestBackend + + set_random_seed(0) + + device = torch.device(f"cuda:{local_rank}") + torch.accelerator.set_device_index(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12346", + "VLLM_ROCM_USE_AITER": "1", + } + ) + + rocm_aiter_ops.refresh_env_variables() + init_distributed_environment() + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rms_norm"], + ) + ) + vllm_config.compilation_config.pass_config = PassConfig( + fuse_norm_quant=True, eliminate_noops=True + ) + vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + vllm_config.parallel_config.rank = local_rank + vllm_config.model_config = ModelConfig(dtype=dtype) + + with set_current_vllm_config(vllm_config): + initialize_model_parallel(tensor_model_parallel_size=world_size) + + fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + + model = test_model_cls(hidden_size) + token_num = 64 + hidden_states = torch.randn( + (token_num, hidden_size), requires_grad=False + ) + + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states) + + results_unfused = model(hidden_states) + results_fused = compiled_model(hidden_states) + torch.testing.assert_close( + results_unfused, results_fused, atol=1e-2, rtol=1e-2 + ) + + fused_ar_op = torch.ops.vllm.fused_allreduce_rmsnorm.default + from vllm.compilation.passes.fx_utils import find_op_nodes + + fused_count_before = len( + list(find_op_nodes(fused_ar_op, backend.graph_pre_pass)) + ) + fused_count_after = len( + list(find_op_nodes(fused_ar_op, backend.graph_post_pass)) + ) + + if expect_decomposition: + assert fused_count_before > 0, ( + "Expected fused_allreduce_rmsnorm in pre-pass graph" + ) + assert fused_count_after < fused_count_before, ( + "Expected decomposition to remove some " + "fused_allreduce_rmsnorm nodes" + ) + else: + assert fused_count_after == fused_count_before, ( + "fused_allreduce_rmsnorm should be preserved " + "when no FP8 consumer" + ) + + del fusion_pass + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_rocm_fused_ar_rmsnorm_no_quant_preserved( + hidden_size: int, + dtype: torch.dtype, +): + """FP4/BF16 path: fused_allreduce_rmsnorm preserved when no FP8 quant.""" + num_processes = 2 + torch.multiprocessing.spawn( + _run_rocm_fused_ar_test, + args=( + num_processes, + TestFusedARRMSNormNoQuantModel, + hidden_size, + dtype, + False, + ), + nprocs=num_processes, + ) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_rocm_fused_ar_rmsnorm_fp8_decomposed( + hidden_size: int, + dtype: torch.dtype, +): + """FP8 path: fused_allreduce_rmsnorm decomposed when FP8 quant follows.""" + num_processes = 2 + torch.multiprocessing.spawn( + _run_rocm_fused_ar_test, + args=( + num_processes, + TestFusedARRMSNormFP8Model, + hidden_size, + dtype, + True, + ), + nprocs=num_processes, + ) diff --git a/tests/distributed/test_fused_ar_rmsnorm.py b/tests/distributed/test_fused_ar_rmsnorm.py new file mode 100644 index 000000000000..c85473d72a46 --- /dev/null +++ b/tests/distributed/test_fused_ar_rmsnorm.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Multi-GPU tests for fused allreduce+RMSNorm communication op. + +Run: pytest tests/distributed/test_fused_ar_rmsnorm.py +""" + +import pytest +import ray +import torch + +from vllm._aiter_ops import IS_AITER_FOUND +from vllm.platforms import current_platform + +from ..utils import ( + init_test_distributed_environment, + multi_gpu_test, + multi_process_parallel, +) + + +@ray.remote(num_gpus=1, max_calls=1) +def fused_allreduce_rmsnorm_test_worker( + monkeypatch: pytest.MonkeyPatch, + tp_size: int, + pp_size: int, + rank: int, + distributed_init_port: str, +): + """Test that fused_allreduce_rmsnorm produces correct results. + + Compares the fused path (allreduce + add + rmsnorm in one call) + against the split path (manual allreduce, then add, then rmsnorm). + """ + monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) + + device = torch.device(f"cuda:{rank}") + torch.accelerator.set_device_index(device) + init_test_distributed_environment( + tp_size, pp_size, rank, distributed_init_port + ) + + from vllm.distributed import get_tp_group + + tp_group = get_tp_group() + + hidden_size = 256 + num_tokens = 64 + eps = 1e-5 + + torch.manual_seed(42 + rank) + weight = torch.randn(hidden_size, dtype=torch.bfloat16, device=device) + + for _ in range(3): + input_ = torch.randn( + num_tokens, hidden_size, dtype=torch.bfloat16, device=device + ) + residual = torch.randn( + num_tokens, hidden_size, dtype=torch.bfloat16, device=device + ) + + input_ref = input_.clone() + residual_ref = residual.clone() + + normed, resid_out = tp_group._fused_allreduce_rmsnorm_out_place( + input_, residual, weight, eps + ) + + ar_ref = tp_group.all_reduce(input_ref) + combined = ar_ref + residual_ref + variance = combined.pow(2).mean(-1, keepdim=True) + normed_ref = combined * torch.rsqrt(variance + eps) + normed_ref = normed_ref * weight + + torch.testing.assert_close( + resid_out, combined, atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close( + normed, normed_ref, atol=1e-2, rtol=1e-2 + ) + + +@ray.remote(num_gpus=1, max_calls=1) +def fused_allreduce_rmsnorm_world_size_1_test_worker( + monkeypatch: pytest.MonkeyPatch, + tp_size: int, + pp_size: int, + rank: int, + distributed_init_port: str, +): + """Test world_size==1 fallback (add + rmsnorm, no allreduce).""" + monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) + + device = torch.device(f"cuda:{rank}") + torch.accelerator.set_device_index(device) + init_test_distributed_environment( + tp_size, pp_size, rank, distributed_init_port + ) + + from vllm.distributed import get_tp_group + + tp_group = get_tp_group() + assert tp_group.world_size == 1 + + hidden_size = 128 + num_tokens = 16 + eps = 1e-5 + + torch.manual_seed(42) + weight = torch.randn(hidden_size, dtype=torch.bfloat16, device=device) + input_ = torch.randn( + num_tokens, hidden_size, dtype=torch.bfloat16, device=device + ) + residual = torch.randn( + num_tokens, hidden_size, dtype=torch.bfloat16, device=device + ) + + normed, resid_out = tp_group.fused_allreduce_rmsnorm( + input_, residual, weight, eps + ) + + combined = input_ + residual + variance = combined.pow(2).mean(-1, keepdim=True) + normed_ref = combined * torch.rsqrt(variance + eps) + normed_ref = normed_ref * weight + + torch.testing.assert_close(resid_out, combined, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(normed, normed_ref, atol=1e-2, rtol=1e-2) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.skipif( + not current_platform.is_rocm() or not IS_AITER_FOUND, + reason="ROCm with AITER required", +) +def test_fused_allreduce_rmsnorm( + monkeypatch: pytest.MonkeyPatch, + tp_size: int, +): + multi_process_parallel( + monkeypatch, + tp_size, + 1, + fused_allreduce_rmsnorm_test_worker, + ) + + +@multi_gpu_test(num_gpus=1) +@pytest.mark.skipif( + not current_platform.is_rocm() or not IS_AITER_FOUND, + reason="ROCm with AITER required", +) +def test_fused_allreduce_rmsnorm_world_size_1( + monkeypatch: pytest.MonkeyPatch, +): + multi_process_parallel( + monkeypatch, + 1, + 1, + fused_allreduce_rmsnorm_world_size_1_test_worker, + ) diff --git a/tests/rocm/aiter/test_fused_ar_rmsnorm.py b/tests/rocm/aiter/test_fused_ar_rmsnorm.py new file mode 100644 index 000000000000..9d81dc0e6ce8 --- /dev/null +++ b/tests/rocm/aiter/test_fused_ar_rmsnorm.py @@ -0,0 +1,485 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the fused allreduce+RMSNorm feature (MI355X). + +Tests cover: +- Custom op registration and fake implementation +- Auto-detection gating (proves non-regression for non-MI355X / non-AITER) +- Graph-level decomposition of fused_allreduce_rmsnorm when followed by FP8 quant +- Preservation of fused_allreduce_rmsnorm when NOT followed by FP8 quant (FP4 path) +- Full-pass pipeline: decompose → pattern-match → fuse (FP8) vs preserve (FP4) +""" + +import operator +from unittest.mock import patch + +import pytest +import torch +from torch import fx + +from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops +from vllm.platforms import current_platform + +pytestmark = pytest.mark.skipif( + not current_platform.is_rocm() or not IS_AITER_FOUND, + reason="ROCm with AITER required", +) + + +# --------------------------------------------------------------------------- +# Custom-op registration / fake-impl tests +# --------------------------------------------------------------------------- + + +def test_fused_allreduce_rmsnorm_op_registered(): + """Verify the custom op is registered and callable.""" + assert hasattr(torch.ops.vllm, "fused_allreduce_rmsnorm") + op = torch.ops.vllm.fused_allreduce_rmsnorm.default + assert op is not None + + +def test_fused_allreduce_rmsnorm_fake_shapes(): + """Verify the fake implementation returns tensors of correct shape.""" + from torch._subclasses.fake_tensor import FakeTensorMode + + hidden = 128 + tokens = 32 + + with FakeTensorMode(): + input_ = torch.randn(tokens, hidden, device="cuda") + residual = torch.randn(tokens, hidden, device="cuda") + weight = torch.randn(hidden, device="cuda") + + out, resid_out = torch.ops.vllm.fused_allreduce_rmsnorm( + input_, residual, weight, 1e-5, "fake_group" + ) + + assert out.shape == (tokens, hidden) + assert resid_out.shape == (tokens, hidden) + + +# --------------------------------------------------------------------------- +# Decomposition-pass graph-level tests +# --------------------------------------------------------------------------- + + +def _build_fused_ar_rmsnorm_graph( + hidden_size: int, + eps: float, + fp8_quant_op, + add_fp8_consumer: bool = True, +) -> tuple[fx.Graph, fx.Node]: + """Build a minimal FX graph with fused_allreduce_rmsnorm. + + Returns (graph, fused_ar_rms_node). + If add_fp8_consumer is True, the normed output (getitem 0) feeds + into an FP8 quant op. + """ + graph = fx.Graph() + + input_ = graph.placeholder("input_") + residual = graph.placeholder("residual") + weight = graph.placeholder("weight") + + fused_node = graph.call_function( + torch.ops.vllm.fused_allreduce_rmsnorm.default, + args=(input_, residual, weight, eps, "test_group"), + ) + + normed = graph.call_function(operator.getitem, args=(fused_node, 0)) + resid_out = graph.call_function(operator.getitem, args=(fused_node, 1)) + + fake_input = torch.randn(4, hidden_size) + fake_residual = torch.randn(4, hidden_size) + input_.meta["val"] = fake_input + residual.meta["val"] = fake_residual + weight.meta["val"] = torch.randn(hidden_size) + fused_node.meta["val"] = (fake_residual.clone(), fake_residual.clone()) + normed.meta["val"] = fake_residual.clone() + resid_out.meta["val"] = fake_residual.clone() + + if add_fp8_consumer: + quant_node = graph.call_function(fp8_quant_op, args=(normed,)) + quant_out = graph.call_function(operator.getitem, args=(quant_node, 0)) + quant_scale = graph.call_function( + operator.getitem, args=(quant_node, 1) + ) + + quant_node.meta["val"] = ( + fake_residual.to(torch.float8_e4m3fnuz), + torch.randn(4, 1), + ) + quant_out.meta["val"] = fake_residual.to(torch.float8_e4m3fnuz) + quant_scale.meta["val"] = torch.randn(4, 1) + + graph.output((quant_out, resid_out, quant_scale)) + else: + graph.output((normed, resid_out)) + + return graph, fused_node + + +def test_decompose_fused_ar_rmsnorm_with_fp8(monkeypatch: pytest.MonkeyPatch): + """When fused_allreduce_rmsnorm feeds into FP8 quant, it should be + decomposed into all_reduce + rmsnorm_with_add.""" + import vllm.config + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + try: + fp8_quant_op = rocm_aiter_ops.get_per_token_quant_op() + except Exception: + pytest.skip("AITER per_token_quant op not available") + + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) + from vllm.config import ( + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, + ) + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=torch.bfloat16), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig( + fuse_norm_quant=True, eliminate_noops=True + ), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config): + fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config) + + graph, _ = _build_fused_ar_rmsnorm_graph( + hidden_size=256, + eps=1e-5, + fp8_quant_op=fp8_quant_op, + add_fp8_consumer=True, + ) + + fused_ar_op = torch.ops.vllm.fused_allreduce_rmsnorm.default + all_reduce_op = torch.ops.vllm.all_reduce.default + rmsnorm_add_op = rocm_aiter_ops.get_rmsnorm_fused_add_op() + + fused_before = sum( + 1 for n in graph.nodes if n.target == fused_ar_op + ) + assert fused_before == 1, "Expected 1 fused_allreduce_rmsnorm node" + + count = fusion_pass._decompose_fused_allreduce_rmsnorm(graph) + + assert count == 1, f"Expected 1 decomposition, got {count}" + + fused_after = sum( + 1 for n in graph.nodes if n.target == fused_ar_op + ) + assert fused_after == 0, "fused_allreduce_rmsnorm should be removed" + + ar_nodes = sum( + 1 for n in graph.nodes if n.target == all_reduce_op + ) + assert ar_nodes == 1, f"Expected 1 all_reduce node, got {ar_nodes}" + + rms_nodes = sum( + 1 for n in graph.nodes if n.target == rmsnorm_add_op + ) + assert rms_nodes == 1, f"Expected 1 rmsnorm_with_add node, got {rms_nodes}" + + +def test_preserve_fused_ar_rmsnorm_without_fp8( + monkeypatch: pytest.MonkeyPatch, +): + """When fused_allreduce_rmsnorm does NOT feed into FP8 quant (e.g. FP4), + the node should be preserved.""" + import vllm.config + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + try: + fp8_quant_op = rocm_aiter_ops.get_per_token_quant_op() + except Exception: + pytest.skip("AITER per_token_quant op not available") + + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) + from vllm.config import ( + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, + ) + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=torch.bfloat16), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig( + fuse_norm_quant=True, eliminate_noops=True + ), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config): + fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config) + + graph, _ = _build_fused_ar_rmsnorm_graph( + hidden_size=256, + eps=1e-5, + fp8_quant_op=fp8_quant_op, + add_fp8_consumer=False, + ) + + fused_ar_op = torch.ops.vllm.fused_allreduce_rmsnorm.default + + fused_before = sum( + 1 for n in graph.nodes if n.target == fused_ar_op + ) + assert fused_before == 1 + + count = fusion_pass._decompose_fused_allreduce_rmsnorm(graph) + + assert count == 0, "Should not decompose when no FP8 consumer" + + fused_after = sum( + 1 for n in graph.nodes if n.target == fused_ar_op + ) + assert fused_after == 1, "fused_allreduce_rmsnorm should be preserved" + + +def test_decompose_no_op_without_fused_nodes( + monkeypatch: pytest.MonkeyPatch, +): + """Decomposition should be a no-op when graph has no fused_allreduce_rmsnorm.""" + import vllm.config + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) + from vllm.config import ( + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, + ) + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=torch.bfloat16), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig( + fuse_norm_quant=True, eliminate_noops=True + ), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config): + fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config) + + graph = fx.Graph() + x = graph.placeholder("x") + y = graph.call_function(torch.relu, args=(x,)) + graph.output(y) + + count = fusion_pass._decompose_fused_allreduce_rmsnorm(graph) + assert count == 0 + + +def test_is_fused_allreduce_rmsnorm_supported(): + """Verify auto-detection method exists and returns a bool.""" + result = rocm_aiter_ops.is_fused_allreduce_rmsnorm_supported() + assert isinstance(result, (bool, type(None))) + + +# --------------------------------------------------------------------------- +# Auto-detection gating tests — proves non-regression for other platforms +# --------------------------------------------------------------------------- + + +def test_auto_detection_disabled_without_gfx950(): + """Feature must be disabled when not on gfx950 (MI355X). + + This proves non-MI355X ROCm GPUs are unaffected by the optimization. + """ + with patch("vllm._aiter_ops.rocm_aiter_ops._AITER_ENABLED", True), \ + patch("vllm._aiter_ops.rocm_aiter_ops._RMSNORM_ENABLED", True), \ + patch("vllm.platforms.rocm.on_gfx950", return_value=False): + result = rocm_aiter_ops.is_fused_allreduce_rmsnorm_supported() + assert not result, ( + "Fused AR+RMSNorm should be disabled when not on gfx950" + ) + + +def test_auto_detection_disabled_without_aiter(): + """Feature must be disabled when AITER kernels are not available. + + This proves non-AITER environments (e.g. CUDA, older ROCm) are + unaffected. + """ + saved = rocm_aiter_ops._AITER_ENABLED + try: + rocm_aiter_ops._AITER_ENABLED = False + result = rocm_aiter_ops.is_fused_allreduce_rmsnorm_supported() + assert not result, ( + "Fused AR+RMSNorm should be disabled when AITER is not enabled" + ) + finally: + rocm_aiter_ops._AITER_ENABLED = saved + + +def test_auto_detection_disabled_without_rmsnorm(): + """Feature must be disabled when AITER RMSNorm is not available.""" + saved = rocm_aiter_ops._RMSNORM_ENABLED + try: + rocm_aiter_ops._RMSNORM_ENABLED = False + result = rocm_aiter_ops.is_fused_allreduce_rmsnorm_supported() + assert not result, ( + "Fused AR+RMSNorm should be disabled when RMSNorm is not enabled" + ) + finally: + rocm_aiter_ops._RMSNORM_ENABLED = saved + + +# --------------------------------------------------------------------------- +# Full-pass pipeline tests (decompose + pattern-match fusion) +# --------------------------------------------------------------------------- + + +def _make_fusion_pass(monkeypatch): + """Create a RocmAiterRMSNormQuantFusionPass with standard config.""" + import vllm.config + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) + from vllm.config import ( + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, + ) + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=torch.bfloat16), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig( + fuse_norm_quant=True, eliminate_noops=True + ), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config): + return RocmAiterRMSNormQuantFusionPass(vllm_config) + + +def _wrap_graph_in_module(graph: fx.Graph) -> fx.GraphModule: + """Wrap a bare fx.Graph in a GraphModule. + + The Inductor PatternMatcherPass requires graph.owning_module to be + a GraphModule. Standalone graphs don't have one, so we wrap them. + """ + return fx.GraphModule(torch.nn.Module(), graph) + + +def test_full_pass_fp8_decompose_and_fuse(monkeypatch: pytest.MonkeyPatch): + """Full pass on FP8 graph: verify the complete pass runs and decomposes + fused_allreduce_rmsnorm when FP8 quant consumers are present. + + Decomposition is verified here at graph level. The subsequent + rmsnorm+quant pattern-match fusion requires a properly traced graph + (via torch.compile) and is covered by the multi-GPU integration test + test_rocm_fused_ar_rmsnorm_fp8_decomposed. + """ + try: + fp8_quant_op = rocm_aiter_ops.get_per_token_quant_op() + except Exception: + pytest.skip("AITER per_token_quant op not available") + + fusion_pass = _make_fusion_pass(monkeypatch) + + graph, _ = _build_fused_ar_rmsnorm_graph( + hidden_size=256, + eps=1e-5, + fp8_quant_op=fp8_quant_op, + add_fp8_consumer=True, + ) + gm = _wrap_graph_in_module(graph) + graph = gm.graph + + fused_ar_op = torch.ops.vllm.fused_allreduce_rmsnorm.default + all_reduce_op = torch.ops.vllm.all_reduce.default + rmsnorm_add_op = rocm_aiter_ops.get_rmsnorm_fused_add_op() + + assert sum(1 for n in graph.nodes if n.target == fused_ar_op) == 1 + assert sum(1 for n in graph.nodes if n.target == all_reduce_op) == 0 + + fusion_pass(graph) + + assert sum(1 for n in graph.nodes if n.target == fused_ar_op) == 0, ( + "fused_allreduce_rmsnorm should be decomposed in FP8 path" + ) + assert sum(1 for n in graph.nodes if n.target == all_reduce_op) == 1, ( + "all_reduce should be present after decomposition" + ) + assert sum(1 for n in graph.nodes if n.target == rmsnorm_add_op) >= 1, ( + "rmsnorm_with_add should be present after decomposition" + ) + assert sum(1 for n in graph.nodes if n.target == fp8_quant_op) >= 1, ( + "fp8_quant should still be in graph (pattern matching requires " + "torch.compile traced graphs for full fusion)" + ) + + +def test_full_pass_fp4_preserves_fused(monkeypatch: pytest.MonkeyPatch): + """Full pass on FP4/BF16 graph: fused_allreduce_rmsnorm has no FP8 + consumer, so it should survive the entire pass untouched. + + This verifies that FP4 models keep the AITER fused AR+RMSNorm kernel + and the pass does not alter the graph. + """ + try: + fp8_quant_op = rocm_aiter_ops.get_per_token_quant_op() + except Exception: + pytest.skip("AITER per_token_quant op not available") + + fusion_pass = _make_fusion_pass(monkeypatch) + + graph, _ = _build_fused_ar_rmsnorm_graph( + hidden_size=256, + eps=1e-5, + fp8_quant_op=fp8_quant_op, + add_fp8_consumer=False, + ) + gm = _wrap_graph_in_module(graph) + graph = gm.graph + + fused_ar_op = torch.ops.vllm.fused_allreduce_rmsnorm.default + + assert sum(1 for n in graph.nodes if n.target == fused_ar_op) == 1 + + fusion_pass(graph) + + assert sum(1 for n in graph.nodes if n.target == fused_ar_op) == 1, ( + "fused_allreduce_rmsnorm must be preserved for FP4/BF16 models" + ) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index d59b74782be2..fcea2b660830 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -2,9 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from collections.abc import Callable +from contextlib import contextmanager +from typing import Protocol import torch from torch._ops import OpOverload +from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.platforms import current_platform @@ -33,6 +36,25 @@ def is_aiter_found() -> bool: IS_AITER_FOUND = is_aiter_found() +class AiterCustomAllreduceProto(Protocol): + max_size: int + world_size: int + fully_connected: bool + + @contextmanager + def capture(self): ... + def close(self) -> None: ... + def custom_fused_ar_rms( + self, + input: torch.Tensor, + residual_inp: torch.Tensor, + weight: torch.Tensor, + eps: float, + use_1stage: bool, + ) -> tuple[torch.Tensor, torch.Tensor] | None: ... + def should_custom_ar(self, inp: torch.Tensor) -> bool: ... + + def is_aiter_found_and_supported() -> bool: """Check if AITER library is available and platform supports it. @@ -672,6 +694,47 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake( return out, y_scale +def _rocm_aiter_fused_allreduce_rmsnorm_impl( + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + assert aiter_ar is not None, "aiter allreduce must be initialized" + + total_bytes = input_.numel() * input_.element_size() + hidden_dim = input_.shape[-1] + token_num = input_.shape[0] + hidden_ok = hidden_dim in (512, 1024, 2048, 4096) + token_ok = token_num <= 80 + world_size = aiter_ar.world_size + full_nvlink = aiter_ar.fully_connected + + if world_size == 2: + size_ok = True + elif full_nvlink and world_size <= 4: + size_ok = total_bytes < 160 * 1024 + elif full_nvlink and world_size <= 8: + size_ok = total_bytes < 80 * 1024 + else: + size_ok = False + + use_1stage = hidden_ok and token_ok and size_ok + result = aiter_ar.custom_fused_ar_rms(input_, residual, weight, epsilon, use_1stage) + assert result is not None + return result[0], result[1] + + +def _rocm_aiter_fused_allreduce_rmsnorm_fake( + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(input_), torch.empty_like(residual) + + def _rocm_aiter_per_tensor_quant_impl( x: torch.Tensor, quant_dtype: torch.dtype, @@ -1072,6 +1135,9 @@ class rocm_aiter_ops: # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + _ALL_REDUCE_MAX_SIZE: int = 8192 * 1024 * 8 * 2 + _CUSTOM_ALL_REDUCE: AiterCustomAllreduceProto | None = None + @classmethod def refresh_env_variables(cls): """ @@ -1239,6 +1305,65 @@ def is_triton_rotary_embed_enabled(cls) -> bool: def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM + @classmethod + @if_aiter_supported + def initialize_aiter_allreduce( + cls, group: ProcessGroup, device: torch.device + ) -> None: + if cls._CUSTOM_ALL_REDUCE is not None: + return + try: + from aiter.dist.device_communicators.custom_all_reduce import ( + CustomAllreduce as AiterCustomAllreduce, + ) + + cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device) + except Exception: + cls._CUSTOM_ALL_REDUCE = None + + @classmethod + def get_aiter_allreduce(cls) -> AiterCustomAllreduceProto | None: + return cls._CUSTOM_ALL_REDUCE + + @classmethod + def destroy_aiter_allreduce(cls) -> None: + if cls._CUSTOM_ALL_REDUCE is not None: + cls._CUSTOM_ALL_REDUCE.close() + cls._CUSTOM_ALL_REDUCE = None + + @classmethod + def get_aiter_allreduce_max_size(cls) -> int: + return cls._ALL_REDUCE_MAX_SIZE + + @classmethod + @if_aiter_supported + def is_fused_allreduce_rmsnorm_supported(cls) -> bool: + """Check if fused allreduce+RMSNorm is supported on this platform. + + Requires gfx950 (MI355X), AITER enabled, RMSNorm kernels available, + and AITER's CustomAllreduce communicator importable. + + Currently only wired in deepseek_v2.py (DeepSeek V2/V3/R1 family). + Other models are unaffected because they do not set + ``fused_allreduce=True`` on their RMSNorm layers. Models that + inherit DeepseekV2DecoderLayer (Eagle, MTP, Mistral Large 3) + automatically benefit when running on MI355X with TP > 1. + + Returns None (falsy) on non-ROCm platforms via @if_aiter_supported. + """ + from vllm.platforms.rocm import on_gfx950 + + if not (cls._AITER_ENABLED and cls._RMSNORM_ENABLED and on_gfx950()): + return False + try: + from aiter.dist.device_communicators.custom_all_reduce import ( + CustomAllreduce as _, # noqa: F401 + ) + + return True + except ImportError: + return False + @staticmethod @if_aiter_supported def register_ops_once() -> None: @@ -1425,6 +1550,12 @@ def register_ops_once() -> None: fake_impl=_triton_rotary_embedding_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_fused_allreduce_rmsnorm", + op_func=_rocm_aiter_fused_allreduce_rmsnorm_impl, + fake_impl=_rocm_aiter_fused_allreduce_rmsnorm_fake, + ) + _OPS_REGISTERED = True @staticmethod @@ -1471,6 +1602,10 @@ def get_triton_add_rmsnorm_pad_op() -> OpOverload: def get_triton_rotary_embedding_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default + @staticmethod + def get_fused_allreduce_rmsnorm_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_fused_allreduce_rmsnorm.default + @staticmethod def rms_norm( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float diff --git a/vllm/compilation/passes/fusion/act_quant_fusion.py b/vllm/compilation/passes/fusion/act_quant_fusion.py index a712c013ce99..e67a021ae3c2 100644 --- a/vllm/compilation/passes/fusion/act_quant_fusion.py +++ b/vllm/compilation/passes/fusion/act_quant_fusion.py @@ -193,6 +193,7 @@ def __init__( is_scale_transposed: bool = False, is_e8m0: bool = False, is_tma_aligned: bool = False, + match_aiter: bool = False, ) -> None: super().__init__(quant_key) self.quant_matcher = MatcherQuantFP8( diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 09b9a557fe45..e9d8a81c9e58 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -12,12 +12,14 @@ from torch._inductor.pattern_matcher import PatternMatcherPass import vllm.ir.ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.compilation.passes.fusion.rms_quant_fusion import ( _rms_input_weight_dtype_match, ) from vllm.config import VllmConfig from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -889,3 +891,203 @@ def __del__(self) -> None: return with contextlib.suppress(Exception): destroy_fi_ar_workspace() + + +class AiterAllreduceFusedRMSNormPattern(BasePattern): + FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str | None, + use_aiter_rmsnorm: bool = True, + ) -> None: + super().__init__(dtype, device) + self.dtype = dtype + self.epsilon = epsilon + + def get_inputs(self) -> list[torch.Tensor]: + return [self.empty(5, 16), self.empty(16)] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms = vllm.ir.ops.rms_norm(allreduce_output, weight, self.epsilon) + + return rms, allreduce_output + + def replacement( + input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + residual = torch.empty_like(input) + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + return allreduce[0], allreduce[1] + + pm.register_replacement( + pattern, + replacement, + self.get_inputs(), + pm.fwd_only, + pm_pass, + extra_check=_rms_input_weight_dtype_match, + ) + + +class AiterAllreduceFusedAddRMSNormPattern(BasePattern): + FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str | None, + use_aiter_rmsnorm: bool = True, + ) -> None: + super().__init__(dtype, device) + self.epsilon = epsilon + self.dtype = dtype + self.rmsnorm_matcher = MatcherFusedAddRMSNorm( + epsilon, match_rocm_aiter=use_aiter_rmsnorm + ) + + def get_inputs(self) -> list[torch.Tensor]: + input, residual, weight = self.rmsnorm_matcher.inputs() + + return [residual, input.to(self.dtype), weight] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) + + return rms, residual + + def replacement( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + return allreduce[0], allreduce[1] + + pm.register_replacement( + pattern, + replacement, + self.get_inputs(), + pm.fwd_only, + pm_pass, + extra_check=_rms_input_weight_dtype_match, + ) + + +class RocmAiterAllReduceFusionPass(VllmPatternMatcherPass): + def __init__(self, config: VllmConfig) -> None: + super().__init__(config) + self.disabled = True + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size <= 1: + logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.") + return + + if config.model_config is None: + logger.warning_once( + "AllReduce fusion pass is disabled for missing model_config." + ) + return + + device_comm = get_tp_group().device_communicator + if device_comm is None: + logger.warning_once("Device communicator is required.") + return + + ca_comm = getattr(device_comm, "ca_comm", None) + if ca_comm is None: + logger.warning_once("Custom Allreduce is required.") + return + self.ca_comm = ca_comm + + assert isinstance(ca_comm, CustomAllreduce) + + group = get_tp_group().cpu_group + rocm_aiter_ops.initialize_aiter_allreduce(group, self.device) + hidden_dim = config.model_config.get_hidden_size() + element_size = torch.tensor([], dtype=self.model_dtype).element_size() + max_size = rocm_aiter_ops.get_aiter_allreduce_max_size() + if max_size is None: + logger.warning("AITER allreduce fusion must be initialized") + return + + max_token_num = (max_size / 2) // (hidden_dim * element_size) + self.max_token_num = min( + max_token_num, + config.scheduler_config.max_num_batched_tokens, + ) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_allreduce_rmsnorm_fusion_pass" + ) + + self.register_patterns() + self.dump_patterns(config, self.patterns) + + @enable_fake_mode + def register_patterns(self): + for epsilon in [1e-5, 1e-6]: + AiterAllreduceFusedRMSNormPattern( + epsilon, + self.model_dtype, + self.device, + ).register(self.patterns) + + AiterAllreduceFusedAddRMSNormPattern( + epsilon, + self.model_dtype, + self.device, + ).register(self.patterns) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + self.disabled = False + + def is_applicable_for_range(self, compile_range: Range) -> bool: + if self.disabled: + logger.warning_once("AllReduce fusion pass is disabled.") + return False + return bool(compile_range.end <= self.max_token_num) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + if self.disabled: + logger.debug("ROCmAiterAllReduceRMSNormFusionPass disabled") + return + + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def __del__(self) -> None: + if getattr(self, "disabled", True): + return + with contextlib.suppress(Exception): + rocm_aiter_ops.destroy_aiter_allreduce() + + def uuid(self) -> str: + return VllmInductorPass.hash_source( + self, + AiterAllreduceFusedRMSNormPattern, + AiterAllreduceFusedAddRMSNormPattern, + ) diff --git a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py index 9a985472371e..14684e022e68 100644 --- a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py +++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any +import operator + import torch import torch._inductor.pattern_matcher as pm from torch import fx @@ -283,12 +285,40 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass): This pass fuses aiter rms_norm & vllm/aiter quant custom ops into a fused rms_norm_quant op. It also supports fused_add_rms_norm. + + When fused_allreduce_rmsnorm custom ops are present in the graph + (auto-enabled on MI355X with AITER for DeepSeek-family models), + two paths are taken depending on the quantization type: + + FP8 path: fused_allreduce_rmsnorm is decomposed into + all_reduce + rms_norm_with_add, then rms_norm_with_add + fp8_quant + are fused into a single AITER kernel by the pattern matcher. + + FP4/BF16 path: fused_allreduce_rmsnorm is preserved as-is because + no FP8 quantization op consumes the normed output. At runtime, + AITER's fused AR+RMSNorm kernel handles the combined operation. """ @enable_fake_mode def __init__(self, config: VllmConfig) -> None: super().__init__(config) + self._fp8_quant_ops: set[OpOverload] = set() + try: + self._fp8_quant_ops.add( + torch.ops.vllm.triton_per_token_group_quant_fp8.default + ) + except AttributeError: + pass + try: + self._fp8_quant_ops.add(rocm_aiter_ops.get_group_quant_op()) + except Exception: + pass + try: + self._fp8_quant_ops.add(rocm_aiter_ops.get_per_token_quant_op()) + except Exception: + pass + self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="rocm_aiter_rms_norm_quant_fusion_pass" ) @@ -321,8 +351,101 @@ def __init__(self, config: VllmConfig) -> None: self.dump_patterns(config, self.patterns) + def _decompose_fused_allreduce_rmsnorm(self, graph: fx.Graph) -> int: + """ + Decompose fused_allreduce_rmsnorm into all_reduce + rmsnorm_with_add + when the normed output feeds into an FP8 quantization op. + + The AITER fused allreduce+rmsnorm kernel wraps both operations into a + single opaque custom op, which prevents torch.compile from recognizing + the rmsnorm for rms_norm + fp8_quant fusion. By decomposing the op + into visible all_reduce and rms_norm_with_add ops, the existing fusion + patterns can match. + + Only decomposes when an FP8 quant op consumes the normed output, + preserving the AITER fused kernel for non-FP8 models (e.g. FP4). + """ + try: + fused_ar_rms_op = torch.ops.vllm.fused_allreduce_rmsnorm.default + except AttributeError: + return 0 + + all_reduce_op = torch.ops.vllm.all_reduce.default + rmsnorm_with_add_op = rocm_aiter_ops.get_rmsnorm_fused_add_op() + + count = 0 + for node in list(graph.nodes): + if node.target != fused_ar_rms_op: + continue + + has_fp8_consumer = False + for user in node.users: + if user.target == operator.getitem and user.args[1] == 0: + for quant_user in user.users: + if ( + quant_user.op == "call_function" + and quant_user.target in self._fp8_quant_ops + ): + has_fp8_consumer = True + break + if has_fp8_consumer: + break + + if not has_fp8_consumer: + continue + + input_ = node.args[0] + residual = node.args[1] + weight = node.args[2] + eps = node.kwargs.get( + "eps", node.args[3] if len(node.args) > 3 else None + ) + group_name = node.kwargs.get( + "group_name", node.args[4] if len(node.args) > 4 else None + ) + + with graph.inserting_before(node): + # All args must be positional to match the schema-normalized + # representation that torch.compile uses for torch.ops.* nodes. + # graph.call_function bypasses schema normalization, so kwargs + # would create a node representation that the pattern matcher + # cannot match against the traced pattern. + ar_args: tuple = (input_,) + if group_name is not None: + ar_args = (input_, group_name) + + ar_node = graph.call_function( + all_reduce_op, + args=ar_args, + ) + + rmsnorm_node = graph.call_function( + rmsnorm_with_add_op, + args=(ar_node, residual, weight, eps), + ) + + # Propagate FakeTensor metadata so downstream passes + # (codegen, shape inference) have correct type info. + if "val" in node.meta: + rmsnorm_node.meta["val"] = node.meta["val"] + if hasattr(input_, "meta") and "val" in input_.meta: + ar_node.meta["val"] = input_.meta["val"] + + node.replace_all_uses_with(rmsnorm_node) + graph.erase_node(node) + count += 1 + + return count + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: + decomposed = self._decompose_fused_allreduce_rmsnorm(graph) + if decomposed: + logger.debug( + "Decomposed %s fused_allreduce_rmsnorm nodes for " + "FP8 quant fusion compatibility", + decomposed, + ) self.matched_count = self.patterns.apply(graph) logger.debug( "%s Replaced %s patterns", self.__class__.__name__, self.matched_count diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index b4823a0afde1..30eddaa2389c 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -18,6 +18,9 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass if rocm_aiter_ops.is_enabled(): + from .fusion.allreduce_rms_fusion import ( + RocmAiterAllReduceFusionPass, + ) from .fusion.rocm_aiter_fusion import ( RocmAiterRMSNormQuantFusionPass, RocmAiterSiluMulFp8GroupQuantFusionPass, @@ -135,7 +138,10 @@ def configure(self, config: VllmConfig) -> None: self.passes += [AsyncTPPass(config)] if self.pass_config.fuse_allreduce_rms: - self.passes += [AllReduceFusionPass(config)] + if rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterAllReduceFusionPass(config)] + else: + self.passes += [AllReduceFusionPass(config)] if self.pass_config.fuse_norm_quant: self.passes += [RMSNormQuantFusionPass(config)] diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 6229b44d52a8..6e7e91754f02 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -120,6 +120,15 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer + if current_platform.is_rocm(): + from vllm._aiter_ops import rocm_aiter_ops + + return ( + rocm_aiter_ops.is_enabled() + and rocm_aiter_ops.is_rmsnorm_enabled() + and cfg.parallel_config.tensor_parallel_size > 1 + ) + return ( cfg.parallel_config.tensor_parallel_size > 1 and current_platform.is_cuda() diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 5ad99e4e1592..79f8465e87d1 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -14,6 +14,24 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: return get_tp_group().all_reduce(input_) +def tensor_model_parallel_fused_allreduce_rmsnorm( + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused allreduce + residual-add + RMSNorm across model parallel group. + + Instead of allreduce inside o_proj/MoE followed by a separate RMSNorm, + this fuses both operations: allreduce(input_) + residual → RMSNorm. + + Returns (normed_output, updated_residual). + """ + return get_tp_group().fused_allreduce_rmsnorm( + input_, residual, weight, eps + ) + + def tensor_model_parallel_all_gather( input_: torch.Tensor, dim: int = -1 ) -> torch.Tensor: diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 4550bdb25629..8c5850b27d9b 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -114,6 +114,25 @@ def __init__( # currently be an MI300 series. self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) + if current_platform.is_rocm() and self.world_size > 1: + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_fused_allreduce_rmsnorm_supported(): + rocm_aiter_ops.initialize_aiter_allreduce( + self.cpu_group, self.device + ) + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + if aiter_ar is not None: + logger.info( + "AITER CustomAllreduce initialized for " + "fused AR+RMSNorm kernel" + ) + else: + logger.warning( + "AITER CustomAllreduce disabled; " + "fused AR+RMSNorm will use split kernels" + ) + if self.use_all2all: if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager @@ -236,6 +255,67 @@ def all_reduce(self, input_): torch.distributed.all_reduce(out, group=self.device_group) return out + def fused_allreduce_rmsnorm( + self, + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Fused allreduce + residual-add + RMSNorm. + + Uses AITER's fused kernel via the global singleton + (rocm_aiter_ops._CUSTOM_ALL_REDUCE) when available. + Falls back to split allreduce + rmsnorm otherwise. + """ + from vllm._aiter_ops import rocm_aiter_ops + + aiter_ca = rocm_aiter_ops.get_aiter_allreduce() + if aiter_ca is not None: + n = input_.shape[-1] + total_bytes = input_.numel() * input_.element_size() + # Constraints from the AITER fused AR+RMS HIP kernel + # (custom_all_reduce.cuh dispatchFusedAllReduceRMSNorm): + # + # n <= 16384: the kernel's rmsnorm step handles n_bytes + # (= n * sizeof(T)) in [1024, 32768]. For bf16 that + # means hidden_dim up to 16384. + # + # total_bytes < 64 MB (8 * 1024 * 8192): AITER's + # CustomAllreduce IPC buffer is 128 MB; 2-stage write + # mode halves usable capacity (max_size / 2). + # + # world_size != 6: the fused kernel only has template + # instantiations for world_size {2, 4, 8}. world_size=6 + # is unsupported by the AITER fused kernel (the base + # allreduce kernel supports it, but the fused variant + # does not). + can_use_fused = ( + n <= 16384 + and total_bytes < 8 * 1024 * 8192 + and self.world_size != 6 + ) + if ( + can_use_fused + and not aiter_ca.disabled + and aiter_ca.should_custom_ar(input_) + ): + # 1-stage is faster for small payloads. AITER uses + # ~80 KB (8 GPU) to ~160 KB (4 GPU) thresholds; + # 128 KB is a conservative middle ground. + use_1stage = total_bytes <= 128 * 1024 + result = aiter_ca.custom_fused_ar_rms( + input_, residual, weight, eps, use_1stage + ) + if result is not None: + return result + + ar_out = self.all_reduce(input_) + out, residual_out = rocm_aiter_ops.rms_norm2d_with_add( + ar_out, residual, weight, eps + ) + return out, residual_out + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): world_size = self.world_size pynccl_comm = self.pynccl_comm diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index cef902d9e4e5..0a068ef75f5a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -259,12 +259,44 @@ def patched_fused_scaled_matmul_reduce_scatter( ) +def fused_allreduce_rmsnorm( + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, + group_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._fused_allreduce_rmsnorm_out_place( + input_, residual, weight, eps + ) + + +def fused_allreduce_rmsnorm_fake( + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, + group_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(residual), torch.empty_like(residual) + + direct_register_custom_op( op_name="all_reduce", op_func=all_reduce, fake_impl=all_reduce_fake, ) +direct_register_custom_op( + op_name="fused_allreduce_rmsnorm", + op_func=fused_allreduce_rmsnorm, + fake_impl=fused_allreduce_rmsnorm_fake, +) + direct_register_custom_op( op_name="reduce_scatter", op_func=reduce_scatter, @@ -470,6 +502,7 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None # only cuda uses this function, # so we don't abstract it into the base class maybe_ca_context = nullcontext() + maybe_aiter_context = nullcontext() from vllm.distributed.device_communicators.cuda_communicator import ( CudaCommunicator, ) @@ -480,13 +513,21 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None if ca_comm is not None: maybe_ca_context = ca_comm.capture() # type: ignore - # ensure all initialization operations complete before attempting to - # capture the graph on another stream + from vllm._aiter_ops import rocm_aiter_ops + + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + if aiter_ar is not None: + maybe_aiter_context = aiter_ar.capture() # type: ignore + curr_stream = torch.cuda.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) - with torch.cuda.stream(stream), maybe_ca_context: + with ( + torch.cuda.stream(stream), + maybe_ca_context, + maybe_aiter_context, + ): yield graph_capture_context def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: @@ -518,6 +559,45 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: raise ValueError("No device communicator found") return self.device_communicator.all_reduce(input_) + def _fused_allreduce_rmsnorm_out_place( + self, + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.device_communicator is None: + raise ValueError("No device communicator found") + return self.device_communicator.fused_allreduce_rmsnorm( + input_, residual, weight, eps + ) + + def fused_allreduce_rmsnorm( + self, + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Fused allreduce + residual-add + RMSNorm. + + When world_size == 1, falls back to plain add + rmsnorm (no allreduce). + """ + if self.world_size == 1: + from vllm.model_executor.layers.layernorm import fused_add_rms_norm + + return fused_add_rms_norm(input_, residual, weight, eps) + + if self.use_custom_op_call: + return torch.ops.vllm.fused_allreduce_rmsnorm( + input_, residual, weight, eps, + group_name=self.unique_name, + ) + else: + return self._fused_allreduce_rmsnorm_out_place( + input_, residual, weight, eps + ) + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 9afc4c9c08d6..760a98ce51e3 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -120,6 +120,7 @@ def __init__( var_hidden_size: int | None = None, has_weight: bool = True, dtype: torch.dtype | None = None, + fused_allreduce: bool = False, ) -> None: super().__init__() @@ -128,6 +129,7 @@ def __init__( self.variance_size_override = ( None if var_hidden_size == hidden_size else var_hidden_size ) + self.fused_allreduce = fused_allreduce weight_dtype = dtype or torch.get_default_dtype() self.has_weight = has_weight self.weight = torch.ones(hidden_size, dtype=weight_dtype) @@ -334,7 +336,17 @@ def forward_hip( if self.variance_size_override is not None: return self.forward_native(x, residual) - if residual is not None: + add_residual = residual is not None + if self.fused_allreduce and add_residual: + from vllm.distributed.communication_op import ( + tensor_model_parallel_fused_allreduce_rmsnorm, + ) + + return tensor_model_parallel_fused_allreduce_rmsnorm( + x, residual, self.weight.data, self.variance_epsilon + ) + + if add_residual: return self.rocm_norm_func_with_add( x, residual, self.weight.data, self.variance_epsilon ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 17ddd5edeced..1b68c178904b 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -245,6 +245,7 @@ def __init__( super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() + self.skip_moe_allreduce = False self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) @@ -393,9 +394,26 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ) final_hidden_states = final_hidden_states[:num_tokens] elif self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states - ) + if self.skip_moe_allreduce: + if self.experts.must_reduce_shared_expert_outputs(): + logger.warning_once( + "Fused AR+RMSNorm: combine kernel already reduces " + "MoE output. This is unexpected with standard TP. " + "Falling back to normal allreduce for this layer." + ) + final_hidden_states = ( + self.experts + .maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + ) + else: + final_hidden_states = ( + self.experts + .maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + ) return final_hidden_states.view(num_tokens, hidden_dim) @@ -1093,6 +1111,25 @@ def __init__( topk_indices_buffer=topk_indices_buffer, ) + # MI355X optimization: move allreduce from projections into + # RMSNorm layers, enabling AITER's fused AR+RMSNorm kernel. + # - FP4/BF16: the fused op is preserved at compile time, executed + # as a single AITER kernel (custom_fused_ar_rms). + # - FP8: the fused op is decomposed at compile time into + # all_reduce + rmsnorm_with_add, then rmsnorm_with_add + fp8_quant + # are fused into one AITER op by RocmAiterRMSNormQuantFusionPass. + self.fuse_ar_rmsnorm = ( + rocm_aiter_ops.is_fused_allreduce_rmsnorm_supported() + and get_tensor_model_parallel_world_size() > 1 + ) + if self.fuse_ar_rmsnorm: + self.self_attn.o_proj.reduce_results = False + logger.debug( + "Layer %d: fused AR+RMSNorm enabled, " + "o_proj.reduce_results=False", + layer_idx, + ) + if ( config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace @@ -1104,17 +1141,27 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) + if self.fuse_ar_rmsnorm: + self.mlp.skip_moe_allreduce = True else: self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + reduce_results=not self.fuse_ar_rmsnorm, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.input_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + fused_allreduce=self.fuse_ar_rmsnorm and layer_idx > 0, + ) self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps + config.hidden_size, + eps=config.rms_norm_eps, + fused_allreduce=self.fuse_ar_rmsnorm, ) self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) @@ -1212,8 +1259,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.layers", ) + fuse_final_norm = ( + rocm_aiter_ops.is_fused_allreduce_rmsnorm_supported() + and get_tensor_model_parallel_world_size() > 1 + ) + if fuse_final_norm: + logger.info( + "Fused AllReduce+RMSNorm enabled for MI355X. " + "AllReduce moved from o_proj/MoE into RMSNorm layers." + ) if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + fused_allreduce=fuse_final_norm, + ) else: self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(