Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions tests/compile/test_fuse_act_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import pytest
import torch

import vllm.config
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.utils import rocm_unquantized_gemm

from .backend import TestBackend


class TestModel(torch.nn.Module):
def __init__(
self,
num_layers: int,
hidden_size: int,
num_local_experts: int,
x_pad_to_multiple: int,
):
super().__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.x_pad_to_multiple = x_pad_to_multiple
self.pad_dim = x_pad_to_multiple - (hidden_size % x_pad_to_multiple)

self.norm = [RMSNorm(hidden_size, eps=1e-5) for _ in range(num_layers)]
self.router = [
torch.nn.Linear(hidden_size, num_local_experts) for _ in range(4)
]

def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
all_router_logits = []
for layer in range(self.num_layers):
x = x[:, : self.hidden_size]
x, resid = self.norm[layer](x, resid)
router_logits = rocm_unquantized_gemm(
self, x, self.router[layer].weight, self.router[layer].bias
)
x = torch.nn.functional.pad(
x, (0, self.pad_dim), mode="constant", value=0.0
)
all_router_logits.append(router_logits)

return x, resid, *all_router_logits

def ops_in_model_before(self):
return [
rocm_aiter_ops.get_rmsnorm_fused_add_op(),
torch.ops.aten.constant_pad_nd,
]

def ops_in_model_after(self):
return [rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()]


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_layers", [3])
@pytest.mark.parametrize("hidden_size", [2880])
@pytest.mark.parametrize("num_local_experts", [128])
@pytest.mark.parametrize("x_pad_to_multiple", [256])
@pytest.mark.skipif(
not is_aiter_found_and_supported(),
reason="Only test on ROCm with AITER installed and supported",
)
def test_fuse_act_padding(
dtype: torch.dtype,
num_layers: int,
hidden_size: int,
num_local_experts: int,
x_pad_to_multiple: int,
monkeypatch: pytest.MonkeyPatch,
):
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(fuse_act_padding=True, eliminate_noops=True),
),
)

with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterTritonAddRMSNormPadFusionPass,
)

torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)

m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables()

fusion_pass = RocmAiterTritonAddRMSNormPadFusionPass(vllm_config)
passes = [
NoOpEliminationPass(vllm_config),
fusion_pass,
PostCleanupPass(vllm_config),
]
backend = TestBackend(*passes)
model = TestModel(num_layers, hidden_size, num_local_experts, x_pad_to_multiple)

x = torch.rand(1, hidden_size)
torch._dynamo.mark_dynamic(x, 0)

outputs_unfused = model(x)

model_fused = torch.compile(model, backend=backend)
outputs_fused = model_fused(x)

torch.testing.assert_close(outputs_unfused, outputs_fused)

assert fusion_pass.matched_count == num_layers

backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())
4 changes: 2 additions & 2 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def test_aiter_fusion_rmsnorm_quant(
)

with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormQuantFusionPass

m.setenv("VLLM_ROCM_USE_AITER", "1")

Expand All @@ -420,7 +420,7 @@ def test_aiter_fusion_rmsnorm_quant(
torch.set_default_dtype(dtype)
torch.manual_seed(1)

fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config)

model = TestModel(
hidden_size=hidden_size,
Expand Down
46 changes: 46 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,41 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
return x_fp8, out_bs


def _rocm_aiter_triton_add_rmsnorm_pad_impl(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
residual: torch.Tensor,
x_pad_to_multiple: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad

return fused_add_rmsnorm_pad(
x,
weight,
variance_epsilon,
residual,
x_pad_to_multiple=x_pad_to_multiple,
)


def _rocm_aiter_triton_add_rmsnorm_pad_fake(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
residual: torch.Tensor,
x_pad_to_multiple: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
if x_pad_to_multiple > 0:
N_out = (N + x_pad_to_multiple - 1) // x_pad_to_multiple * x_pad_to_multiple
else:
N_out = N
out = torch.empty((M, N_out), dtype=x.dtype, device=x.device)
residual_out = torch.empty_like(residual)
return out, residual_out


# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False

Expand Down Expand Up @@ -1108,6 +1143,13 @@ def register_ops_once() -> None:
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
)

direct_register_custom_op(
op_name="rocm_aiter_triton_add_rmsnorm_pad",
op_func=_rocm_aiter_triton_add_rmsnorm_pad_impl,
fake_impl=_rocm_aiter_triton_add_rmsnorm_pad_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
Expand Down Expand Up @@ -1175,6 +1217,10 @@ def get_group_quant_op() -> OpOverload:
def get_act_mul_fused_fp8_group_quant_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default

@staticmethod
def get_triton_add_rmsnorm_pad_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_triton_add_rmsnorm_pad.default

@staticmethod
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
Expand Down
8 changes: 6 additions & 2 deletions vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

if rocm_aiter_ops.is_enabled():
from vllm.compilation.rocm_aiter_fusion import (
RocmAiterRMSNormFusionPass,
RocmAiterRMSNormQuantFusionPass,
RocmAiterSiluMulFp8GroupQuantFusionPass,
RocmAiterTritonAddRMSNormPadFusionPass,
)

if current_platform.is_cuda_alike():
Expand Down Expand Up @@ -123,13 +124,16 @@ def configure(self, config: VllmConfig) -> None:
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [
RocmAiterRMSNormFusionPass(config),
RocmAiterRMSNormQuantFusionPass(config),
]
if self.pass_config.fuse_act_quant:
self.passes += [ActivationQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]

if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]

if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]

Expand Down
105 changes: 104 additions & 1 deletion vllm/compilation/rocm_aiter_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def replacement(
)


class RocmAiterRMSNormFusionPass(VllmPatternMatcherPass):
class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
into a fused rms_norm_quant op.
Expand Down Expand Up @@ -399,3 +399,106 @@ def uuid(self) -> str:
AiterSiluMulFp8GroupQuantPattern,
]
return VllmInductorPass.hash_source(self, *fusion_patterns)


class AddAiterRMSNormPadPattern:
"""
This pattern replaces an aiter_rmsnorm_with_add & a pad op
with a custom triton_add_rmsnorm_pad op from AITER.
"""

AITER_TRITON_ADD_RMSNORM_PAD_OP = rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()

def __init__(
self,
epsilon: float,
hidden_size: int,
x_pad_to_multiple: int,
):
self.epsilon = epsilon
self.hidden_size = hidden_size
self.x_pad_to_multiple = x_pad_to_multiple
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)

def get_inputs(self) -> list[torch.Tensor]:
input, weight, residual = self.rmsnorm_matcher.inputs()
router_weight = torch.empty([8, 16], dtype=weight.dtype, device=weight.device)
router_bias = torch.empty([8], dtype=weight.dtype, device=weight.device)
return [input, weight, residual, router_weight, router_bias]

def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pad_size = self.x_pad_to_multiple - (
self.hidden_size % self.x_pad_to_multiple
)
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_rms, router_weight, router_bias
)
result = torch.nn.functional.pad(
result_rms, (0, pad_size), mode="constant", value=0.0
)
return result, residual_out, router_logits

def replacement(
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
router_weight: torch.Tensor,
router_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
at = self.AITER_TRITON_ADD_RMSNORM_PAD_OP(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
residual=residual,
x_pad_to_multiple=self.x_pad_to_multiple,
)
result_padded = at[0]
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
result_padded[:, : self.hidden_size], router_weight, router_bias
)
residual_out = at[1]
return result_padded, residual_out, router_logits

pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)


class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass):
"""
This pass replaces an AITER CK RMSNorm + residual add and a pad op
with an triton_add_rmsnorm_pad op from AITER.
"""

def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rocm_aiter_triton_add_rmsnorm_pad_fusion_pass"
)

# gpt-oss has hidden size 2880
# padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively
hidden_size = 2880
for epsilon in [1e-5, 1e-6]:
for x_pad_to_multiple in [128, 256]:
AddAiterRMSNormPadPattern(
epsilon, hidden_size, x_pad_to_multiple
).register(self.patterns)

self.dump_patterns(config, self.patterns)

@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)

def uuid(self) -> str:
return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern)
Loading