Skip to content
241 changes: 240 additions & 1 deletion tests/compile/passes/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
TritonFp8BlockScaledMMKernel,
_KernelT,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm, RMSNormGated
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
Expand Down Expand Up @@ -441,3 +441,242 @@ def test_aiter_fusion_rmsnorm_quant(
_run_fusion_test(
model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens
)


class TestGatedModel(torch.nn.Module):
"""Model that uses RMSNormGated + reshape + group FP8 quant + linear.

Mimics GatedDeltaNetAttention's output projection path where:
- RMSNormGated operates on per-head tensors (N*H, D)
- Output is reshaped to (N, H*D) before group quantization + linear
"""

def __init__(
self,
num_heads: int,
head_dim: int,
eps: float,
force_kernel: type[_KernelT],
group_shape: GroupShape,
dtype: torch.dtype,
use_aiter_quant: bool = True,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
hidden_dim = num_heads * head_dim

self.norm = RMSNormGated(
head_dim,
eps=eps,
group_size=None,
norm_before_gate=True,
)

self.activation_quant_key = create_fp8_quant_key(
static=False, group_shape=group_shape
)
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(group_shape.col, group_shape.col)
)

self.fp8_linear = TestFP8Layer(
weight_shape=(hidden_dim, hidden_dim),
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
force_kernel=force_kernel,
transpose_weights=True,
input_dtype=dtype,
)
self.fp8_linear.kernel.quant_fp8.use_aiter = use_aiter_quant

def forward(self, x, z):
num_heads = self.num_heads
head_dim = self.head_dim
hidden_dim = num_heads * head_dim
x = torch.relu(x)
z = torch.relu(z)
x_heads = x.reshape(-1, num_heads, head_dim).reshape(-1, head_dim)
z_heads = z.reshape(-1, num_heads, head_dim).reshape(-1, head_dim)
normed = self.norm(x_heads, z_heads)
merged = normed.reshape(-1, hidden_dim)
out = self.fp8_linear(merged)
return out

def ops_in_model_after(self):
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
AiterRMSNormGatedFp8GroupQuantPattern,
)

return [AiterRMSNormGatedFp8GroupQuantPattern.FUSED_OP]


class _MockGDNLayer:
"""Minimal mock to populate static_forward_context for pass discovery.

Uses __class__ assignment to pass isinstance checks against
GatedDeltaNetAttention without requiring a full config-based init.
"""

def __init__(self, num_v_heads: int, head_v_dim: int, tp_size: int = 1):
self.num_v_heads = num_v_heads
self.head_v_dim = head_v_dim
self.tp_size = tp_size

from vllm.model_executor.layers.mamba.gdn_linear_attn import (
GatedDeltaNetAttention,
)

self.__class__ = GatedDeltaNetAttention


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_heads", [2])
@pytest.mark.parametrize("head_dim", [128])
@pytest.mark.parametrize("num_tokens", [8])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.skipif(
(not current_platform.is_rocm() or not IS_AITER_FOUND),
reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_gated_quant(
dtype: torch.dtype,
num_heads: int,
head_dim: int,
num_tokens: int,
eps: float,
monkeypatch: pytest.MonkeyPatch,
):
group_shape = GroupShape(1, 128)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["-rms_norm", "-silu_and_mul", "-quant_fp8"],
pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
),
)

with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
RocmAiterRMSNormQuantFusionPass,
)

m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables()

# Register a mock GDN layer so the pass discovers num_heads/head_dim
mock_gdn = _MockGDNLayer(num_v_heads=num_heads, head_v_dim=head_dim, tp_size=1)
vllm_config.compilation_config.static_forward_context["mock_gdn_layer"] = (
mock_gdn
)

torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)

fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config)

model = TestGatedModel(
num_heads=num_heads,
head_dim=head_dim,
eps=eps,
force_kernel=AiterFp8BlockScaledMMKernel,
group_shape=group_shape,
dtype=dtype,
use_aiter_quant=True,
)

noop_pass = NoOpEliminationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)

hidden_dim = num_heads * head_dim
x = torch.rand(num_tokens, hidden_dim)
z = torch.rand(num_tokens, hidden_dim)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(z, 0)

model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x, z)

model_unfused = torch.compile(model, backend=backend2)
result_unfused = model_unfused(x, z)

torch.testing.assert_close(result_fused, result_unfused, atol=1e-2, rtol=1e-2)

assert fusion_pass.matched_count == 1
backend.check_after_ops(model.ops_in_model_after())


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_heads", [2])
@pytest.mark.parametrize("head_dim", [128])
@pytest.mark.parametrize("num_tokens", [8])
@pytest.mark.parametrize("eps", [1e-6])
@pytest.mark.skipif(
(not current_platform.is_rocm() or not IS_AITER_FOUND),
reason="Only test on ROCm with aiter package installed",
)
def test_aiter_fusion_rmsnorm_gated_quant_no_gdn_layers(
dtype: torch.dtype,
num_heads: int,
head_dim: int,
num_tokens: int,
eps: float,
monkeypatch: pytest.MonkeyPatch,
):
"""Verify that without GDN layers in static_forward_context,
the gated pattern is not registered and no matches occur."""
group_shape = GroupShape(1, 128)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["-rms_norm", "-silu_and_mul", "-quant_fp8"],
pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True),
),
)

with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
RocmAiterRMSNormQuantFusionPass,
)

m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables()

torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)

# No mock GDN layer registered -- pass should not register gated pattern
fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config)

model = TestGatedModel(
num_heads=num_heads,
head_dim=head_dim,
eps=eps,
force_kernel=AiterFp8BlockScaledMMKernel,
group_shape=group_shape,
dtype=dtype,
use_aiter_quant=True,
)

noop_pass = NoOpEliminationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)

hidden_dim = num_heads * head_dim
x = torch.rand(num_tokens, hidden_dim)
z = torch.rand(num_tokens, hidden_dim)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(z, 0)

model_fused = torch.compile(model, backend=backend)
model_fused(x, z)

assert fusion_pass.matched_count == 0
58 changes: 58 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,50 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
)


def _rocm_aiter_fused_rms_gated_fp8_group_quant_impl(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
z: torch.Tensor,
eps: float,
norm_before_gate: bool,
activation: str,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fused gated-RMSNorm + FP8 group quantization via aiter Triton kernel."""
from aiter.ops.triton.quant import fused_rms_gated_fp8_group_quant

return fused_rms_gated_fp8_group_quant(
x,
weight,
bias,
z,
eps,
norm_before_gate=norm_before_gate,
activation=activation,
out_dtype=FP8_DTYPE,
group_size=group_size,
)


def _rocm_aiter_fused_rms_gated_fp8_group_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
z: torch.Tensor,
eps: float,
norm_before_gate: bool,
activation: str,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
)


def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
Expand Down Expand Up @@ -1480,6 +1524,9 @@ def are_gdn_triton_kernels_available(cls) -> bool:
try:
import aiter.ops.triton.causal_conv1d_update_single_token # noqa: F401
import aiter.ops.triton.gated_delta_net # noqa: F401
from aiter.ops.triton.quant import ( # noqa: F401
fused_rms_gated_fp8_group_quant,
)

return True
except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -1598,6 +1645,12 @@ def register_ops_once() -> None:
fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
)

direct_register_custom_op(
op_name="rocm_aiter_fused_rms_gated_fp8_group_quant",
op_func=_rocm_aiter_fused_rms_gated_fp8_group_quant_impl,
fake_impl=_rocm_aiter_fused_rms_gated_fp8_group_quant_fake,
)

direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
Expand Down Expand Up @@ -1689,6 +1742,11 @@ def get_rmsnorm_fused_dynamic_quant_op() -> OpOverload:
def get_rmsnorm_group_fused_quant_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default

@staticmethod
def get_fused_rms_gated_fp8_group_quant_op() -> OpOverload:
"""Return the fused gated-RMSNorm + FP8 group quant custom op."""
return torch.ops.vllm.rocm_aiter_fused_rms_gated_fp8_group_quant.default

@staticmethod
def get_rmsnorm_group_add_fused_quant_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
Expand Down
Loading