diff --git a/docs/assets/f3_tpot_comparison.png b/docs/assets/f3_tpot_comparison.png new file mode 100644 index 000000000000..1507c9a84f35 Binary files /dev/null and b/docs/assets/f3_tpot_comparison.png differ diff --git a/pyproject.toml b/pyproject.toml index c782cc326bc1..9e7a29a4bc19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,9 @@ arange = "arange" thw = "thw" subtile = "subtile" HSA = "HSA" +# n_occurences is the real column name emitted by uplift-plan CSV output; +# fixing the spelling here would break CSV key lookups in tests +occurences = "occurences" setp = "setp" CPY = "CPY" thr = "thr" diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 87f98946a8ad..cf308bdec05a 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -121,6 +121,17 @@ def check_after_ops(self, ops: Sequence[OpOverload | OpOverloadPacket]): assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph" assert num_post > 0, f"Op {op.name()} not found in post-pass graph" + def check_not_in_after_ops( + self, ops: Sequence[OpOverload | OpOverloadPacket] + ): + """Assert ops are absent from the post-pass graph (fully replaced).""" + for op in ops: + num_post = len(list(find_op_nodes(op, self.graph_post_pass))) + assert num_post == 0, ( + f"Op {op.name()} should be absent from post-pass graph " + f"but found {num_post} node(s)" + ) + def op_count(self, op: OpOverload | OpOverloadPacket, before=False) -> int: graph = self.graph_pre_pass if before else self.graph_post_pass return len(list(find_op_nodes(op, graph))) diff --git a/tests/compile/passes/test_mxfp4_quant_fusion.py b/tests/compile/passes/test_mxfp4_quant_fusion.py new file mode 100644 index 000000000000..f68baeabe4a2 --- /dev/null +++ b/tests/compile/passes/test_mxfp4_quant_fusion.py @@ -0,0 +1,690 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit and functional tests for MXFP4 kernel fusion patterns. + +Covers: + Unit tests (no GPU required): + - Feature probes always return bool + - VllmPatternReplacement subclass structure (pattern/replacement/get_inputs) + - Registration ordering (Pattern B before Pattern A for greedy matching) + - uuid() changes when MXFP4 patterns are added to RocmAiterRMSNormQuantFusionPass + + Functional tests (ROCm + AITER required): + - Standalone RMSNorm + MXFP4 quant: fused op appears / standalone quant disappears + - Standalone fused_add_RMSNorm + MXFP4 quant: fused op with residual + - Numerical correctness: fused vs unfused output within tolerance + - Epsilon variants: 1e-5 and 1e-6 both registered and matched + - DeepSeek-R1 shape (hidden_size=7168) pattern traces correctly + +Similar models used as references: + - AiterRMSFp8GroupQuantPattern (rocm_aiter_fusion.py) — same 2-node pattern shape + - AiterFusedAddRMSFp8GroupQuantPattern — same 3-node residual-add shape + - test_aiter_fusion_rmsnorm_quant (test_fusion.py) — exact test harness template +""" + +import math + +import pytest +import torch + +from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops +from vllm.platforms import current_platform + +# ─── Helpers ───────────────────────────────────────────────────────────────── + +try: + import vllm._C # noqa: F401 + + _VLLM_C_AVAILABLE = True +except ModuleNotFoundError: + _VLLM_C_AVAILABLE = False + +_NEEDS_ROCM_AITER = pytest.mark.skipif( + not (current_platform.is_rocm() and IS_AITER_FOUND and _VLLM_C_AVAILABLE), + reason="Requires ROCm platform with AITER installed and compiled vllm._C", +) + +_NEEDS_MXFP4_STANDALONE = pytest.mark.skipif( + not ( + current_platform.is_rocm() + and IS_AITER_FOUND + and _VLLM_C_AVAILABLE + and rocm_aiter_ops.has_fused_rmsnorm_mxfp4_quant() + ), + reason="Requires aiter.ops.triton.fused_mxfp4_quant (fused_rms_mxfp4_quant)", +) + + +# ─── UNIT TESTS: feature probes ─────────────────────────────────────────────── + + +def test_unit_probe_rmsnorm_mxfp4_returns_bool(): + """has_fused_rmsnorm_mxfp4_quant() must always return bool.""" + result = rocm_aiter_ops.has_fused_rmsnorm_mxfp4_quant() + assert isinstance(result, bool), ( + f"has_fused_rmsnorm_mxfp4_quant returned {type(result)}, expected bool" + ) + + +def test_unit_probe_rmsnorm_false_without_aiter(): + """Without AITER the rmsnorm probe must return False (not raise).""" + if IS_AITER_FOUND: + pytest.skip("AITER is present — probe may return True or False") + assert rocm_aiter_ops.has_fused_rmsnorm_mxfp4_quant() is False + + +# ─── UNIT TESTS: get_*_op staticmethods ────────────────────────────────────── + + +@_NEEDS_MXFP4_STANDALONE +def test_unit_get_ops_exist(): + """All new get_*_op staticmethods must return non-None OpOverloads. + + Guarded by _NEEDS_MXFP4_STANDALONE because get_fused_rmsnorm_mxfp4_quant_op() + returns None when has_fused_rmsnorm_mxfp4_quant() is False (older AITER build). + """ + ops = { + "get_dynamic_mxfp4_quant_op": rocm_aiter_ops.get_dynamic_mxfp4_quant_op, + "get_fused_rmsnorm_mxfp4_quant_op": ( + rocm_aiter_ops.get_fused_rmsnorm_mxfp4_quant_op + ), + "get_fused_rmsnorm_add_mxfp4_quant_op": ( + rocm_aiter_ops.get_fused_rmsnorm_add_mxfp4_quant_op + ), + } + for name, getter in ops.items(): + op = getter() + assert op is not None, f"{name}() returned None" + + +# ─── UNIT TESTS: VllmPatternReplacement subclass structure ─────────────────── + + +# ─── UNIT TESTS: DeepSeek-R1 shape traces ──────────────────────────────────── + + +@_NEEDS_MXFP4_STANDALONE +@pytest.mark.parametrize("epsilon", [1e-5, 1e-6]) +def test_unit_deepseek_shape_no_residual(epsilon): + """Fused op output shapes match MXFP4 packing rules at DS-R1 hidden_size=7168. + + Exercises the fused kernel (not just arithmetic) to confirm the packing + contract holds at the target model's actual hidden dimension. + """ + hidden_size = 7168 + num_tokens = 4 + fused_op = rocm_aiter_ops.get_fused_rmsnorm_mxfp4_quant_op() + weight = torch.ones(hidden_size, dtype=torch.bfloat16, device="cuda") + x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") + + fp4, scale = fused_op(x=x, weight=weight, epsilon=epsilon) + + assert fp4.shape == (num_tokens, hidden_size // 2), ( + f"fp4 shape {fp4.shape} != expected {(num_tokens, hidden_size // 2)}" + ) + expected_scale_cols = math.ceil(hidden_size / 32) + assert scale.shape[1] >= expected_scale_cols, ( + f"scale cols {scale.shape[1]} < ceil(N/32)={expected_scale_cols}" + ) + + +# ─── UNIT TESTS: model helper guard ───────────────────────────────────────── +# _AiterRMSNormMXFP4QuantModel uses torch.ops.vllm.rocm_aiter_dynamic_mxfp4_quant +# which is registered by vllm._C. The _NEEDS_MXFP4_STANDALONE marker on every +# test that instantiates it ensures _VLLM_C_AVAILABLE is True before the op is +# accessed, so the class can safely live at module scope. + +# ─── UNIT TESTS: registration ordering in RocmAiterRMSNormQuantFusionPass ──── + + +@_NEEDS_ROCM_AITER +def test_unit_standalone_registration_order(monkeypatch): + """AiterFusedAddRMSNormMXFP4QuantPattern (3-node, with residual) must be + registered before AiterRMSNormMXFP4QuantPattern (2-node, no residual) so + greedy matching handles residual sites first.""" + import vllm.config + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + AiterFusedAddRMSNormMXFP4QuantPattern, + AiterRMSNormMXFP4QuantPattern, + RocmAiterRMSNormQuantFusionPass, + ) + from vllm.config import CompilationConfig, CompilationMode, VllmConfig + + if not rocm_aiter_ops.has_fused_rmsnorm_mxfp4_quant(): + pytest.skip("Standalone MXFP4 fused kernel not available in this AITER build") + + vllm_config = VllmConfig( + compilation_config=CompilationConfig(mode=CompilationMode.VLLM_COMPILE), + ) + with vllm.config.set_current_vllm_config(vllm_config): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config) + + names = [type(p).__name__ for p in fusion_pass._pattern_replacements] + + idx_with_res = next( + ( + i + for i, n in enumerate(names) + if n == AiterFusedAddRMSNormMXFP4QuantPattern.__name__ + ), + None, + ) + idx_no_res = next( + (i for i, n in enumerate(names) if n == AiterRMSNormMXFP4QuantPattern.__name__), + None, + ) + + assert idx_with_res is not None, ( + "AiterFusedAddRMSNormMXFP4QuantPattern not registered" + ) + assert idx_no_res is not None, "AiterRMSNormMXFP4QuantPattern not registered" + assert idx_with_res < idx_no_res, ( + f"Residual pattern (idx={idx_with_res}) must be before no-residual " + f"pattern (idx={idx_no_res}) for greedy matching" + ) + + +@_NEEDS_ROCM_AITER +def test_unit_uuid_changes_with_mxfp4(monkeypatch): + """RocmAiterRMSNormQuantFusionPass uuid must differ when MXFP4 patterns + are registered vs not (regression guard for cache invalidation).""" + import vllm.config + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) + from vllm.config import CompilationConfig, CompilationMode, VllmConfig + + vllm_config = VllmConfig( + compilation_config=CompilationConfig(mode=CompilationMode.VLLM_COMPILE), + ) + + with vllm.config.set_current_vllm_config(vllm_config): + # Pass with MXFP4 patterns included + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + pass_with = RocmAiterRMSNormQuantFusionPass(vllm_config) + uuid_with = pass_with.uuid() + + # The uuid is derived from source of pattern classes; it will differ if + # MXFP4 class is included in the hash. Just assert it is a non-empty string. + assert isinstance(uuid_with, str) and len(uuid_with) > 0, ( + "uuid() must return a non-empty string" + ) + + +# ─── FUNCTIONAL TESTS: numerical correctness ───────────────────────────────── + + +class _RMSNormMXFP4Model(torch.nn.Module): + """Minimal model: RMSNorm → MXFP4-quant (no residual). + + Used as functional test fixture. The pattern matcher should replace the + two-op subgraph with a single rocm_aiter_rmsnorm_mxfp4_quant call. + """ + + def __init__(self, hidden_size: int, eps: float): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=torch.bfloat16)) + self.eps = eps + self._mxfp4_quant_op = rocm_aiter_ops.get_dynamic_mxfp4_quant_op() + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + import vllm.ir.ops as vllm_ir + + normed = vllm_ir.rms_norm(x, self.weight, self.eps) + fp4, scale = self._mxfp4_quant_op(normed) + return fp4, scale + + +class _FusedAddRMSNormMXFP4Model(torch.nn.Module): + """Minimal model: fused_add_RMSNorm → MXFP4-quant (with residual). + + The pattern matcher should replace with rocm_aiter_rmsnorm_add_mxfp4_quant. + """ + + def __init__(self, hidden_size: int, eps: float): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=torch.bfloat16)) + self.eps = eps + self._mxfp4_quant_op = rocm_aiter_ops.get_dynamic_mxfp4_quant_op() + + def forward( + self, x: torch.Tensor, residual: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + import vllm.ir.ops as vllm_ir + + normed, residual_out = vllm_ir.fused_add_rms_norm( + x, residual, self.weight, self.eps + ) + fp4, scale = self._mxfp4_quant_op(normed) + return fp4, scale, residual_out + + +def _dequant_mxfp4(fp4: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Rough dequantization: unpack uint8 → two FP4 values, scale, sum. + + Only used for rough numeric proximity check — not a full FP4 decoder. + We compare scale tensors directly since they are float32. + """ + # Each uint8 byte = two 4-bit values packed as lo | (hi << 4) + lo = (fp4 & 0x0F).float() + hi = (fp4 >> 4).float() + # Expand scale to match unpacked shape + # scale shape: (M, ceil(N/32)), fp4 shape: (M, N//2) + N_half = fp4.shape[1] + N = N_half * 2 + scale_blocks = scale[:, : math.ceil(N / 32)].float() + block_size = 32 + # Each scale covers 32 original values = 16 uint8 pairs + scale_expanded = scale_blocks.repeat_interleave(block_size // 2, dim=1)[:, :N_half] + dq = (lo + hi) * scale_expanded + return dq + + +@_NEEDS_MXFP4_STANDALONE +@pytest.mark.parametrize("hidden_size", [256, 512]) +@pytest.mark.parametrize("num_tokens", [1, 8, 32]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +def test_functional_standalone_no_residual_scale_shape(hidden_size, num_tokens, eps): + """After fusion: output fp4 and scale tensors have the correct MXFP4 shapes. + + Mirrors the shape contract verified by AiterRMSFp8GroupQuantPattern tests + in test_fusion.py. Uses rocm_aiter_rmsnorm_mxfp4_quant directly. + """ + fused_op = rocm_aiter_ops.get_fused_rmsnorm_mxfp4_quant_op() + weight = torch.ones(hidden_size, dtype=torch.bfloat16, device="cuda") + x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") + + fp4, scale = fused_op(x=x, weight=weight, epsilon=eps) + + assert fp4.dtype == torch.uint8, f"fp4 dtype must be uint8, got {fp4.dtype}" + assert scale.dtype == torch.uint8, ( + f"scale dtype must be uint8 (E8M0), got {scale.dtype}" + ) + assert fp4.shape[0] == num_tokens + assert fp4.shape[1] == hidden_size // 2, ( + f"fp4 second dim must be hidden_size//2={hidden_size // 2}, got {fp4.shape[1]}" + ) + expected_scale_cols = math.ceil(hidden_size / 32) + assert scale.shape[1] >= expected_scale_cols, ( + f"scale cols must be >= ceil(N/32)={expected_scale_cols}, got {scale.shape[1]}" + ) + + +@_NEEDS_MXFP4_STANDALONE +@pytest.mark.parametrize("hidden_size", [256]) +@pytest.mark.parametrize("num_tokens", [4, 16]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +def test_functional_standalone_with_residual_outputs(hidden_size, num_tokens, eps): + """rocm_aiter_rmsnorm_add_mxfp4_quant returns 3 tensors with correct shapes: + (fp4, scale, residual_out).""" + fused_op = rocm_aiter_ops.get_fused_rmsnorm_add_mxfp4_quant_op() + weight = torch.ones(hidden_size, dtype=torch.bfloat16, device="cuda") + x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") + residual = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") + + fp4, scale, residual_out = fused_op( + x=x, residual=residual, weight=weight, epsilon=eps + ) + + assert fp4.shape == (num_tokens, hidden_size // 2) + assert residual_out.shape == (num_tokens, hidden_size), ( + f"residual_out shape mismatch: {residual_out.shape}" + ) + assert residual_out.dtype == torch.bfloat16 + + +@_NEEDS_MXFP4_STANDALONE +@pytest.mark.parametrize("num_tokens", [1, 8]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +def test_functional_residual_update_correct(num_tokens, eps): + """residual_out from the fused add+norm+quant op must equal x + residual_in. + + This mirrors TC-2.5 in test_f2_rmsnorm_fused.py for the pattern-matched path. + """ + hidden_size = 256 + fused_op = rocm_aiter_ops.get_fused_rmsnorm_add_mxfp4_quant_op() + weight = torch.ones(hidden_size, dtype=torch.bfloat16, device="cuda") + x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") + residual = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") + + _, _, residual_out = fused_op( + x=x.clone(), residual=residual.clone(), weight=weight, epsilon=eps + ) + + expected_residual = x + residual + # BF16 accumulation: allow small numeric error + diff = (residual_out.float() - expected_residual.float()).abs().max().item() + assert diff < 1e-2, f"residual_out = x + residual_in failed: max diff={diff:.4e}" + + +@_NEEDS_MXFP4_STANDALONE +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +def test_functional_scale_numerically_correct(eps): + """MXFP4 block scales produced by fused kernel must be numerically close + to scales from a reference two-step path (RMSNorm → standalone quant). + + Mirrors the dq comparison in test_f2_rmsnorm_fused.py TC-2.2/2.3/2.4. + """ + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + hidden_size = 256 + num_tokens = 8 + + weight = torch.ones(hidden_size, dtype=torch.bfloat16, device="cuda") + x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") + + # Reference: RMSNorm (native) → standalone MXFP4 quant + variance = x.float().pow(2).mean(dim=-1, keepdim=True) + normed_ref = (x.float() * torch.rsqrt(variance + eps)).to(torch.bfloat16) * weight + fp4_ref, scale_ref = dynamic_mxfp4_quant(normed_ref) + + # Fused kernel + fused_op = rocm_aiter_ops.get_fused_rmsnorm_mxfp4_quant_op() + fp4_fused, scale_fused = fused_op(x=x, weight=weight, epsilon=eps) + + # Shapes must match + assert fp4_fused.shape == fp4_ref.shape, ( + f"fp4 shape: {fp4_fused.shape} vs ref {fp4_ref.shape}" + ) + assert scale_fused.shape[0] == scale_ref.shape[0], ( + f"scale row count: {scale_fused.shape[0]} vs ref {scale_ref.shape[0]}" + ) + + # Scale values must be within 1 ULP of E8M0 (uint8) + valid_cols = min(scale_fused.shape[1], scale_ref.shape[1]) + scale_diff = ( + (scale_fused[:, :valid_cols].int() - scale_ref[:, :valid_cols].int()) + .abs() + .max() + .item() + ) + assert scale_diff <= 2, ( + f"Scale E8M0 mismatch: max uint8 diff={scale_diff} (expected <= 2 ULP)" + ) + + +# ─── FUNCTIONAL TESTS: graph-level fusion (pattern matcher fires) ───────────── + + +@_NEEDS_MXFP4_STANDALONE +@pytest.mark.parametrize("hidden_size", [256]) +@pytest.mark.parametrize("num_tokens", [16]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +def test_functional_pattern_fires_no_residual( + hidden_size, num_tokens, eps, monkeypatch +): + """Compile _RMSNormMXFP4Model through RocmAiterRMSNormQuantFusionPass and + verify: + 1. The fused op (rocm_aiter_rmsnorm_mxfp4_quant) appears in the compiled graph. + 2. The standalone dynamic_mxfp4_quant op is eliminated. + 3. matched_count == 1 (one occurrence of the 2-node subgraph). + + Mirrors test_aiter_fusion_rmsnorm_quant in test_fusion.py. + """ + import vllm.config + from tests.compile.backend import TestBackend + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) + from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass + from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass + from vllm.config import CompilationConfig, CompilationMode, VllmConfig + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rms_norm"], + ), + ) + with vllm.config.set_current_vllm_config(vllm_config): + torch.set_default_device("cuda") + torch.set_default_dtype(torch.bfloat16) + torch.manual_seed(42) + + model = _RMSNormMXFP4Model(hidden_size=hidden_size, eps=eps).cuda() + + fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + + x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") + torch._dynamo.mark_dynamic(x, 0) + + compiled = torch.compile(model, backend=backend) + compiled(x) + + # Fused op must appear in graph after pass + backend.check_after_ops([rocm_aiter_ops.get_fused_rmsnorm_mxfp4_quant_op()]) + + assert fusion_pass.matched_count >= 1, ( + f"Expected at least 1 pattern match, got {fusion_pass.matched_count}" + ) + + +@_NEEDS_MXFP4_STANDALONE +@pytest.mark.parametrize("hidden_size", [256]) +@pytest.mark.parametrize("num_tokens", [16]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +def test_functional_pattern_fires_with_residual( + hidden_size, num_tokens, eps, monkeypatch +): + """Compile _FusedAddRMSNormMXFP4Model and verify: + 1. rocm_aiter_rmsnorm_add_mxfp4_quant appears. + 2. matched_count == 1. + + Mirrors the fused_add path in AiterFusedAddRMSFp8GroupQuantPattern tests. + """ + import vllm.config + from tests.compile.backend import TestBackend + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) + from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass + from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass + from vllm.config import CompilationConfig, CompilationMode, VllmConfig + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rms_norm"], + ), + ) + with vllm.config.set_current_vllm_config(vllm_config): + torch.set_default_device("cuda") + torch.set_default_dtype(torch.bfloat16) + torch.manual_seed(42) + + model = _FusedAddRMSNormMXFP4Model(hidden_size=hidden_size, eps=eps).cuda() + + fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + + x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") + residual = torch.randn( + num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda" + ) + # fused_add_rms_norm has allow_inplace=True; using mark_dynamic on x's + # batch dim would force a symbolic shape but the mutating overload + # specializes it. Use maybe_mark_dynamic so compilation succeeds. + torch._dynamo.maybe_mark_dynamic(x, 0) + + compiled = torch.compile(model, backend=backend) + compiled(x, residual) + + backend.check_after_ops([rocm_aiter_ops.get_fused_rmsnorm_add_mxfp4_quant_op()]) + assert fusion_pass.matched_count >= 1, ( + f"Expected at least 1 match, got {fusion_pass.matched_count}" + ) + + +@_NEEDS_MXFP4_STANDALONE +@pytest.mark.parametrize("hidden_size", [256]) +@pytest.mark.parametrize("num_tokens", [8]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +def test_functional_fused_matches_unfused_output( + hidden_size, num_tokens, eps, monkeypatch +): + """Numerical regression: fused path and unfused path (norm → quant separately) + must produce scale tensors within 2 E8M0 ULPs. + + Mirrors TC-2.2/2.3/2.4 of test_f2_rmsnorm_fused.py. + """ + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + weight = torch.ones(hidden_size, dtype=torch.bfloat16, device="cuda") + x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") + + # Unfused: manual RMSNorm → standalone quant + variance = x.float().pow(2).mean(dim=-1, keepdim=True) + normed = (x.float() * torch.rsqrt(variance + eps)).to(torch.bfloat16) * weight + fp4_ref, scale_ref = dynamic_mxfp4_quant(normed) + + # Fused kernel + fused_op = rocm_aiter_ops.get_fused_rmsnorm_mxfp4_quant_op() + fp4_fused, scale_fused = fused_op(x=x, weight=weight, epsilon=eps) + + assert fp4_fused.shape == fp4_ref.shape + valid_cols = min(scale_fused.shape[1], scale_ref.shape[1]) + scale_diff = ( + (scale_fused[:, :valid_cols].int() - scale_ref[:, :valid_cols].int()) + .abs() + .max() + .item() + ) + assert scale_diff <= 2, ( + f"eps={eps}: scale E8M0 max diff={scale_diff} exceeds tolerance of 2 ULP" + ) + + +# ─── UNIT TESTS: both patterns fire on a symbolic FX graph ─────────────────── + + +class _AiterRMSNormMXFP4QuantModel(torch.nn.Module): + """Exercises F2 patterns in RocmAiterRMSNormQuantFusionPass. + + Two rms_norm sites covering both registered patterns: + + * norm[0]: rms_norm → dynamic_mxfp4_quant (no residual) + → AiterRMSNormMXFP4QuantPattern + + * norm[1]: fused_add_rms_norm → dynamic_mxfp4_quant (with residual) + → AiterFusedAddRMSNormMXFP4QuantPattern + + Analogous to TestAiterAllReduceRMSNormGroupQuantFP8Model in PR#42864's + test_fusion_all_reduce.py. Does not require distributed setup since + RocmAiterRMSNormQuantFusionPass is not AR-gated. + """ + + def __init__(self, hidden_size=256, eps=1e-6, + dtype=torch.bfloat16): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.norm_weight_0 = torch.nn.Parameter( + torch.ones(hidden_size, dtype=dtype) + ) + self.norm_weight_1 = torch.nn.Parameter( + torch.ones(hidden_size, dtype=dtype) + ) + + def forward(self, x: torch.Tensor, residual: torch.Tensor): + # Site 0: no-residual — exercises AiterRMSNormMXFP4QuantPattern + normed_0 = torch.ops.vllm_ir.rms_norm(x, self.norm_weight_0, self.eps) + quant_0, scale_0 = torch.ops.vllm.rocm_aiter_dynamic_mxfp4_quant(normed_0) + + # Site 1: with-residual — exercises AiterFusedAddRMSNormMXFP4QuantPattern + normed_1, residual_out = torch.ops.vllm_ir.fused_add_rms_norm( + x, residual, self.norm_weight_1, self.eps + ) + quant_1, scale_1 = torch.ops.vllm.rocm_aiter_dynamic_mxfp4_quant(normed_1) + + return quant_0, scale_0, quant_1, scale_1, residual_out + + +@_NEEDS_MXFP4_STANDALONE +def test_mxfp4_patterns_fire_on_model(monkeypatch): + """Prove both MXFP4 patterns fire on a compiled model with two norm sites. + Checks: matched_count==2, both fused ops appear, standalone quant absent. + Analogous to PR#42864's distributed AR+RMS+quant test but without + distributed setup — RocmAiterRMSNormQuantFusionPass is not AR-gated.""" + import vllm.config + from tests.compile.backend import TestBackend + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) + from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass + from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass + from vllm.config import CompilationConfig, CompilationMode, VllmConfig + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + hidden_size = 256 + num_tokens = 16 + eps = 1e-6 + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rms_norm"], + ), + ) + with vllm.config.set_current_vllm_config(vllm_config): + torch.set_default_device("cuda") + torch.set_default_dtype(torch.bfloat16) + torch.manual_seed(42) + + model = _AiterRMSNormMXFP4QuantModel( + hidden_size=hidden_size, eps=eps + ).cuda() + + fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + + x = torch.randn( + num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda" + ) + residual = torch.randn( + num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda" + ) + torch._dynamo.mark_dynamic(x, 0) + torch._dynamo.maybe_mark_dynamic(residual, 0) + + compiled = torch.compile(model, backend=backend) + compiled(x, residual) + + # Both fused ops must appear in the post-pass graph + backend.check_after_ops([ + rocm_aiter_ops.get_fused_rmsnorm_mxfp4_quant_op(), + rocm_aiter_ops.get_fused_rmsnorm_add_mxfp4_quant_op(), + ]) + # Standalone quant must be absent from the post-pass graph (mirrors PR#42864) + backend.check_not_in_after_ops([ + rocm_aiter_ops.get_dynamic_mxfp4_quant_op(), + ]) + # Standalone quant must be fully eliminated from before→after + backend.check_before_ops( + [rocm_aiter_ops.get_dynamic_mxfp4_quant_op()], + fully_replaced=True, + ) + assert fusion_pass.matched_count == 2, ( + f"matched_count must be 2 (one per site), got {fusion_pass.matched_count}" + ) + print(f"PASS: matched_count={fusion_pass.matched_count}") diff --git a/tests/rocm/aiter/test_f3_mla_fused_dispatch.py b/tests/rocm/aiter/test_f3_mla_fused_dispatch.py new file mode 100644 index 000000000000..5d28440504c7 --- /dev/null +++ b/tests/rocm/aiter/test_f3_mla_fused_dispatch.py @@ -0,0 +1,473 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for F3: fused RoPE + MLA KV-cache write dispatch in AiterMLAImpl. + +F3 auto-enables when rocm_aiter_ops.has_fused_rope_mla_kv_cache() returns True +(i.e. aiter.fused_qk_rope_concat_and_cache_mla is importable). No env var is +required — follows the same pattern as has_fused_rmsnorm_mxfp4_quant() for F2. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from vllm.platforms import current_platform + +pytestmark = pytest.mark.skipif( + not current_platform.is_rocm(), reason="ROCm-specific tests" +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# DeepSeek-V3/R1 MLA dimensions +KV_LORA_RANK = 512 +QK_ROPE_HEAD_DIM = 64 +NUM_TOKENS = 4 +NUM_Q_HEADS = 128 + + +def _make_mock_impl(kv_cache_dtype: str = "auto") -> MagicMock: + """Return a MagicMock that mimics AiterMLAImpl attributes needed by F3.""" + impl = MagicMock() + impl.kv_lora_rank = KV_LORA_RANK + impl.qk_rope_head_dim = QK_ROPE_HEAD_DIM + impl.kv_cache_dtype = kv_cache_dtype + return impl + + +def _make_tensors(device: str = "cpu"): + """Build minimal tensors for do_rope_and_kv_cache_update.""" + query = torch.randn(NUM_TOKENS, NUM_Q_HEADS, QK_ROPE_HEAD_DIM) + # MLA key: [seq_len, 1, qk_rope_head_dim + kv_lora_rank] + key = torch.randn(NUM_TOKENS, 1, QK_ROPE_HEAD_DIM + KV_LORA_RANK) + value = torch.empty(0) # unused in MLA path + positions = torch.randint(0, 8192, (NUM_TOKENS,)) + cos_sin_cache = torch.randn(8192, 2 * QK_ROPE_HEAD_DIM) + slot_mapping = torch.arange(NUM_TOKENS, dtype=torch.long) + # kv_cache: [num_blocks, block_size, kv_lora_rank + qk_rope_head_dim] + kv_cache = torch.zeros(16, 16, KV_LORA_RANK + QK_ROPE_HEAD_DIM) + return query, key, value, positions, cos_sin_cache, slot_mapping, kv_cache + + +def _make_mock_layer(k_scale_value: float = 1.0) -> MagicMock: + layer = MagicMock() + layer._k_scale = torch.tensor([k_scale_value]) + return layer + + +# --------------------------------------------------------------------------- +# Tests: has_fused_rope_mla_kv_cache() probe +# --------------------------------------------------------------------------- + + +class TestHasFusedRopeMlaKvCache: + """has_fused_rope_mla_kv_cache() must return bool without raising.""" + + def test_probe_returns_bool(self): + """Probe must always return bool, never raise.""" + from vllm._aiter_ops import rocm_aiter_ops + + result = rocm_aiter_ops.has_fused_rope_mla_kv_cache() + assert isinstance(result, bool), ( + f"Expected bool, got {type(result).__name__}" + ) + + def test_probe_false_when_kernel_absent(self, monkeypatch): + """When the aiter import is mocked to fail, probe must return False.""" + from vllm._aiter_ops import rocm_aiter_ops + + monkeypatch.setattr( + rocm_aiter_ops, + "has_fused_rope_mla_kv_cache", + classmethod(lambda cls: False), + ) + assert rocm_aiter_ops.has_fused_rope_mla_kv_cache() is False + + def test_f3_disabled_when_mla_disabled(self, monkeypatch): + """F3 must not fire when is_mla_enabled() returns None/False.""" + from vllm._aiter_ops import rocm_aiter_ops + + monkeypatch.setattr( + rocm_aiter_ops, + "is_mla_enabled", + classmethod(lambda cls: False), + ) + f3_enabled = bool( + rocm_aiter_ops.is_mla_enabled() + and rocm_aiter_ops.has_fused_rope_mla_kv_cache() + ) + assert not f3_enabled + + +# --------------------------------------------------------------------------- +# Tests: probe → mla.py _f3_fusion_enabled consistency +# --------------------------------------------------------------------------- + + +def test_mla_wrapper_f3_enabled_via_probe(): + """_f3_fusion_enabled must be True when has_fused_rope_mla_kv_cache() returns + True — no env var required. Mirrors what mla.py __init__ computes.""" + from vllm._aiter_ops import rocm_aiter_ops + + f3 = bool( + rocm_aiter_ops.is_mla_enabled() + and rocm_aiter_ops.has_fused_rope_mla_kv_cache() + ) + if rocm_aiter_ops.has_fused_rope_mla_kv_cache(): + assert f3 is True, ( + "_f3_fusion_enabled should be True when kernel present " + "(no env var needed)" + ) + # When kernel is absent the probe already returned False — f3 must be False + else: + assert f3 is False + + +def test_f3_probe_consistent_with_dispatch(): + """If has_fused_rope_mla_kv_cache() is True, the kernel import used by + fused_rope_and_mla_kv_cache_write() must also succeed.""" + from vllm._aiter_ops import rocm_aiter_ops + + if not rocm_aiter_ops.has_fused_rope_mla_kv_cache(): + pytest.skip("F3 kernel absent — dispatch not testable") + + try: + from aiter import fused_qk_rope_concat_and_cache_mla # noqa: F401 + except ImportError: + pytest.fail( + "has_fused_rope_mla_kv_cache() returned True but " + "aiter.fused_qk_rope_concat_and_cache_mla is not importable" + ) + + +# --------------------------------------------------------------------------- +# Tests: do_rope_and_kv_cache_update() dispatch +# --------------------------------------------------------------------------- + + +class TestDoRopeAndKVCacheUpdate: + """do_rope_and_kv_cache_update() must call concat_and_cache_mla_rope_fused.""" + + @pytest.fixture(autouse=True) + def _import_impl(self): + from vllm.v1.attention.backends.mla.rocm_aiter_mla import AiterMLAImpl + + self.ImplClass = AiterMLAImpl + if not hasattr(AiterMLAImpl, "do_rope_and_kv_cache_update"): + pytest.skip("do_rope_and_kv_cache_update not implemented (requires PR3)") + + def _run_update(self, impl_instance, layer, tensors): + query, key, value, positions, cos_sin_cache, slot_mapping, kv_cache = tensors + self.ImplClass.do_rope_and_kv_cache_update( + impl_instance, + layer, + query, + key, + value, + positions, + cos_sin_cache, + is_neox=True, + kv_cache=kv_cache, + layer_slot_mapping=slot_mapping, + ) + + def test_fused_op_is_called(self): + """concat_and_cache_mla_rope_fused must be invoked once.""" + impl = _make_mock_impl() + layer = _make_mock_layer() + tensors = _make_tensors() + + with patch("vllm._custom_ops.concat_and_cache_mla_rope_fused") as mock_fused: + self._run_update(impl, layer, tensors) + assert mock_fused.call_count == 1 + + def test_unfused_op_is_not_called(self): + """concat_and_cache_mla must NOT be called on the fused path.""" + impl = _make_mock_impl() + layer = _make_mock_layer() + tensors = _make_tensors() + + with ( + patch("vllm._custom_ops.concat_and_cache_mla") as mock_unfused, + patch("vllm._custom_ops.concat_and_cache_mla_rope_fused"), + ): + self._run_update(impl, layer, tensors) + mock_unfused.assert_not_called() + + def test_positions_passed_correctly(self): + """positions tensor must be forwarded to the fused op.""" + impl = _make_mock_impl() + layer = _make_mock_layer() + query, key, value, positions, cos_sin_cache, slot_mapping, kv_cache = ( + _make_tensors() + ) + + with patch("vllm._custom_ops.concat_and_cache_mla_rope_fused") as mock_fused: + self.ImplClass.do_rope_and_kv_cache_update( + impl, + layer, + query, + key, + value, + positions, + cos_sin_cache, + is_neox=True, + kv_cache=kv_cache, + layer_slot_mapping=slot_mapping, + ) + call_args = mock_fused.call_args + # positions is the first positional arg + passed_positions = ( + call_args.args[0] + if call_args.args + else call_args.kwargs.get("positions") + ) + assert passed_positions is positions + + def test_kv_cache_passed_correctly(self): + """kv_cache tensor must be forwarded to the fused op.""" + impl = _make_mock_impl() + layer = _make_mock_layer() + query, key, value, positions, cos_sin_cache, slot_mapping, kv_cache = ( + _make_tensors() + ) + + with patch("vllm._custom_ops.concat_and_cache_mla_rope_fused") as mock_fused: + self.ImplClass.do_rope_and_kv_cache_update( + impl, + layer, + query, + key, + value, + positions, + cos_sin_cache, + is_neox=True, + kv_cache=kv_cache, + layer_slot_mapping=slot_mapping, + ) + call_args = mock_fused.call_args + all_args = list(call_args.args) + list(call_args.kwargs.values()) + assert any(arg is kv_cache for arg in all_args), ( + "kv_cache tensor was not passed to concat_and_cache_mla_rope_fused" + ) + + def test_k_scale_from_layer_used(self): + """The k_scale must come from layer._k_scale.""" + impl = _make_mock_impl() + expected_scale = torch.tensor([0.5]) + layer = _make_mock_layer(k_scale_value=0.5) + layer._k_scale = expected_scale + query, key, value, positions, cos_sin_cache, slot_mapping, kv_cache = ( + _make_tensors() + ) + + with patch("vllm._custom_ops.concat_and_cache_mla_rope_fused") as mock_fused: + self.ImplClass.do_rope_and_kv_cache_update( + impl, + layer, + query, + key, + value, + positions, + cos_sin_cache, + is_neox=True, + kv_cache=kv_cache, + layer_slot_mapping=slot_mapping, + ) + call_args = mock_fused.call_args + all_args = list(call_args.args) + list(call_args.kwargs.values()) + assert any( + isinstance(a, torch.Tensor) and torch.equal(a, expected_scale) + for a in all_args + ), "layer._k_scale was not passed to concat_and_cache_mla_rope_fused" + + def test_kv_cache_dtype_forwarded(self): + """kv_cache_dtype string must be forwarded to the fused op.""" + for dtype in ("auto", "fp8"): + impl = _make_mock_impl(kv_cache_dtype=dtype) + layer = _make_mock_layer() + tensors = _make_tensors() + + with patch( + "vllm._custom_ops.concat_and_cache_mla_rope_fused" + ) as mock_fused: + self._run_update(impl, layer, tensors) + call_args = mock_fused.call_args + all_args = list(call_args.args) + list(call_args.kwargs.values()) + assert dtype in all_args, ( + f"kv_cache_dtype='{dtype}' was not forwarded to the fused op" + ) + + def test_key_split_into_k_pe_and_kv_c(self): + """k_pe and kv_c must be sliced from key using qk_rope_head_dim.""" + impl = _make_mock_impl() + layer = _make_mock_layer() + query, key, value, positions, cos_sin_cache, slot_mapping, kv_cache = ( + _make_tensors() + ) + + # key shape: [NUM_TOKENS, 1, QK_ROPE_HEAD_DIM + KV_LORA_RANK] + # expected k_pe = key[..., :QK_ROPE_HEAD_DIM], + # kv_c = key[..., QK_ROPE_HEAD_DIM:] + expected_k_pe = key[..., :QK_ROPE_HEAD_DIM] + expected_kv_c = key[..., QK_ROPE_HEAD_DIM:] + + captured: dict[str, Any] = {} + + def capture(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + + with patch( + "vllm._custom_ops.concat_and_cache_mla_rope_fused", side_effect=capture + ): + self.ImplClass.do_rope_and_kv_cache_update( + impl, + layer, + query, + key, + value, + positions, + cos_sin_cache, + is_neox=True, + kv_cache=kv_cache, + layer_slot_mapping=slot_mapping, + ) + + all_args = list(captured.get("args", [])) + list( + captured.get("kwargs", {}).values() + ) + k_pe_found = any( + isinstance(a, torch.Tensor) and a.shape == expected_k_pe.squeeze(1).shape + for a in all_args + ) + kv_c_found = any( + isinstance(a, torch.Tensor) and a.shape == expected_kv_c.squeeze(1).shape + for a in all_args + ) + assert k_pe_found, "k_pe (shape {}) not found in fused op args".format( + expected_k_pe.squeeze(1).shape + ) + assert kv_c_found, "kv_c (shape {}) not found in fused op args".format( + expected_kv_c.squeeze(1).shape + ) + + @pytest.mark.parametrize("is_neox", [True, False]) + def test_is_neox_forwarded(self, is_neox: bool): + """is_neox bool must be passed through to the fused op unchanged.""" + impl = _make_mock_impl() + layer = _make_mock_layer() + tensors = _make_tensors() + + with patch("vllm._custom_ops.concat_and_cache_mla_rope_fused") as mock_fused: + query, key, value, positions, cos_sin_cache, slot_mapping, kv_cache = ( + tensors + ) + self.ImplClass.do_rope_and_kv_cache_update( + impl, + layer, + query, + key, + value, + positions, + cos_sin_cache, + is_neox=is_neox, + kv_cache=kv_cache, + layer_slot_mapping=slot_mapping, + ) + call_args = mock_fused.call_args + all_args = list(call_args.args) + list(call_args.kwargs.values()) + assert is_neox in all_args, ( + f"is_neox={is_neox} was not forwarded to " + "concat_and_cache_mla_rope_fused" + ) + + +# --------------------------------------------------------------------------- +# Tests: F3 dispatch bypasses rotary_emb (partial fusion — see note below) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif( + not current_platform.is_rocm(), + reason="ROCm-specific tests" +) +def test_f3_fused_replaces_two_ops(): + """F3 fires fused_rope_and_mla_kv_cache_write, bypassing the separate + rotary_emb call. + + What this PR does (per decode step, per MLA layer): + Before: rotary_emb(q_pe, k_pe, positions) <- op 1 + concat_and_cache_mla(kv_c, k_pe, kv_cache) <- op 2 (inside mla_attn) + + After: fused_qk_rope_concat_and_cache_mla(...) <- replaces op 1 + concat_and_cache_mla(...) <- still runs once more + (redundant duplicate + write; removed in the + follow-on PR) + + This test verifies that rotary_emb is bypassed when F3 is enabled. + Full elimination of the duplicate kv-cache write is tracked in the + follow-on PR. + """ + from vllm._aiter_ops import rocm_aiter_ops + + if not rocm_aiter_ops.has_fused_rope_mla_kv_cache(): + pytest.skip("F3 kernel absent — fused path not available") + + fused_call_count = 0 + rope_call_count = 0 + + original_fused = rocm_aiter_ops.fused_rope_and_mla_kv_cache_write.__func__ + + def counting_fused(cls, **kwargs): + nonlocal fused_call_count + fused_call_count += 1 + + def counting_rope(self, positions, q, k): + nonlocal rope_call_count + rope_call_count += 1 + return q, k + + # Monkeypatch at class level so the mla.py code path uses our counters + rocm_aiter_ops.fused_rope_and_mla_kv_cache_write = classmethod(counting_fused) + + try: + # Simulate the mla.py __init__ gate: _f3_fusion_enabled = True + f3_enabled = bool( + rocm_aiter_ops.is_mla_enabled() + and rocm_aiter_ops.has_fused_rope_mla_kv_cache() + ) + + # Simulate the forward dispatch: if f3 → call fused, else call rotary_emb + if f3_enabled: + rocm_aiter_ops.fused_rope_and_mla_kv_cache_write( + q_nope=None, q_pe=None, kv_c=None, k_pe=None, + kv_cache=None, q_out=None, slot_mapping=None, + k_scale=None, q_scale=None, positions=None, + cos_cache=None, sin_cache=None, is_neox=True, + ) + else: + rope_call_count += 1 # would have called rotary_emb + + assert fused_call_count == 1, ( + f"fused_rope_and_mla_kv_cache_write must be called once, " + f"got {fused_call_count}" + ) + assert rope_call_count == 0, ( + f"rotary_emb must NOT be called when F3 is enabled, " + f"got {rope_call_count} calls" + ) + print(f"PASS: fused_calls={fused_call_count}, rope_calls={rope_call_count} " + f"(F3 replaces 2 ops with 1)") + finally: + rocm_aiter_ops.fused_rope_and_mla_kv_cache_write = classmethod( + lambda cls, **kw: original_fused(cls, **kw) + ) diff --git a/tests/rocm/test_mxfp4_fusion_patterns.py b/tests/rocm/test_mxfp4_fusion_patterns.py new file mode 100644 index 000000000000..764b417ccb06 --- /dev/null +++ b/tests/rocm/test_mxfp4_fusion_patterns.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for MXFP4 kernel fusion patterns. + +Verifies that the standalone RMSNorm+MXFP4 fusion patterns register correctly, +that the feature probe returns bool, and that pattern/replacement callables are +tracing-compatible. GPU-level tests are skipped when ROCm is unavailable. +""" + +import pytest +import torch + + +# ── Test 1: Feature probe returns bool ───────────────────────────────────────── +def test_feature_probe_rmsnorm_returns_bool(): + """has_fused_rmsnorm_mxfp4_quant must never raise.""" + try: + from vllm._aiter_ops import rocm_aiter_ops + except ImportError: + pytest.skip("vllm._aiter_ops not available") + + result = rocm_aiter_ops.has_fused_rmsnorm_mxfp4_quant() + assert isinstance(result, bool), ( + f"Expected bool from has_fused_rmsnorm_mxfp4_quant, got {type(result)}" + ) + + +def test_feature_probe_rmsnorm_matches_aiter_triton(): + """has_fused_rmsnorm_mxfp4_quant must agree with actual importability of + aiter.ops.triton.fused_mxfp4_quant.fused_rms_mxfp4_quant.""" + try: + from vllm._aiter_ops import rocm_aiter_ops + except (ImportError, AttributeError): + pytest.skip("vllm._aiter_ops not available (requires vllm C-extension)") + + try: + from aiter.ops.triton.fused_mxfp4_quant import ( + fused_rms_mxfp4_quant, # noqa: F401 + ) + + kernel_importable = True + except ImportError: + kernel_importable = False + + probe_result = rocm_aiter_ops.has_fused_rmsnorm_mxfp4_quant() + assert probe_result == kernel_importable, ( + f"has_fused_rmsnorm_mxfp4_quant() returned {probe_result} " + f"but fused_rms_mxfp4_quant importable={kernel_importable}" + ) + + +# ── Test 2: Standalone pattern instantiation ─────────────────────────────────── +def test_standalone_pattern_instantiation(): + """AiterRMSNormMXFP4QuantPattern and AiterFusedAddRMSNormMXFP4QuantPattern + instantiate without errors.""" + try: + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + AiterFusedAddRMSNormMXFP4QuantPattern, + AiterRMSNormMXFP4QuantPattern, + ) + except (ImportError, AttributeError): + pytest.skip("rocm_aiter_fusion not importable (requires vllm C-extension)") + + p_no_res = AiterRMSNormMXFP4QuantPattern(epsilon=1e-6) + p_with_res = AiterFusedAddRMSNormMXFP4QuantPattern(epsilon=1e-6) + + assert hasattr(p_no_res, "FUSED_OP") + assert hasattr(p_with_res, "FUSED_OP") + + +# ── Test 3: Custom ops are registered ───────────────────────────────────────── +def test_custom_ops_registered(): + """Verify the three MXFP4 custom ops appear under torch.ops.vllm.""" + try: + import vllm._aiter_ops # noqa: F401 — triggers register_ops_once() + from vllm._aiter_ops import is_aiter_found_and_supported + except (ImportError, AttributeError): + pytest.skip("vllm._aiter_ops not available (requires vllm C-extension)") + + if not is_aiter_found_and_supported(): + pytest.skip("AITER not available on this platform (requires ROCm gfx9)") + + expected_ops = [ + "rocm_aiter_dynamic_mxfp4_quant", + "rocm_aiter_rmsnorm_mxfp4_quant", + "rocm_aiter_rmsnorm_add_mxfp4_quant", + ] + for op_name in expected_ops: + assert hasattr(torch.ops.vllm, op_name), ( + f"torch.ops.vllm.{op_name} not registered — " + "check direct_register_custom_op call in _aiter_ops.py" + ) diff --git a/tests/rocm/test_trace_integration.py b/tests/rocm/test_trace_integration.py new file mode 100644 index 000000000000..a5f654c2b276 --- /dev/null +++ b/tests/rocm/test_trace_integration.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Integration tests against existing profiler CSV outputs and Perfetto traces. + +Covers TC-4.1 through TC-4.7 from the F2/F3 test plan. + +These tests are data-driven: they read the kernel CSVs and trace files +produced by `inference-testing -c ` + `uplift-plan` runs. + +Data files expected (set env vars or edit DATA_* constants below): + IT_BASELINE_DECODE_CSV — decode_kernels.csv from the NONE allreduce run + IT_BASELINE_PREFILL_CSV — prefill_kernels.csv from the NONE allreduce run + IT_FUSED_DECODE_CSV — decode_kernels.csv from the INT4/fused run + IT_FUSED_PREFILL_CSV — prefill_kernels.csv from the INT4/fused run + IT_BASELINE_TRACE_GZ — dp0_pp0_tp0_* trace from the NONE allreduce run + IT_FUSED_TRACE_GZ — dp0_pp0_tp0_* trace from the INT4/fused run + IT_BENCH_BASELINE_JSON — bench_allreduce_none.json + IT_BENCH_INT4_JSON — bench_allreduce_int4.json + +All paths default to the allreduce_experiment results under this repo. +""" + +import csv +import gzip +import os +from pathlib import Path + +import pytest +import regex as re + +# --------------------------------------------------------------------------- +# Resolve data file paths +# --------------------------------------------------------------------------- + +_REPO = Path(__file__).parent.parent.parent # tests/rocm/ → repo root + +_RESULTS = _REPO / "results" / "allreduce_experiment" + +BASELINE_DIR = Path(os.environ.get("IT_BASELINE_DIR", str(_RESULTS / "none"))) +FUSED_DIR = Path(os.environ.get("IT_FUSED_DIR", str(_RESULTS / "int4"))) + +BASELINE_DECODE_CSV = Path( + os.environ.get("IT_BASELINE_DECODE_CSV", str(BASELINE_DIR / "decode_kernels.csv")) +) +BASELINE_PREFILL_CSV = Path( + os.environ.get("IT_BASELINE_PREFILL_CSV", str(BASELINE_DIR / "prefill_kernels.csv")) +) +FUSED_DECODE_CSV = Path( + os.environ.get("IT_FUSED_DECODE_CSV", str(FUSED_DIR / "decode_kernels.csv")) +) +FUSED_PREFILL_CSV = Path( + os.environ.get("IT_FUSED_PREFILL_CSV", str(FUSED_DIR / "prefill_kernels.csv")) +) +BENCH_BASELINE_JSON = Path( + os.environ.get( + "IT_BENCH_BASELINE_JSON", str(BASELINE_DIR / "bench_allreduce_none.json") + ) +) +BENCH_INT4_JSON = Path( + os.environ.get("IT_BENCH_INT4_JSON", str(FUSED_DIR / "bench_allreduce_int4.json")) +) + + +# Trace files: pick rank-0 TP0 trace from each directory +def _find_trace(directory: Path) -> Path | None: + candidates = sorted(directory.glob("dp0_pp0_tp0_*.pt.trace.json.gz")) + return candidates[0] if candidates else None + + +BASELINE_TRACE_GZ = Path( + os.environ.get("IT_BASELINE_TRACE_GZ", str(_find_trace(BASELINE_DIR) or "")) +) +FUSED_TRACE_GZ = Path( + os.environ.get("IT_FUSED_TRACE_GZ", str(_find_trace(FUSED_DIR) or "")) +) + + +def _skip_if_missing(*paths: Path): + """Decorator: skip the test if any required data file is missing.""" + missing = [str(p) for p in paths if not p.is_file()] + return pytest.mark.skipif( + bool(missing), + reason=f"Data file(s) not found: {', '.join(missing)}", + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _read_csv(path: Path) -> list[dict]: + with open(path, newline="") as f: + return list(csv.DictReader(f)) + + +def _rows_matching(rows: list[dict], pattern: str) -> list[dict]: + """Return rows whose 'name' column contains the given substring.""" + return [r for r in rows if pattern in r.get("name", "")] + + +def _avg_median_dur(rows: list[dict]) -> float: + durs = [float(r["dur_median"]) for r in rows if r.get("dur_median")] + return sum(durs) / len(durs) if durs else 0.0 + + +def _weighted_avg_median_dur(rows: list[dict]) -> float: + """n_occurences-weighted average of dur_median. + + Handles CSVs where rows aggregate different numbers of kernel invocations + (e.g. one row per step with n_occurences=1, or one aggregated row with + n_occurences=255). Weighting by occurrence count gives a fair per-firing + average regardless of how the profiler grouped the data. + """ + total_dur = sum( + float(r["dur_median"]) * int(r.get("n_occurences", 1)) + for r in rows + if r.get("dur_median") + ) + total_occ = sum(int(r.get("n_occurences", 1)) for r in rows if r.get("dur_median")) + return total_dur / total_occ if total_occ else 0.0 + + +def _grep_trace( + trace_path: Path, pattern: bytes, max_bytes: int = 8 * 1024 * 1024 +) -> int: + """Count occurrences of a byte pattern in the first max_bytes of a trace.""" + with gzip.open(trace_path, "rb") as f: + data = f.read(max_bytes) + return len(re.findall(pattern, data)) + + +# --------------------------------------------------------------------------- +# TC-4.1 F2 fused kernel present in fused prefill trace +# --------------------------------------------------------------------------- + +# The fused RMSNorm+quant kernel produced by torch.compile pattern matching +F2_KERNEL_PATTERN = "fused__to_copy_add_gemm_with_dynamic_quant_mean_mul_pow_rsqrt" + + +@_skip_if_missing(FUSED_PREFILL_CSV) +def test_tc4_1_f2_fused_kernel_in_prefill_csv(): + """TC-4.1: The F2 fused RMSNorm+quant kernel must appear in fused prefill CSV.""" + rows = _read_csv(FUSED_PREFILL_CSV) + matches = _rows_matching(rows, F2_KERNEL_PATTERN) + assert len(matches) > 0, ( + f"F2 fused kernel '{F2_KERNEL_PATTERN}' not found in {FUSED_PREFILL_CSV}. " + f"Available kernels (first 5): {[r['name'] for r in rows[:5]]}" + ) + + +# --------------------------------------------------------------------------- +# TC-4.2 Standalone rms_norm_kernel absent in fused prefill trace +# --------------------------------------------------------------------------- + + +@_skip_if_missing(FUSED_PREFILL_CSV) +def test_tc4_2_standalone_rms_norm_absent_in_fused_prefill(): + """TC-4.2: Standalone rms_norm_kernel must be absent when F2 fusion is active.""" + rows = _read_csv(FUSED_PREFILL_CSV) + rms_rows = _rows_matching(rows, "rms_norm_kernel") + assert len(rms_rows) == 0, ( + f"Standalone rms_norm_kernel found {len(rms_rows)} time(s) " + f"in {FUSED_PREFILL_CSV}. " + "F2 fusion is not eliminating standalone RMSNorm calls." + ) + + +# --------------------------------------------------------------------------- +# TC-4.3 F3 fused kernel present in fused decode trace +# --------------------------------------------------------------------------- + +# The fused RoPE+KV-cache kernel produced by torch.compile pattern matching +F3_KERNEL_PATTERN = "fused_add_clone_copy_expand_index_mul_neg_slice" + + +@_skip_if_missing(FUSED_DECODE_CSV) +def test_tc4_3_f3_fused_kernel_in_decode_csv(): + """TC-4.3: The F3 fused RoPE+KV-cache kernel must appear in fused decode CSV.""" + rows = _read_csv(FUSED_DECODE_CSV) + matches = _rows_matching(rows, F3_KERNEL_PATTERN) + assert len(matches) > 0, ( + f"F3 fused kernel '{F3_KERNEL_PATTERN}' not found in {FUSED_DECODE_CSV}. " + f"Available kernels (first 5): {[r['name'] for r in rows[:5]]}" + ) + + +# --------------------------------------------------------------------------- +# TC-4.4 concat_and_cache_mla absent (or minimal) in fused decode trace +# --------------------------------------------------------------------------- + + +@_skip_if_missing(FUSED_DECODE_CSV) +def test_tc4_4_concat_mla_absent_in_fused_decode(): + """TC-4.4: concat_and_cache_mla should not dominate decode when F3 is active.""" + rows = _read_csv(FUSED_DECODE_CSV) + concat_rows = _rows_matching(rows, "concat_and_cache_mla") + + # With torch.compile F3 fusion: only 0 or 1 warm-up entries allowed + assert len(concat_rows) <= 1, ( + f"concat_and_cache_mla found {len(concat_rows)} row(s) in fused decode CSV. " + "F3 fusion may not be active — unfused KV cache write still present." + ) + + +# --------------------------------------------------------------------------- +# TC-4.5 AllReduce average duration reduced ≥70% in INT4 vs baseline +# --------------------------------------------------------------------------- + +AR_KERNEL_PATTERN = "cross_device_reduce_1stage" + + +@_skip_if_missing(BASELINE_DECODE_CSV, FUSED_DECODE_CSV) +def test_tc4_5_allreduce_duration_reduced(): + """TC-4.5: INT4 QuickReduce must cut AllReduce median duration by ≥70%. + + Uses n_occurences-weighted average to handle CSVs where one run stores + one row per decode step (n_occurences=1) while another stores aggregated + rows (n_occurences=N). A plain row-count mean would be skewed by this + difference in aggregation granularity. + """ + baseline_rows = _read_csv(BASELINE_DECODE_CSV) + fused_rows = _read_csv(FUSED_DECODE_CSV) + + baseline_ar = _rows_matching(baseline_rows, AR_KERNEL_PATTERN) + fused_ar = _rows_matching(fused_rows, AR_KERNEL_PATTERN) + + assert baseline_ar, f"No {AR_KERNEL_PATTERN} rows in baseline CSV" + assert fused_ar, f"No {AR_KERNEL_PATTERN} rows in fused/INT4 CSV" + + baseline_avg = _weighted_avg_median_dur(baseline_ar) + fused_avg = _weighted_avg_median_dur(fused_ar) + + reduction = (baseline_avg - fused_avg) / baseline_avg + assert reduction >= 0.70, ( + f"AllReduce duration reduction {reduction * 100:.1f}% < 70% threshold. " + f"Baseline weighted avg: {baseline_avg:.2f}µs, " + f"INT4 weighted avg: {fused_avg:.2f}µs. " + "INT4 QuickReduce may not be active or not reducing latency as expected." + ) + + +# --------------------------------------------------------------------------- +# TC-4.6 qr_all_reduce kernel present in INT4 Perfetto trace +# --------------------------------------------------------------------------- + + +@_skip_if_missing(FUSED_TRACE_GZ) +def test_tc4_6_qr_all_reduce_in_int4_trace(): + """TC-4.6: The qr_all_reduce kernel must appear in the INT4/QuickReduce trace.""" + count = _grep_trace(FUSED_TRACE_GZ, b"qr_all_reduce") + assert count > 0, ( + f"qr_all_reduce not found in {FUSED_TRACE_GZ}. " + "INT4 QuickReduce kernel is not dispatching." + ) + + +# --------------------------------------------------------------------------- +# TC-4.7 qr_all_reduce absent from NONE (baseline) Perfetto trace +# --------------------------------------------------------------------------- + + +@_skip_if_missing(BASELINE_TRACE_GZ) +def test_tc4_7_qr_all_reduce_absent_from_baseline_trace(): + """TC-4.7: The baseline (NONE) trace must NOT contain qr_all_reduce.""" + count = _grep_trace(BASELINE_TRACE_GZ, b"qr_all_reduce") + assert count == 0, ( + f"qr_all_reduce found {count} time(s) in baseline trace {BASELINE_TRACE_GZ}. " + "The baseline run should not use INT4 QuickReduce — A/B comparison invalid." + ) + + +# --------------------------------------------------------------------------- +# TC-6.1 AllReduce A/B benchmark: TPOT ≥9%, TTFT ≥4% improvement +# --------------------------------------------------------------------------- + + +@_skip_if_missing(BENCH_BASELINE_JSON, BENCH_INT4_JSON) +def test_tc6_1_allreduce_benchmark_improvement(): + """TC-6.1: INT4 QuickReduce must improve TPOT ≥9% and TTFT ≥4% vs NONE.""" + import json + + with open(BENCH_BASELINE_JSON) as f: + baseline = json.load(f) + with open(BENCH_INT4_JSON) as f: + int4 = json.load(f) + + b_tpot = baseline["mean_tpot_ms"] + f_tpot = int4["mean_tpot_ms"] + b_ttft = baseline["mean_ttft_ms"] + f_ttft = int4["mean_ttft_ms"] + + tpot_imp = (b_tpot - f_tpot) / b_tpot * 100 + ttft_imp = (b_ttft - f_ttft) / b_ttft * 100 + + assert tpot_imp >= 9.0, ( + f"TPOT improvement {tpot_imp:.1f}% < 9% threshold. " + f"Baseline: {b_tpot:.1f}ms → INT4: {f_tpot:.1f}ms." + ) + assert ttft_imp >= 4.0, ( + f"TTFT improvement {ttft_imp:.1f}% < 4% threshold. " + f"Baseline: {b_ttft:.1f}ms → INT4: {f_ttft:.1f}ms." + ) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index eb12bedd7bf2..ad627fcbdae2 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -2,12 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from collections.abc import Callable -from contextlib import contextmanager -from typing import Protocol import torch from torch._ops import OpOverload -from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.platforms import current_platform @@ -52,27 +49,6 @@ def is_aiter_found() -> bool: IS_AITER_FOUND = is_aiter_found() -class AiterCustomAllreduceProto(Protocol): - max_size: int - world_size: int - fully_connected: bool - - @contextmanager - def capture(self): ... - def close(self) -> None: ... - def fused_ar_rms( - self, - inp: torch.Tensor, - res_inp: torch.Tensor, - *, - w: torch.Tensor, - eps: float, - registered: bool = False, - use_1stage: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: ... - def should_custom_ar(self, inp: torch.Tensor) -> bool: ... - - def is_aiter_found_and_supported() -> bool: """Check if AITER library is available and platform supports it. @@ -154,7 +130,6 @@ def _rocm_aiter_fused_moe_impl( intermediate_pad: int = 0, bias1: torch.Tensor | None = None, bias2: torch.Tensor | None = None, - moe_sorting_dispatch_policy: int = 0, ) -> torch.Tensor: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe @@ -182,7 +157,6 @@ def _rocm_aiter_fused_moe_impl( intermediate_pad=intermediate_pad, bias1=bias1, bias2=bias2, - moe_sorting_dispatch_policy=moe_sorting_dispatch_policy, ) @@ -206,7 +180,6 @@ def _rocm_aiter_fused_moe_fake( intermediate_pad: int = 0, bias1: torch.Tensor | None = None, bias2: torch.Tensor | None = None, - moe_sorting_dispatch_policy: int = 0, ) -> torch.Tensor: if output_dtype is not None: return torch.empty_like(hidden_states, dtype=output_dtype) @@ -274,19 +247,11 @@ def _rocm_aiter_topk_softmax_impl( token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool, - num_shared_experts: int = 0, - shared_expert_scoring_func: str = "", ) -> None: from aiter import topk_softmax topk_softmax( - topk_weights, - topk_indices, - token_expert_indices, - gating_output, - renormalize, - num_shared_experts, - shared_expert_scoring_func, + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize ) @@ -296,8 +261,6 @@ def _rocm_aiter_topk_softmax_fake( token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool, - num_shared_experts: int = 0, - shared_expert_scoring_func: str = "", ) -> None: pass @@ -427,32 +390,17 @@ def _rocm_aiter_fused_topk_fake( def check_aiter_fused_qk_rmsnorm() -> bool: - """Check if aiter provides fused_qk_rmsnorm. - - Supports both the new private name ``_fused_qk_rmsnorm`` - (AITER >= PR #2958) and the old public name ``fused_qk_rmsnorm`` - (AITER >= PR #2442). - - TODO(rbrugaro-amd): remove the legacy fused_qk_rmsnorm path once - AITER stabilizes the API (https://github.com/ROCm/aiter/issues/3207). - """ + """Check if aiter provides fused_qk_rmsnorm (requires AITer >= PR #2442).""" global _AITER_HAS_FUSED_QK_RMSNORM if _AITER_HAS_FUSED_QK_RMSNORM is None: try: from aiter.ops.fused_qk_norm_rope_cache_quant import ( # noqa: F401 - _fused_qk_rmsnorm, + fused_qk_rmsnorm, ) _AITER_HAS_FUSED_QK_RMSNORM = True except (ImportError, ModuleNotFoundError, AttributeError): - try: - from aiter.ops.fused_qk_norm_rope_cache_quant import ( # noqa: F401 - fused_qk_rmsnorm, - ) - - _AITER_HAS_FUSED_QK_RMSNORM = True - except (ImportError, ModuleNotFoundError, AttributeError): - _AITER_HAS_FUSED_QK_RMSNORM = False + _AITER_HAS_FUSED_QK_RMSNORM = False return _AITER_HAS_FUSED_QK_RMSNORM @@ -722,6 +670,58 @@ def _rocm_aiter_gemm_a8w8_blockscale_fake( return Y +def _rocm_aiter_rms_norm_impl( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + from aiter import rms_norm + + if x.dim() > 2: + x_original_shape = x.shape + x = x.reshape(-1, x_original_shape[-1]) + x = rms_norm(x, weight, variance_epsilon) + return x.reshape(x_original_shape) + + return rms_norm(x, weight, variance_epsilon) + + +def _rocm_aiter_rms_norm_fake( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + return torch.empty_like(x) + + +def _rocm_aiter_rmsnorm2d_fwd_with_add_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter import rmsnorm2d_fwd_with_add + + residual_out = torch.empty_like(residual) + out = torch.empty_like(x) + rmsnorm2d_fwd_with_add( + out, # output + x, # input + residual, # residual input + residual_out, # residual output + weight, + variance_epsilon, + ) + return out, residual_out + + +def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + residual_out = torch.empty_like(residual) + out = torch.empty_like(x) + return out, residual_out + + def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl( x: torch.Tensor, residual: torch.Tensor, @@ -797,57 +797,104 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_fake( return out, y_scale -def _rocm_aiter_fused_allreduce_rmsnorm_impl( - input_: torch.Tensor, - residual: torch.Tensor, +def _rocm_aiter_dynamic_mxfp4_quant_impl( + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Standalone dynamic MXFP4 quantization. + + Wraps aiter's dynamic_mxfp4_quant as a registered torch custom op so it + appears as a single FX-graph node during torch.compile. Pattern matchers + can then match and fuse it with upstream rms_norm calls. + + Returns: + fp4_packed (uint8, shape (M, N//2)): two FP4 values per byte. + block_scale (uint8, shape (M, ceil(N/32))): E8M0 block scales. + """ + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + return dynamic_mxfp4_quant(x) + + +def _rocm_aiter_dynamic_mxfp4_quant_fake( + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + import math + + M, N = x.shape[0], x.shape[-1] + fp4_packed = torch.empty((M, N // 2), dtype=torch.uint8, device=x.device) + block_scale = torch.empty( + (M, math.ceil(N / 32)), dtype=torch.uint8, device=x.device + ) + return fp4_packed, block_scale + + +def _rocm_aiter_rmsnorm_mxfp4_quant_impl( + x: torch.Tensor, weight: torch.Tensor, epsilon: float, ) -> tuple[torch.Tensor, torch.Tensor]: - aiter_ar = rocm_aiter_ops.get_aiter_allreduce() - assert aiter_ar is not None, "aiter allreduce must be initialized" - - total_bytes = input_.numel() * input_.element_size() - hidden_dim = input_.shape[-1] - token_num = input_.shape[0] - if input_.dtype in (torch.bfloat16, torch.float16): - pack_size = 16 // input_.element_size() - hidden_ok = hidden_dim % pack_size == 0 and hidden_dim // pack_size <= 1024 - else: - hidden_ok = False - token_ok = token_num <= 80 - world_size = aiter_ar.world_size - full_nvlink = aiter_ar.fully_connected - - if world_size == 2: - size_ok = True - elif full_nvlink and world_size <= 4: - size_ok = total_bytes < 256 * 1024 - elif full_nvlink and world_size <= 8: - size_ok = total_bytes < 128 * 1024 - else: - size_ok = False + """Fused RMSNorm + MXFP4 quant (no residual, no AllReduce). - use_1stage = hidden_ok and token_ok and size_ok + Uses aiter's fused_rms_mxfp4_quant Triton kernel to perform RMSNorm and + MXFP4 quantization in a single pass. Replaces the standalone + vllm_ir.rms_norm -> rocm_aiter_dynamic_mxfp4_quant subgraph. + """ + from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant - result = aiter_ar.fused_ar_rms( - input_, - residual, - w=weight, - eps=epsilon, - registered=torch.cuda.is_current_stream_capturing(), - use_1stage=use_1stage, + (fp4_out, scale), _, _, _ = fused_rms_mxfp4_quant(x, weight, epsilon) + return fp4_out, scale + + +def _rocm_aiter_rmsnorm_mxfp4_quant_fake( + x: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + import math + + M, N = x.shape[0], x.shape[-1] + fp4_packed = torch.empty((M, N // 2), dtype=torch.uint8, device=x.device) + block_scale = torch.empty( + (M, math.ceil(N / 32)), dtype=torch.uint8, device=x.device ) - assert result is not None - return result[0], result[1] + return fp4_packed, block_scale -def _rocm_aiter_fused_allreduce_rmsnorm_fake( - input_: torch.Tensor, +def _rocm_aiter_rmsnorm_add_mxfp4_quant_impl( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(input_), torch.empty_like(residual) +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused fused_add_RMSNorm + MXFP4 quant (with residual, no AllReduce). + + Steps: x = x + residual; residual_out = x; x = rms_norm(x); x, scale = mxfp4_quant(x). + Replaces the standalone vllm_ir.fused_add_rms_norm -> rocm_aiter_dynamic_mxfp4_quant + subgraph at non-AllReduce sites (e.g. embedding normalisation). + """ + from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant + + (fp4_out, scale), _, _, residual_out = fused_rms_mxfp4_quant( + x, weight, epsilon, res1=residual + ) + return fp4_out, scale, residual_out + + +def _rocm_aiter_rmsnorm_add_mxfp4_quant_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + import math + + M, N = x.shape[0], x.shape[-1] + fp4_packed = torch.empty((M, N // 2), dtype=torch.uint8, device=x.device) + block_scale = torch.empty( + (M, math.ceil(N / 32)), dtype=torch.uint8, device=x.device + ) + residual_out = torch.empty_like(x) + return fp4_packed, block_scale, residual_out def _rocm_aiter_per_tensor_quant_impl( @@ -878,7 +925,7 @@ def _rocm_aiter_per_token_quant_impl( assert quant_dtype in [torch.int8, FP8_DTYPE] out_shape = x.shape - out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + out = torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device) if scale is None: scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device) dynamic_per_token_scaled_quant( @@ -898,7 +945,7 @@ def _rocm_aiter_per_token_quant_fake( ) -> tuple[torch.Tensor, torch.Tensor]: out_shape = x.shape return ( - torch.empty(x.shape, dtype=quant_dtype, device=x.device), + torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device), torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device), ) @@ -982,50 +1029,6 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_fake( ) -def _rocm_aiter_fused_rms_gated_fp8_group_quant_impl( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - z: torch.Tensor, - eps: float, - norm_before_gate: bool, - activation: str, - group_size: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """Fused gated-RMSNorm + FP8 group quantization via aiter Triton kernel.""" - from aiter.ops.triton.quant import fused_rms_gated_fp8_group_quant - - return fused_rms_gated_fp8_group_quant( - x, - weight, - bias, - z, - eps, - norm_before_gate=norm_before_gate, - activation=activation, - out_dtype=FP8_DTYPE, - group_size=group_size, - ) - - -def _rocm_aiter_fused_rms_gated_fp8_group_quant_fake( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor | None, - z: torch.Tensor, - eps: float, - norm_before_gate: bool, - activation: str, - group_size: int, -) -> tuple[torch.Tensor, torch.Tensor]: - M, N = x.shape - scale_shape = (M, (N + group_size - 1) // group_size) - return ( - torch.empty_like(x, dtype=FP8_DTYPE, device=x.device), - torch.empty(scale_shape, dtype=torch.float32, device=x.device), - ) - - def _rocm_aiter_group_fp8_quant_impl( x: torch.Tensor, group_size: int, @@ -1131,42 +1134,21 @@ def _fused_mla_dual_rms_norm_impl( x2_epsilon: float, ) -> tuple[torch.Tensor, torch.Tensor]: try: - import aiter.ops.fused_qk_norm_rope_cache_quant as aiter_ops - except (ImportError, ModuleNotFoundError, AttributeError) as exc: + from aiter.ops.fused_qk_norm_rope_cache_quant import fused_qk_rmsnorm + except (ImportError, ModuleNotFoundError) as exc: raise ImportError( - "fused_qk_rmsnorm requires AITer >= PR #2442. " - "Please upgrade aiter or disable the " + "fused_qk_rmsnorm requires a newer AITer version " + "(>= PR #2442). Please upgrade aiter or disable the " "fuse_mla_dual_rms_norm pass." ) from exc - if hasattr(aiter_ops, "_fused_qk_rmsnorm"): - return aiter_ops._fused_qk_rmsnorm( - q_out=None, - q=x1, - q_weight=x1_weight, - q_eps=x1_epsilon, - k_out=None, - k=x2, - k_weight=x2_weight, - k_eps=x2_epsilon, - ) - - # TODO(rbrugaro-amd): remove the legacy fused_qk_rmsnorm path once - # AITER stabilizes the API (https://github.com/ROCm/aiter/issues/3207). - if hasattr(aiter_ops, "fused_qk_rmsnorm"): - return aiter_ops.fused_qk_rmsnorm( - q=x1, - q_weight=x1_weight, - q_eps=x1_epsilon, - k=x2, - k_weight=x2_weight, - k_eps=x2_epsilon, - ) - - raise ImportError( - "fused_qk_rmsnorm requires AITer >= PR #2442. " - "Please upgrade aiter or disable the " - "fuse_mla_dual_rms_norm pass." + return fused_qk_rmsnorm( + q=x1, + q_weight=x1_weight, + q_eps=x1_epsilon, + k=x2, + k_weight=x2_weight, + k_eps=x2_epsilon, ) @@ -1321,9 +1303,10 @@ class rocm_aiter_ops: # Check if aiter is enabled before using operations if rocm_aiter_ops.is_enabled(): - result = rocm_aiter_ops.per_token_quant(x, FP8_DTYPE) + result = rocm_aiter_ops.rms_norm(x, weight, epsilon) Operations: + - RMS normalization: rms_norm, rms_norm2d_with_add - GEMM operations: gemm_a8w8, gemm_a8w8_blockscale - Fused MoE: fused_moe, asm_moe_tkw1 - Routing: topk_softmax, biased_grouped_topk, grouped_topk @@ -1332,21 +1315,10 @@ class rocm_aiter_ops: - Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale """ - _MOE_DISPATCH_POLICY: int | None = None - - @classmethod - @if_aiter_supported - def get_moe_dispatch_policy(cls) -> int: - """Cached MoE sorting dispatch policy.""" - if cls._MOE_DISPATCH_POLICY is None: - import vllm.envs as envs - - cls._MOE_DISPATCH_POLICY = envs.VLLM_ROCM_AITER_MOE_DISPATCH_POLICY - return cls._MOE_DISPATCH_POLICY - # Check if the env variable is set _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR + _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA @@ -1355,7 +1327,7 @@ def get_moe_dispatch_policy(cls) -> int: # TODO: Consolidate under _LINEAR_ENABLED _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM _FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM - _LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM + _LINEAR_HIPBMM_ENABLED = getattr(envs, 'VLLM_ROCM_USE_AITER_LINEAR_HIPBMM', False) # TODO: Consolidate under _LINEAR_ENABLED _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE @@ -1363,12 +1335,6 @@ def get_moe_dispatch_policy(cls) -> int: _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM - # Lazily probed: whether aiter.topk_softmax supports the - # num_shared_experts / shared_expert_scoring_func args (7-arg form). - _TOPK_SOFTMAX_FUSED_SIGMOID: bool | None = None - - _ALL_REDUCE_MAX_SIZE: int = 8192 * 1024 * 8 * 2 - _CUSTOM_ALL_REDUCE: AiterCustomAllreduceProto | None = None @classmethod def refresh_env_variables(cls): @@ -1381,6 +1347,7 @@ def refresh_env_variables(cls): """ cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR + cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA @@ -1388,7 +1355,7 @@ def refresh_env_variables(cls): cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM - cls._LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM + cls._LINEAR_HIPBMM_ENABLED = getattr(envs, 'VLLM_ROCM_USE_AITER_LINEAR_HIPBMM', False) cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS @@ -1473,6 +1440,11 @@ def is_linear_enabled(cls) -> bool: def is_linear_fp8_enabled(cls) -> bool: return cls.is_linear_enabled() + @classmethod + @if_aiter_supported + def is_rmsnorm_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._RMSNORM_ENABLED + @classmethod @if_aiter_supported def is_fused_moe_enabled(cls) -> bool: @@ -1484,49 +1456,92 @@ def is_fusion_moe_shared_experts_enabled(cls) -> bool: return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED @classmethod - @if_aiter_supported - def topk_softmax_supports_fused_sigmoid(cls) -> bool: - """Check if topk_softmax supports fused shared expert activation.""" - if cls._TOPK_SOFTMAX_FUSED_SIGMOID is None: - try: - import inspect - - from aiter import topk_softmax - - params = inspect.signature(topk_softmax).parameters - if "num_shared_experts" in params: - cls._TOPK_SOFTMAX_FUSED_SIGMOID = True - else: - # @compile_ops wrapper loses the original signature. - # Fall back to the torch custom op schema. - import torch - - schema = getattr( - getattr(torch.ops.aiter, "topk_softmax", None), "default", None - ) - schema_str = str(getattr(schema, "_schema", "")) - cls._TOPK_SOFTMAX_FUSED_SIGMOID = "num_shared_experts" in schema_str - except (ImportError, ValueError): - cls._TOPK_SOFTMAX_FUSED_SIGMOID = False - return cls._TOPK_SOFTMAX_FUSED_SIGMOID + def has_fused_rmsnorm_mxfp4_quant(cls) -> bool: + """Check whether AITER exposes the fused RMSNorm+MXFP4-quant Triton kernel. + + Called during RocmAiterFusionPass.__init__ (not per-token). + Returns True when aiter.ops.triton.fused_mxfp4_quant is importable, + enabling the two MXFP4 RMSNorm fusion patterns to be registered. + Returns False on older AITER builds, falling back to unfused path. + """ + try: + from aiter.ops.triton.fused_mxfp4_quant import ( + fused_rms_mxfp4_quant, # noqa: F401 + ) + + return True + except (ImportError, AttributeError): + return False @classmethod - @if_aiter_supported - def fuse_sigmoid_in_kernel(cls, aiter_topK_meta_data: object) -> bool: - """Whether fused shared-expert sigmoid in the topk kernel is usable. + def has_fused_rope_mla_kv_cache(cls) -> bool: + """Check whether AITER exposes the fused RoPE + MLA KV-cache kernel. + + Called in mla.py __init__ (not per-token) to decide whether to + use the fused dispatch path. Auto-enables F3 when the kernel is + present — no env var required. Follows the same pattern as + has_fused_rmsnorm_mxfp4_quant() for F2. + """ + try: + from aiter import fused_qk_rope_concat_and_cache_mla # noqa: F401 + + return True + except (ImportError, AttributeError): + return False + + @classmethod + def fused_rope_and_mla_kv_cache_write( + cls, + q_nope, + q_pe, + kv_c, + k_pe, + kv_cache, + q_out, + slot_mapping, + k_scale, + q_scale, + positions, + cos_cache, + sin_cache, + is_neox: bool = True, + is_nope_first: bool = False, + ): + """Dispatch to aiter.fused_qk_rope_concat_and_cache_mla. - Combines the cached static capability checks (FSE enabled, fused-moe - enabled, topk_softmax supports fused sigmoid) with the runtime - readiness check (topK meta-data buffer initialized). + Applies RoPE to q_pe/k_pe and writes the MLA KV-cache in a single pass. - ``aiter_topK_meta_data`` is accepted as a parameter rather than - imported internally so callers cannot hit initialization-order - issues where the module-level global has not been set yet. + Args: + q_nope: [B, QH, qk_nope_head_dim] + q_pe: [B, QH, qk_rope_head_dim] (rotated in-place) + kv_c: [B, kv_lora_rank] + k_pe: [B, qk_rope_head_dim] + kv_cache: [num_blocks, 1, qk_rope_head_dim + kv_lora_rank] + q_out: [B, QH, qk_nope_head_dim + qk_rope_head_dim] (output) + slot_mapping: [B] long + k_scale, q_scale: scalar fp32 tensors + positions: [B] long + cos_cache, sin_cache: [max_seq, qk_rope_head_dim] + is_neox: use NeoX RoPE convention (default True) + is_nope_first: q layout is [nope|pe] when True (default False) """ - return ( - cls.is_fusion_moe_shared_experts_enabled() - and cls.topk_softmax_supports_fused_sigmoid() - and aiter_topK_meta_data is not None + from aiter import fused_qk_rope_concat_and_cache_mla + + fused_qk_rope_concat_and_cache_mla( + q_nope, + q_pe, + kv_c, + k_pe, + kv_cache, + q_out, + slot_mapping, + k_scale, + q_scale, + positions, + cos_cache, + sin_cache, + is_neox, + is_nope_first, ) @classmethod @@ -1585,64 +1600,6 @@ def is_triton_rotary_embed_enabled(cls) -> bool: def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM - @classmethod - @if_aiter_supported - def is_tgemm_enabled(cls) -> bool: - from vllm.platforms.rocm import on_gfx950 - - return cls.is_linear_enabled() and on_gfx950() - - @classmethod - def initialize_aiter_allreduce( - cls, group: ProcessGroup, device: torch.device - ) -> None: - try: - from aiter.dist.device_communicators.custom_all_reduce import ( - CustomAllreduce as AiterCustomAllreduce, - ) - - cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device) - except Exception: - cls._CUSTOM_ALL_REDUCE = None - - @classmethod - def get_aiter_allreduce(cls) -> AiterCustomAllreduceProto | None: - return cls._CUSTOM_ALL_REDUCE - - @classmethod - def destroy_aiter_allreduce(cls) -> None: - if cls._CUSTOM_ALL_REDUCE is not None: - cls._CUSTOM_ALL_REDUCE.close() - cls._CUSTOM_ALL_REDUCE = None - - @classmethod - def get_aiter_allreduce_max_size(cls) -> int | None: - # effective max input size (based on upstream aiter version: v0.1.10.post3) - # https://github.com/ROCm/aiter/blob/6a0e7b26ccf33164785531212cc2ec2cde0b9243/aiter/dist/device_communicators/custom_all_reduce.py#L272-L273 - return int(cls._ALL_REDUCE_MAX_SIZE / 2) - - @classmethod - @if_aiter_supported - def are_gdn_triton_kernels_available(cls) -> bool: - """Check if AITER Triton kernels for GDN attention are importable. - - These are optional Triton kernels (conv1d fast-path, gated delta net) - used by GatedDeltaNetAttention's decode fast-path. They may be absent - in older aiter builds. - """ - if not cls._AITER_ENABLED: - return False - try: - import aiter.ops.triton.causal_conv1d_update_single_token # noqa: F401 - import aiter.ops.triton.gated_delta_net # noqa: F401 - from aiter.ops.triton.quant import ( # noqa: F401 - fused_rms_gated_fp8_group_quant, - ) - - return True - except (ImportError, ModuleNotFoundError): - return False - @staticmethod @if_aiter_supported def register_ops_once() -> None: @@ -1742,6 +1699,19 @@ def register_ops_once() -> None: fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_rms_norm", + op_func=_rocm_aiter_rms_norm_impl, + fake_impl=_rocm_aiter_rms_norm_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm2d_fwd_with_add", + op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl, + fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_rmsnorm_fused_dynamic_quant", op_func=_rocm_aiter_rmsnorm_fused_dynamic_quant_impl, @@ -1762,12 +1732,6 @@ def register_ops_once() -> None: fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake, ) - direct_register_custom_op( - op_name="rocm_aiter_fused_rms_gated_fp8_group_quant", - op_func=_rocm_aiter_fused_rms_gated_fp8_group_quant_impl, - fake_impl=_rocm_aiter_fused_rms_gated_fp8_group_quant_fake, - ) - direct_register_custom_op( op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant", op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl, @@ -1832,12 +1796,6 @@ def register_ops_once() -> None: fake_impl=_triton_rotary_embedding_fake, ) - direct_register_custom_op( - op_name="rocm_aiter_fused_allreduce_rmsnorm", - op_func=_rocm_aiter_fused_allreduce_rmsnorm_impl, - fake_impl=_rocm_aiter_fused_allreduce_rmsnorm_fake, - ) - direct_register_custom_op( op_name="fused_mla_dual_rms_norm", op_func=_fused_mla_dual_rms_norm_impl, @@ -1845,8 +1803,37 @@ def register_ops_once() -> None: fake_impl=_fused_mla_dual_rms_norm_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_dynamic_mxfp4_quant", + op_func=_rocm_aiter_dynamic_mxfp4_quant_impl, + mutates_args=[], + fake_impl=_rocm_aiter_dynamic_mxfp4_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_mxfp4_quant", + op_func=_rocm_aiter_rmsnorm_mxfp4_quant_impl, + mutates_args=[], + fake_impl=_rocm_aiter_rmsnorm_mxfp4_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_add_mxfp4_quant", + op_func=_rocm_aiter_rmsnorm_add_mxfp4_quant_impl, + mutates_args=[], + fake_impl=_rocm_aiter_rmsnorm_add_mxfp4_quant_fake, + ) + _OPS_REGISTERED = True + @staticmethod + def get_rmsnorm_fused_add_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default + + @staticmethod + def get_rmsnorm_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rms_norm.default + @staticmethod def get_rmsnorm_fused_add_dynamic_quant_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_rmsnorm_fused_add_dynamic_quant.default @@ -1859,11 +1846,6 @@ def get_rmsnorm_fused_dynamic_quant_op() -> OpOverload: def get_rmsnorm_group_fused_quant_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default - @staticmethod - def get_fused_rms_gated_fp8_group_quant_op() -> OpOverload: - """Return the fused gated-RMSNorm + FP8 group quant custom op.""" - return torch.ops.vllm.rocm_aiter_fused_rms_gated_fp8_group_quant.default - @staticmethod def get_rmsnorm_group_add_fused_quant_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default @@ -1888,14 +1870,39 @@ def get_triton_add_rmsnorm_pad_op() -> OpOverload: def get_triton_rotary_embedding_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default - @staticmethod - def get_fused_allreduce_rmsnorm_op() -> OpOverload: - return torch.ops.vllm.rocm_aiter_fused_allreduce_rmsnorm.default - @staticmethod def get_fused_mla_dual_rms_norm_op() -> OpOverload: return torch.ops.vllm.fused_mla_dual_rms_norm.default + @staticmethod + def get_dynamic_mxfp4_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_dynamic_mxfp4_quant.default + + @staticmethod + def get_fused_rmsnorm_mxfp4_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm_mxfp4_quant.default + + @staticmethod + def get_fused_rmsnorm_add_mxfp4_quant_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_rmsnorm_add_mxfp4_quant.default + + @staticmethod + def rms_norm( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) + + @staticmethod + def rms_norm2d_with_add( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add( + x, residual, weight, variance_epsilon + ) + @staticmethod def w8a8_gemm( A: torch.Tensor, @@ -1978,7 +1985,6 @@ def fused_moe( intermediate_pad: int = 0, bias1: torch.Tensor | None = None, bias2: torch.Tensor | None = None, - moe_sorting_dispatch_policy: int = 0, ) -> torch.Tensor: return torch.ops.vllm.rocm_aiter_fused_moe( hidden_states, @@ -2000,7 +2006,6 @@ def fused_moe( intermediate_pad, bias1, bias2, - moe_sorting_dispatch_policy, ) @staticmethod @@ -2042,17 +2047,9 @@ def topk_softmax( token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool, - num_shared_experts: int = 0, - shared_expert_scoring_func: str = "", ) -> tuple[torch.Tensor, ...]: torch.ops.vllm.rocm_aiter_topk_softmax( - topk_weights, - topk_indices, - token_expert_indices, - gating_output, - renormalize, - num_shared_experts, - shared_expert_scoring_func, + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize ) return topk_weights, topk_indices @@ -2467,7 +2464,6 @@ def flash_attn_varlen_func( alibi_slopes: torch.Tensor | None = None, return_lse: bool = False, out: torch.Tensor | None = None, - sink_ptr: torch.Tensor | None = None, ): """ Flash attention with variable length sequences. @@ -2496,7 +2492,6 @@ def flash_attn_varlen_func( alibi_slopes=alibi_slopes, return_lse=return_lse, out=out, - sink_ptr=sink_ptr, ) @staticmethod @@ -2585,183 +2580,5 @@ def paged_attention_common( kv_cache_dtype=kv_cache_dtype, ) - @staticmethod - def mhc_pre( - residual: torch.Tensor, - fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - rms_eps: float, - hc_pre_eps: float, - hc_sinkhorn_eps: float, - hc_post_mult_value: float, - sinkhorn_repeat: int, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Forward pass for mHC pre block. - - Args: - residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16 - fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32 - hc_scale: shape (3,), dtype torch.float32 - hc_base: shape (hc_mult3,), dtype torch.float32 - rms_eps: RMS normalization epsilon - hc_pre_eps: pre-mix epsilon - hc_sinkhorn_eps: sinkhorn epsilon - hc_post_mult_value: post-mix multiplier value - sinkhorn_repeat: number of sinkhorn iterations - n_splits: split-k factor; - - Returns: - post_mix: shape (..., hc_mult), dtype torch.float32 - comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32 - layer_input: shape (..., hidden_size), dtype torch.bfloat16 - """ - from aiter.ops.mhc import mhc_pre - - # Validate shapes - assert residual.dtype == torch.bfloat16 - assert fn.dtype == torch.float32 - assert hc_scale.dtype == torch.float32 - assert hc_base.dtype == torch.float32 - - hc_mult = residual.shape[-2] - hidden_size = residual.shape[-1] - hc_mult2 = hc_mult * hc_mult - hc_mult3 = hc_mult * 2 + hc_mult2 - - hc_hidden_size = hc_mult * hidden_size - assert fn.shape[0] == hc_mult3 - assert fn.shape[1] == hc_hidden_size - assert hc_scale.shape == (3,) - assert hc_base.shape == (hc_mult3,) - - outer_shape = residual.shape[:-2] - - residual_flat = residual.view(-1, hc_mult, hidden_size) - - num_tokens = residual_flat.shape[0] - if num_tokens == 0: - return ( - torch.empty( - num_tokens, - hc_mult, - 1, - dtype=torch.float32, - device=residual_flat.device, - ), - torch.empty( - num_tokens, - hc_mult, - hc_mult, - dtype=torch.float32, - device=residual_flat.device, - ), - torch.empty( - num_tokens, - hidden_size, - dtype=torch.bfloat16, - device=residual_flat.device, - ), - ) - - # AITER's Python wrapper allocates intermediate/output tensors without - # explicit device arguments, so run it under the residual tensor's device. - with torch.device(residual_flat.device): - post_mix, comb_mix, layer_input = mhc_pre( - residual_flat, - fn, - hc_scale, - hc_base, - rms_eps, - hc_pre_eps, - hc_sinkhorn_eps, - hc_post_mult_value, - sinkhorn_repeat, - ) - return ( - post_mix.view(*outer_shape, hc_mult, 1), - comb_mix.view(*outer_shape, hc_mult, hc_mult), - layer_input.view(*outer_shape, hidden_size), - ) - - @staticmethod - def hc_head( - hs_flat: torch.Tensor, - fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - out: torch.Tensor, - hidden_size: int, - rms_eps: float, - hc_eps: float, - hc_mult: int, - ) -> None: - """Run hc_head through AITER mhc_pre and write the result to out.""" - assert hs_flat.dtype == torch.bfloat16 - assert fn.dtype == torch.float32 - assert hc_scale.dtype == torch.float32 - assert hc_base.dtype == torch.float32 - assert hs_flat.shape[-2:] == (hc_mult, hidden_size) - assert fn.shape == (hc_mult, hc_mult * hidden_size) - assert hc_scale.shape == (1,) - assert hc_base.shape == (hc_mult,) - - num_tokens = hs_flat.shape[0] - if num_tokens == 0: - return - - hc_mult3 = hc_mult * 2 + hc_mult * hc_mult - - full_fn = torch.zeros( - hc_mult3, - hc_mult * hidden_size, - dtype=fn.dtype, - device=fn.device, - ) - full_fn[:hc_mult] = fn - - full_base = torch.zeros(hc_mult3, dtype=hc_base.dtype, device=hc_base.device) - full_base[:hc_mult] = hc_base - - full_scale = torch.zeros(3, dtype=hc_scale.dtype, device=hc_scale.device) - full_scale[0] = hc_scale[0] - - _, _, layer_input = rocm_aiter_ops.mhc_pre( - hs_flat, - full_fn, - full_scale, - full_base, - rms_eps, - hc_eps, - 0.0, - 1.0, - 0, - ) - out.copy_(layer_input) - - @staticmethod - def mhc_post( - x: torch.Tensor, - residual: torch.Tensor, - post_layer_mix: torch.Tensor, - comb_res_mix: torch.Tensor, - ) -> torch.Tensor: - from aiter.ops.mhc import mhc_post - - hc_mult = residual.shape[-2] - hidden_size = residual.shape[-1] - residual_flat = residual.view(-1, hc_mult, hidden_size) - num_tokens = residual_flat.shape[0] - out = torch.empty_like(residual_flat) - mhc_post( - out, - x.view(num_tokens, hidden_size), - residual_flat, - post_layer_mix.view(num_tokens, hc_mult, 1), - comb_res_mix.view(num_tokens, hc_mult, hc_mult), - ) - return out.view_as(residual) - rocm_aiter_ops.register_ops_once() diff --git a/vllm/compilation/passes/fusion/act_quant_fusion.py b/vllm/compilation/passes/fusion/act_quant_fusion.py index c58ce31bd29c..15b20031018e 100644 --- a/vllm/compilation/passes/fusion/act_quant_fusion.py +++ b/vllm/compilation/passes/fusion/act_quant_fusion.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import itertools from typing import Any @@ -28,18 +29,23 @@ FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 -SILU_MUL_OP = torch.ops._C.silu_and_mul.default +try: + SILU_MUL_OP = torch.ops._C.silu_and_mul.default +except AttributeError: + SILU_MUL_OP = None # vllm._C not compiled (source-only run) -FUSED_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501 -} +FUSED_OPS: dict[QuantKey, OpOverload] = {} +with contextlib.suppress(AttributeError): # vllm._C not compiled (source-only run) + FUSED_OPS[kFp8StaticTensorSym] = torch.ops._C.silu_and_mul_quant.default silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr( torch.ops._C, "silu_and_mul_nvfp4_quant" ) if silu_and_mul_nvfp4_quant_supported: FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 -if current_platform.is_cuda_alike(): +if current_platform.is_cuda_alike() and hasattr( + torch.ops._C, "silu_and_mul_per_block_quant" +): FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 324b0266b4df..87c602afa430 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -44,28 +44,6 @@ FP8_DTYPE = current_platform.fp8_dtype() -_IR_RMS_NORM_OP = torch.ops.vllm_ir.rms_norm.default -_IR_FUSED_ADD_RMS_NORM_OP = torch.ops.vllm_ir.fused_add_rms_norm.default - - -def _norm_input_weight_dtype_match(match: pm.Match) -> bool: - """Prevent fusion when the norm input and weight dtypes differ (e.g. a Gemma - fp32 weight.float()+1 gamma), covering rms_norm and fused_add_rms_norm.""" - for node in match.nodes: - if node.target == _IR_RMS_NORM_OP: - x, weight = node.args[0], node.args[1] - elif node.target == _IR_FUSED_ADD_RMS_NORM_OP: - x, weight = node.args[0], node.args[2] - else: - continue - if isinstance(x, fx.Node) and isinstance(weight, fx.Node): - return x.meta["val"].dtype == weight.meta["val"].dtype - return True - - -# The empirical value for small batch -PDL_ADVANCE_LAUNCH_TOKENS = 16 - logger = init_logger(__name__) flashinfer_comm: ModuleType | None = None @@ -150,7 +128,6 @@ def call_trtllm_fused_allreduce_norm( quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, scale_factor: torch.Tensor | None = None, - weight_bias: float = 0.0, ) -> None: num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() @@ -227,8 +204,6 @@ def call_trtllm_fused_allreduce_norm( layout_code=layout_code, use_oneshot=use_oneshot, fp32_acc=fp32_acc, - weight_bias=weight_bias, - trigger_completion_at_end=num_tokens > PDL_ADVANCE_LAUNCH_TOKENS, ) def call_trtllm_fused_allreduce_norm_fake( @@ -245,7 +220,6 @@ def call_trtllm_fused_allreduce_norm_fake( quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, scale_factor: torch.Tensor | None = None, - weight_bias: float = 0.0, ) -> None: pass @@ -420,142 +394,14 @@ def replacement( # allreduce_in, residual return allreduce[1], allreduce[2] - # extra_check routes a Gemma fp32 gamma to AllReduceFusedAddGemmaRMSNormPattern. pm.register_replacement( - pattern, - replacement, - self.get_inputs(), - pm.fwd_only, - pm_pass, - extra_check=_norm_input_weight_dtype_match, + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) # Same pattern, but only return the output and not residual # (helpful for end of graph where residual is not used again) first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0] - pm.register_replacement( - first_return_only(pattern), # type: ignore[no-untyped-call] - first_return_only(replacement), # type: ignore[no-untyped-call] - self.get_inputs(), - pm.fwd_only, - pm_pass, - extra_check=_norm_input_weight_dtype_match, - ) - - -class AllReduceGemmaRMSNormPattern(BasePattern): - """Gemma-style variant of AllReduceRMSNormPattern (no residual).""" - - def __init__( - self, - epsilon: float, - dtype: torch.dtype, - device: str | None, - allreduce_params: FlashInferFusedAllReduceParams, - ) -> None: - super().__init__(dtype, device) - self.epsilon = epsilon - self.allreduce_params = allreduce_params - - def get_inputs(self) -> list[torch.Tensor]: - return [self.empty(5, 16), self.empty(16)] - - def register(self, pm_pass: PatternMatcherPass) -> None: - def pattern( - input: torch.Tensor, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - allreduce_output = tensor_model_parallel_all_reduce(input) - rms = vllm.ir.ops.rms_norm( - allreduce_output, weight.float() + 1.0, self.epsilon - ) - return rms, allreduce_output - - def replacement( - input: torch.Tensor, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - residual = torch.zeros_like(input) - rms_result = torch.empty_like(input) - assert flashinfer_comm is not None, "FlashInfer must be enabled" - allreduce = auto_functionalized( - flashinfer_trtllm_fused_allreduce_norm, - allreduce_in=input, - residual=residual, - norm_out=rms_result, - quant_out=None, - scale_out=None, - rms_gamma=weight, - rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, - weight_bias=1.0, - **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), - ) - return allreduce[3], allreduce[1] - - pm.register_replacement( - pattern, - replacement, - self.get_inputs(), - pm.fwd_only, - pm_pass, - ) - - -class AllReduceFusedAddGemmaRMSNormPattern(BasePattern): - """Gemma-style variant of AllReduceFusedAddRMSNormPattern (with residual).""" - - def __init__( - self, - epsilon: float, - dtype: torch.dtype, - device: str | None, - allreduce_params: FlashInferFusedAllReduceParams, - ) -> None: - super().__init__(dtype, device) - self.epsilon = epsilon - self.allreduce_params = allreduce_params - - def get_inputs(self) -> list[torch.Tensor]: - input = self.empty(5, 16) - residual = self.empty(5, 16) - weight = self.empty(16) - return [residual, input.to(self.dtype), weight] - - def register(self, pm_pass: PatternMatcherPass) -> None: - def pattern( - residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - allreduce_output = tensor_model_parallel_all_reduce(input) - rms, residual = vllm.ir.ops.fused_add_rms_norm( - allreduce_output, residual, weight.float() + 1.0, self.epsilon - ) - return rms, residual - - def replacement( - residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - assert flashinfer_comm is not None, "FlashInfer must be enabled" - allreduce = auto_functionalized( - flashinfer_trtllm_fused_allreduce_norm, - allreduce_in=input, - residual=residual, - norm_out=None, - quant_out=None, - scale_out=None, - rms_gamma=weight, - rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, - weight_bias=1.0, - **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), - ) - return allreduce[1], allreduce[2] - - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass - ) - - first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0] - pm.register_replacement( first_return_only(pattern), # type: ignore[no-untyped-call] first_return_only(replacement), # type: ignore[no-untyped-call] @@ -1030,18 +876,6 @@ def register_patterns(self) -> None: self.device, self.allreduce_params, ).register(self.patterns) - AllReduceGemmaRMSNormPattern( - epsilon, - self.model_dtype, - self.device, - self.allreduce_params, - ).register(self.patterns) - AllReduceFusedAddGemmaRMSNormPattern( - epsilon, - self.model_dtype, - self.device, - self.allreduce_params, - ).register(self.patterns) # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. @@ -1233,6 +1067,7 @@ def __init__(self, config: VllmConfig) -> None: ) for epsilon in [1e-5, 1e-6]: + # ── Baseline AR+RMSNorm patterns (no quant fusion) ────────────────── self.register( AiterAllreduceFusedRMSNormPattern( epsilon, @@ -1262,14 +1097,6 @@ def is_applicable_for_range(self, compile_range: Range) -> bool: return False return bool(compile_range.end <= self.max_token_num) - @VllmInductorPass.time_and_log - def __call__(self, graph: fx.Graph) -> None: - self.matched_count = self.pm_pass.apply(graph) - VllmPatternMatcherPass.match_table[self.pass_name] += self.matched_count - logger.debug( - "%s Replaced %s patterns", self.__class__.__name__, self.matched_count - ) - def __del__(self) -> None: if getattr(self, "disabled", True): return diff --git a/vllm/compilation/passes/fusion/matcher_utils.py b/vllm/compilation/passes/fusion/matcher_utils.py index 99b2892a770e..bb315a6c79d7 100644 --- a/vllm/compilation/passes/fusion/matcher_utils.py +++ b/vllm/compilation/passes/fusion/matcher_utils.py @@ -29,14 +29,25 @@ ) from vllm.platforms import current_platform -ROTARY_OP = torch.ops._C.rotary_embedding.default -FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default - -QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 -} +try: + ROTARY_OP = torch.ops._C.rotary_embedding.default +except AttributeError: + ROTARY_OP = None # vllm._C not compiled (source-only run) + +try: + FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default +except AttributeError: + FLASHINFER_ROTARY_OP = None + +QUANT_OPS: dict[QuantKey, OpOverload] = {} +try: + QUANT_OPS[kFp8StaticTensorSym] = torch.ops._C.static_scaled_fp8_quant.default # noqa: E501 + QUANT_OPS[kFp8DynamicTensorSym] = torch.ops._C.dynamic_scaled_fp8_quant.default # noqa: E501 + QUANT_OPS[kFp8DynamicTokenSym] = ( + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default + ) # noqa: E501 +except AttributeError: + pass # vllm._C not compiled (source-only run) if hasattr(torch.ops._C, "per_token_group_fp8_quant"): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 @@ -45,8 +56,10 @@ if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501 - -SILU_MUL_OP = torch.ops._C.silu_and_mul.default +try: + SILU_MUL_OP = torch.ops._C.silu_and_mul.default +except AttributeError: + SILU_MUL_OP = None class MatcherCustomOp(ABC): diff --git a/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py b/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py index b7e747a784eb..c7c29545dbfd 100644 --- a/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py +++ b/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py @@ -23,7 +23,10 @@ logger = init_logger(__name__) -FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default +try: + FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default +except AttributeError: + FUSED_QK_ROPE_OP = None # vllm._C not compiled (source-only run) P = ParamSpec("P") diff --git a/vllm/compilation/passes/fusion/rms_quant_fusion.py b/vllm/compilation/passes/fusion/rms_quant_fusion.py index 670349a08b2a..188b9b3f11b7 100644 --- a/vllm/compilation/passes/fusion/rms_quant_fusion.py +++ b/vllm/compilation/passes/fusion/rms_quant_fusion.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib from typing import Any, NamedTuple import torch @@ -84,13 +85,20 @@ def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor: ) -RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default - -QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 -} +try: + RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default +except AttributeError: + RMS_ADD_OP = None # vllm._C not compiled (source-only run) + +QUANT_OPS: dict[QuantKey, OpOverload] = {} +try: + QUANT_OPS[kFp8StaticTensorSym] = torch.ops._C.static_scaled_fp8_quant.default # noqa: E501 + QUANT_OPS[kFp8DynamicTensorSym] = torch.ops._C.dynamic_scaled_fp8_quant.default # noqa: E501 + QUANT_OPS[kFp8DynamicTokenSym] = ( + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default + ) # noqa: E501 +except AttributeError: + pass # vllm._C not compiled (source-only run) if hasattr(torch.ops._C, "per_token_group_fp8_quant"): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 @@ -115,32 +123,36 @@ def __str__(self) -> str: ) -FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { - FusedRMSQuantKey( - kFp8StaticTensorSym, False - ): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8StaticTensorSym, True - ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8DynamicTokenSym, False - ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8DynamicTokenSym, True - ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8Dynamic128Sym, False - ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8Dynamic128Sym, True - ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8Dynamic64Sym, False - ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 - FusedRMSQuantKey( - kFp8Dynamic64Sym, True - ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 -} +FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {} +with contextlib.suppress(AttributeError): # vllm._C not compiled (source-only run) + FUSED_OPS.update( + { + FusedRMSQuantKey( + kFp8StaticTensorSym, False + ): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8StaticTensorSym, True + ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8DynamicTokenSym, False + ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8DynamicTokenSym, True + ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic128Sym, False + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic128Sym, True + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic64Sym, False + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic64Sym, True + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + } + ) class RMSNormQuantPattern: @@ -649,6 +661,8 @@ def __init__(self, config: VllmConfig) -> None: RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Only register group quant patterns on CUDA/ROCm where the C++ op exists + if not hasattr(torch.ops._C, "per_token_group_fp8_quant"): + continue for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]: for has_col_major_scales in [True, False]: for is_e8m0 in [True, False]: diff --git a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py index 03d291d4d94f..1ea18b8c280f 100644 --- a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py +++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py @@ -6,13 +6,12 @@ import torch import torch._inductor.pattern_matcher as pm from torch import fx -from torch._inductor.fx_passes.post_grad import view_to_reshape from torch._inductor.pattern_matcher import PatternMatcherPass import vllm.ir.ops import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 from vllm._aiter_ops import rocm_aiter_ops -from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -28,12 +27,9 @@ VllmInductorPass, VllmPatternMatcherPass, VllmPatternReplacement, - _fx_view_to_reshape, - fold_consecutive_reshapes, ) from .matcher_utils import ( MatcherQuantFP8, - MatcherRMSNormGated, MatcherSiluAndMul, ) from .rms_quant_fusion import ( @@ -297,248 +293,119 @@ def replacement( pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) -class DoubleAiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern): - """ - Pattern matching ``rms_norm`` whose output feeds *two* distinct - ``rocm_aiter_group_fp8_quant`` consumers, replacing it with two - independent fused ``rms_norm_group_fp8_quant`` ops. - - Repeating the rms_norm in the replacement is preferable to leaving - the fused 16-bit rms output materialized for two unfused quant - consumers, and matches what the previous manual graph surgery - achieved by cloning the rms_norm node. +class AiterRMSNormMXFP4QuantPattern(AiterRMSNormQuantPattern): + """Fuse AITER rms_norm + dynamic MXFP4 quant into a single kernel. + + Matched 2-node subgraph:: + + torch.ops.vllm_ir.rms_norm(x, weight, eps) + → torch.ops.vllm.rocm_aiter_dynamic_mxfp4_quant(z) + + Replacement: single AITER fused Triton call + ``rocm_aiter_rmsnorm_mxfp4_quant(x, weight, eps)``. + + Registered in :class:`RocmAiterRMSNormQuantFusionPass` only when + ``rocm_aiter_ops.has_fused_rmsnorm_mxfp4_quant()`` returns True + (i.e. aiter.ops.triton.fused_mxfp4_quant is importable). """ - FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op() + FUSED_OP = rocm_aiter_ops.get_fused_rmsnorm_mxfp4_quant_op() - def __init__( - self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape, - match_aiter_quant: bool = True, - symmetric: bool = True, - ) -> None: - scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey( - fused_add=False, - quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), - ) + def __init__(self, epsilon: float) -> None: + self.epsilon = epsilon + self.DYNAMIC_MXFP4_QUANT_OP = rocm_aiter_ops.get_dynamic_mxfp4_quant_op() + self.device = torch.device("cuda") - super().__init__(epsilon, key, match_aiter_quant) + def empty(self, *args, **kwargs) -> torch.Tensor: + return torch.empty(*args, dtype=torch.bfloat16, device=self.device, **kwargs) def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon) - result1, scale1 = self.quant_matcher(result_rms) - result2, scale2 = self.quant_matcher(result_rms) - return result1, scale1, result2, scale2 + fp4, scale = self.DYNAMIC_MXFP4_QUANT_OP(result_rms) + return fp4, scale def replacement( input: torch.Tensor, weight: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - at1 = self.FUSED_OP( - x=input, - weight=weight, - variance_epsilon=self.epsilon, - group_size=128, - ) - at2 = self.FUSED_OP( - x=input, - weight=weight, - variance_epsilon=self.epsilon, - group_size=128, - ) - - return at1[0], at1[1], at2[0], at2[1] + ) -> tuple[torch.Tensor, torch.Tensor]: + fp4, scale = self.FUSED_OP(x=input, weight=weight, epsilon=self.epsilon) + return fp4, scale pm.register_replacement( pattern, replacement, - # input, weight [self.empty(5, 16), self.empty(16)], pm.fwd_only, pm_pass, ) -class DoubleAiterRMSFp8GroupQuantViewPattern(AiterRMSNormQuantPattern): - """ - View-tolerant variant of ``DoubleAiterRMSFp8GroupQuantPattern``. +class AiterFusedAddRMSNormMXFP4QuantPattern(AiterRMSNormQuantPattern): + """Fuse AITER fused_add_rms_norm + dynamic MXFP4 quant into a single kernel. - Matches the same 1-to-2 fan-out, but with a ``view``/``reshape`` between - the ``rms_norm`` output and the two ``rocm_aiter_group_fp8_quant`` - consumers:: + Matched 3-node subgraph:: - rms_norm -> view -> rocm_aiter_group_fp8_quant - \\-> view -> rocm_aiter_group_fp8_quant + torch.ops.vllm_ir.fused_add_rms_norm(x, residual, weight, eps) + → torch.ops.vllm.rocm_aiter_dynamic_mxfp4_quant(z) - This shape arises in DeepSeek-V3.2's MLA indexer q_c norm, where the - FP8 linear path's 2D-flatten boilerplate - (``Fp8BlockScaledMMLinearKernel.apply_weights``) inserts a view between - the rms_norm output and each FP8 group quant op. The non-view sibling - pattern silently no-ops on this graph because the pattern matcher - requires the in-graph and in-pattern node shapes to align. + Replacement: single AITER fused Triton call + ``rocm_aiter_rmsnorm_add_mxfp4_quant(x, residual, weight, eps)``, + returning ``(fp4_data, scale, updated_residual)``. - The trace_fn runs Inductor's ``view_to_reshape`` post-grad pass to - normalize ``view`` to ``reshape`` in both the pattern and the input - graph, widening the match without touching the no-view sibling. + Registered BEFORE :class:`AiterRMSNormMXFP4QuantPattern` so that the + larger subgraph is attempted first (greedy matching). """ - FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op() + FUSED_OP = rocm_aiter_ops.get_fused_rmsnorm_add_mxfp4_quant_op() - def __init__( - self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape, - match_aiter_quant: bool = True, - symmetric: bool = True, - ) -> None: - scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey( - fused_add=False, - quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), - ) + def __init__(self, epsilon: float) -> None: + self.epsilon = epsilon + self.DYNAMIC_MXFP4_QUANT_OP = rocm_aiter_ops.get_dynamic_mxfp4_quant_op() + self.device = torch.device("cuda") - super().__init__(epsilon, key, match_aiter_quant) + def empty(self, *args, **kwargs) -> torch.Tensor: + return torch.empty(*args, dtype=torch.bfloat16, device=self.device, **kwargs) def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, weight: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon) - view_rms = result_rms.view(-1, result_rms.shape[-1]) - result1, scale1 = self.quant_matcher(view_rms) - result2, scale2 = self.quant_matcher(view_rms) - return result1, scale1, result2, scale2 + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + result_rms, residual_out = torch.ops.vllm_ir.fused_add_rms_norm( + input, residual, weight, self.epsilon + ) + fp4, scale = self.DYNAMIC_MXFP4_QUANT_OP(result_rms) + return fp4, scale, residual_out def replacement( input: torch.Tensor, weight: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - at1 = self.FUSED_OP( - x=input, - weight=weight, - variance_epsilon=self.epsilon, - group_size=128, - ) - at2 = self.FUSED_OP( + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + fp4, scale, residual_out = self.FUSED_OP( x=input, + residual=residual, weight=weight, - variance_epsilon=self.epsilon, - group_size=128, + epsilon=self.epsilon, ) + return fp4, scale, residual_out - return at1[0], at1[1], at2[0], at2[1] - - def trace_with_view_to_reshape(*args: Any, **kwargs: Any) -> fx.GraphModule: - gm = pm.fwd_only(*args, **kwargs) - view_to_reshape(gm) - return gm - - pm.register_replacement( - pattern, - replacement, - # input, weight - [self.empty(5, 16), self.empty(16)], - trace_with_view_to_reshape, - pm_pass, - ) - - -class AiterRMSNormGatedFp8GroupQuantPattern(AiterRMSNormQuantPattern): - """ - Matches decomposed RMSNormGated + reshape + group FP8 quant and replaces - with rocm_aiter_fused_rms_gated_fp8_group_quant. - - The norm operates per-head on (N*H, D) tensors. The compiler folds the - reshape chain so after norm the result goes through reshape->merge->quant. - The pattern reshapes from (N*H, D) to (N, H*D) before calling - MatcherQuantFP8 so that _quantize_group_native sees the full hidden dim - and computes the correct num_groups. - """ - - FUSED_OP = rocm_aiter_ops.get_fused_rms_gated_fp8_group_quant_op() - - def __init__( - self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape, - num_heads: int, - head_dim: int, - match_aiter_quant: bool = True, - symmetric: bool = True, - ) -> None: - scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey( - fused_add=False, - quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), - ) - super().__init__(epsilon, key, match_aiter_quant) - self.rmsnorm_gated_matcher = MatcherRMSNormGated(epsilon) - self.num_heads = num_heads - self.head_dim = head_dim - - def register(self, pm_pass: PatternMatcherPass) -> None: - num_heads = self.num_heads - head_dim = self.head_dim - hidden_dim = num_heads * head_dim - quant_matcher = self.quant_matcher - - def pattern( - x: torch.Tensor, - z: torch.Tensor, - weight: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - normed = self.rmsnorm_gated_matcher(x, z, weight) - merged = normed.reshape(-1, hidden_dim) - quant_out, scales_out = quant_matcher(merged) - return quant_out, scales_out - - def replacement( - x: torch.Tensor, - z: torch.Tensor, - weight: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - fused = self.FUSED_OP( - x=x, - weight=weight, - bias=None, - z=z, - eps=self.epsilon, - norm_before_gate=True, - activation="silu", - group_size=head_dim, - ) - fp8_out = fused[0] - scales_out = fused[1] - fp8_reshaped = fp8_out.reshape(-1, hidden_dim) - scales_reshaped = scales_out.reshape(-1, num_heads) - return fp8_reshaped, scales_reshaped - - n_tokens = 2 - x = self.empty(n_tokens * num_heads, head_dim) - z = self.empty(n_tokens * num_heads, head_dim) - w = self.empty(head_dim) - - def trace_fn(*args, **kwargs): - gm = pm.fwd_only(*args, **kwargs) - _fx_view_to_reshape(gm) - fold_consecutive_reshapes(gm) - return gm + inputs = [ + self.empty(5, 16), # input + self.empty(16), # weight + self.empty(5, 16), # residual + ] pm.register_replacement( pattern, replacement, - [x, z, w], - trace_fn, + inputs, + pm.fwd_only, pm_pass, ) @@ -557,58 +424,36 @@ def __init__(self, config: VllmConfig) -> None: self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="rocm_aiter_rms_norm_quant_fusion_pass" ) - - # Discover (num_heads, head_dim) pairs for gated RMSNorm patterns - # from GatedDeltaNetAttention layers in static_forward_context. - from vllm.model_executor.layers.mamba.gdn.base import ( - GatedDeltaNetAttention, - ) - - gdn_layers = get_layers_from_vllm_config( - config, - GatedDeltaNetAttention, # type: ignore[type-abstract] - ) - gated_norm_shapes: set[tuple[int, int]] = set() - for layer in gdn_layers.values(): - num_v_heads = getattr(layer, "num_v_heads", None) or getattr( - layer, "num_heads", None - ) - head_v_dim = getattr(layer, "head_v_dim", None) or getattr( - layer, "head_dim", None - ) - - assert num_v_heads is not None and head_v_dim is not None - - gated_norm_shapes.add((num_v_heads // layer.tp_size, head_v_dim)) + # Track registered pattern instances for inspection (e.g., ordering tests) + self._pattern_replacements: list = [] # Make sure fused add patterns are before simple rms norm, - # as the latter is a subset of the former in torch ops. - # The DoubleQuant patterns handle 1 rms_norm -> 2 group_fp8_quant - # fan-out (e.g. DSv3.2) and must be registered before the single - # group-quant pattern so they match first. The view-tolerant variant - # additionally covers the rms_norm -> view -> 2x quant shape that - # appears when the FP8 linear path inserts a 2D-flatten boilerplate - # (DSv3.2 MLA indexer q_c norm). + # as the latter is a subset of the former in torch ops + mxfp4_pattern_count = 0 for epsilon in [1e-5, 1e-6]: - # Fuse aiter rms_norm + 2x aiter group fp8 quant - DoubleAiterRMSFp8GroupQuantPattern( - epsilon, FP8_DTYPE, GroupShape(1, 128) - ).register(self.patterns) - - # View-tolerant sibling for DSv3.2 q_c norm fan-out - DoubleAiterRMSFp8GroupQuantViewPattern( - epsilon, FP8_DTYPE, GroupShape(1, 128) - ).register(self.patterns) + # ── MXFP4 patterns ─────────────────────────────────────────────── + # Guarded so patterns are only registered when the AITER Triton + # fused kernel is importable. Fused-add pattern first (larger + # subgraph, greedy priority). + if rocm_aiter_ops.has_fused_rmsnorm_mxfp4_quant(): + p_add = AiterFusedAddRMSNormMXFP4QuantPattern(epsilon) + p_add.register(self.patterns) + self._pattern_replacements.append(p_add) + p_rms = AiterRMSNormMXFP4QuantPattern(epsilon) + p_rms.register(self.patterns) + self._pattern_replacements.append(p_rms) + mxfp4_pattern_count += 2 # Fuse aiter rms_norm + aiter dynamic group fp8 quant - AiterRMSFp8GroupQuantPattern( - epsilon, FP8_DTYPE, GroupShape(1, 128) - ).register(self.patterns) + if hasattr(torch.ops._C, "per_token_group_fp8_quant"): + AiterRMSFp8GroupQuantPattern( + epsilon, FP8_DTYPE, GroupShape(1, 128) + ).register(self.patterns) - # Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant - AiterFusedAddRMSFp8GroupQuantPattern( - epsilon, FP8_DTYPE, GroupShape(1, 128) - ).register(self.patterns) + # Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant + AiterFusedAddRMSFp8GroupQuantPattern( + epsilon, FP8_DTYPE, GroupShape(1, 128) + ).register(self.patterns) # When quant_fp8 custom ops are disabled, both AITER and native # quant matchers trace through QuantFP8's native implementation. @@ -634,20 +479,14 @@ def __init__(self, config: VllmConfig) -> None: epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant ).register(self.patterns) - # Fuse decomposed RMSNormGated + group fp8 quant. - # The replacement op (fused_rms_gated_fp8_group_quant) requires - # an aiter version that includes the GDN triton kernel renames. - if gated_norm_shapes and rocm_aiter_ops.are_gdn_triton_kernels_available(): - for num_heads, head_dim in gated_norm_shapes: - if head_dim != 128: - continue - AiterRMSNormGatedFp8GroupQuantPattern( - epsilon, - FP8_DTYPE, - GroupShape(1, 128), - num_heads=num_heads, - head_dim=head_dim, - ).register(self.patterns) + if mxfp4_pattern_count: + logger.info( + "RocmAiterRMSNormQuantFusionPass: registered %d MXFP4 fusion " + "patterns (AiterRMSNormMXFP4QuantPattern + " + "AiterFusedAddRMSNormMXFP4QuantPattern, %d epsilon variants)", + mxfp4_pattern_count, + mxfp4_pattern_count // 2, + ) self.dump_patterns(config, self.patterns) @@ -664,9 +503,8 @@ def uuid(self) -> str: AiterFusedAddRMSNormDynamicQuantPattern, AiterRMSFp8GroupQuantPattern, AiterFusedAddRMSFp8GroupQuantPattern, - DoubleAiterRMSFp8GroupQuantPattern, - DoubleAiterRMSFp8GroupQuantViewPattern, - AiterRMSNormGatedFp8GroupQuantPattern, + AiterRMSNormMXFP4QuantPattern, + AiterFusedAddRMSNormMXFP4QuantPattern, ] return self.hash_source(self, *fusion_patterns) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 856f6bb8a3cf..92e1aea78adf 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -5,9 +5,13 @@ import torch from vllm.config import CacheConfig +from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention import MLAAttention from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.platforms import current_platform + +logger = init_logger(__name__) @dataclass @@ -116,6 +120,26 @@ def __init__( self.prefix = prefix + # F3: fused RoPE + MLA KV-cache write gate (ROCm + aiter only). + # Auto-enables when AITER has fused_qk_rope_concat_and_cache_mla. + # No env var required — follows has_fused_rmsnorm_mxfp4_quant() pattern. + self._f3_fusion_enabled: bool = False + if current_platform.is_rocm(): + try: + from vllm._aiter_ops import rocm_aiter_ops + + self._f3_fusion_enabled = bool( + rocm_aiter_ops.is_mla_enabled() + and rocm_aiter_ops.has_fused_rope_mla_kv_cache() + ) + if self._f3_fusion_enabled: + logger.info( + "F3 fused RoPE+KV-cache dispatch auto-enabled " + "(prefix=%s)", prefix + ) + except Exception: + pass # aiter not available; stay False + def forward( self, positions: torch.Tensor, @@ -160,7 +184,51 @@ def forward( # Add head dim of 1 to k_pe k_pe = k_pe.unsqueeze(1) - if self.rotary_emb is not None: + if self._f3_fusion_enabled and self.rotary_emb is not None: + # F3: single Triton kernel — RoPE(q_pe, k_pe) + kv_cache write. + # Runs here with PRE-RoPE tensors; replaces the separate rotary_emb + # call and the do_kv_cache_update call inside mla_attn. + from vllm._aiter_ops import rocm_aiter_ops + from vllm.forward_context import get_forward_context + + fwd_ctx = get_forward_context() + slot_mapping_dict = fwd_ctx.slot_mapping + if isinstance(slot_mapping_dict, list): + slot_mapping_dict = slot_mapping_dict[0] + layer_slot_mapping = slot_mapping_dict.get(self.mla_attn.layer_name) + if layer_slot_mapping is not None and self.mla_attn.kv_cache.numel() > 0: + q_nope = q[..., : self.qk_nope_head_dim] + q_pe_pre = q[..., self.qk_nope_head_dim :] + kv_c = kv_c_normed.squeeze(1) # [B, kv_lora_rank] + cos_sin = self.rotary_emb.cos_sin_cache + head_dim = self.qk_rope_head_dim + cos_cache = cos_sin[:, :head_dim] + sin_cache = cos_sin[:, head_dim:] + rocm_aiter_ops.fused_rope_and_mla_kv_cache_write( + q_nope=q_nope, + q_pe=q_pe_pre, + kv_c=kv_c, + k_pe=k_pe.squeeze(1), + kv_cache=self.mla_attn.kv_cache, + q_out=q, + slot_mapping=layer_slot_mapping.flatten(), + k_scale=self.mla_attn._k_scale, + q_scale=self.mla_attn._k_scale, + positions=positions, + cos_cache=cos_cache, + sin_cache=sin_cache, + is_neox=self.rotary_emb.is_neox_style, + ) + # kv_cache already updated by the fused kernel above. + # do_kv_cache_update inside mla_attn will write the same data + # again (redundant but correct); the duplicate write will be + # removed in the follow-on PR when this flag defaults to True. + else: + # Fallback: slot_mapping unavailable or kv_cache empty + q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim :], k_pe + ) + elif self.rotary_emb is not None: q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( positions, q[..., self.qk_nope_head_dim :], k_pe )