Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering what's difference between VLLM_ROCM_USE_AITER and VLLM_ROCM_USE_AITER_MHA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering what's difference between VLLM_ROCM_USE_AITER and VLLM_ROCM_USE_AITER_MHA?

Main switch and submodule switch.

VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
Expand Down Expand Up @@ -581,6 +582,13 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_ROCM_USE_AITER_MLA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in
("true", "1")),

# Whether to use aiter mha ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MHA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to override #16828 by default?

("true", "1")),

# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
Expand Down
83 changes: 59 additions & 24 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


def is_rocm_aiter_rmsnorm_enabled() -> bool:
Expand Down Expand Up @@ -42,36 +43,70 @@ def fused_add_rms_norm(
return x, residual


def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
if is_rocm_aiter_rmsnorm_enabled():

import aiter as rocm_aiter
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:

return rocm_aiter.rms_norm(x, weight, variance_epsilon)
import aiter as rocm_aiter
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)

return rocm_aiter.rms_norm(x, weight, variance_epsilon)

def rocm_aiter_fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
def rocm_aiter_rms_norm_fake(input: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
return input.clone()

import aiter as rocm_aiter
try:
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake,
)
rocm_aiter_rms_norm = torch.ops.vllm.rocm_aiter_rms_norm

except AttributeError:
Copy link
Collaborator

@tjtanaa tjtanaa May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to use the try catch statement. Because the registration must work as vLLM is going to deprecate V0. If registration does not work when aiter is present on ROCm env, this could mean there is a bug.

An example unit tests to check if the registration works is as follows https://github.com/vllm-project/vllm/blob/main/tests/kernels/moe/test_rocm_aiter_topk.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

rocm_aiter_rms_norm = rocm_aiter_rms_norm_impl

def rocm_aiter_fused_add_rms_norm_impl(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:

import aiter as rocm_aiter
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rocm_aiter.rmsnorm2d_fwd_with_add(
output, # output
x, # input
residual, # residual input
residual_out, # residual output
weight,
variance_epsilon,
)
return output, residual_out

def rocm_aiter_fused_add_rms_norm_fake(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
return x.clone(), residual.clone()

try:
direct_register_custom_op(
op_name="rocm_aiter_fused_add_rms_norm",
op_func=rocm_aiter_fused_add_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_fused_add_rms_norm_fake,
)
rocm_aiter_fused_add_rms_norm = \
torch.ops.vllm.rocm_aiter_fused_add_rms_norm

residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rocm_aiter.rmsnorm2d_fwd_with_add(
output, # output
x, # input
residual, # residual input
residual_out, # residual output
weight,
variance_epsilon,
)
return output, residual_out
except AttributeError:
rocm_aiter_fused_add_rms_norm = rocm_aiter_fused_add_rms_norm_impl


def dispatch_cuda_rmsnorm_func(add_residual: bool):
Expand Down
12 changes: 9 additions & 3 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1:
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
and on_mi250_mi300():
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"rocm_aiter_fa.AiterFlashAttentionBackend")
else:
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
Expand Down
Loading