Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering what's difference between VLLM_ROCM_USE_AITER and VLLM_ROCM_USE_AITER_MHA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering what's difference between VLLM_ROCM_USE_AITER and VLLM_ROCM_USE_AITER_MHA?

Main switch and submodule switch.

VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
Expand Down Expand Up @@ -566,6 +567,13 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_ROCM_USE_AITER_MLA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in
("true", "1")),

# Whether to use aiter mha ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MHA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to override #16828 by default?

("true", "1")),

# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
Expand Down
75 changes: 57 additions & 18 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


def is_rocm_aiter_rmsnorm_enabled() -> bool:
Expand Down Expand Up @@ -42,29 +43,67 @@ def fused_add_rms_norm(
return x, residual


def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
if is_rocm_aiter_rmsnorm_enabled():

import aiter as rocm_aiter
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
def _rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:

import aiter as rocm_aiter
return rocm_aiter.rms_norm(x, weight, variance_epsilon)

def rocm_aiter_fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
def rocm_aiter_rms_norm_fake(input: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
return input.clone()

import aiter as rocm_aiter
try:
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=_rocm_aiter_rms_norm,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITS: As standardizing the name like in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
_rocm_aiter_rms_norm -> rocm_aiter_rms_norm_impl

mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake,
)
rocm_aiter_rms_norm = torch.ops.vllm.rocm_aiter_rms_norm

except AttributeError:
Copy link
Collaborator

@tjtanaa tjtanaa May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to use the try catch statement. Because the registration must work as vLLM is going to deprecate V0. If registration does not work when aiter is present on ROCm env, this could mean there is a bug.

An example unit tests to check if the registration works is as follows https://github.com/vllm-project/vllm/blob/main/tests/kernels/moe/test_rocm_aiter_topk.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

rocm_aiter_rms_norm = _rocm_aiter_rms_norm

def _rocm_aiter_fused_add_rms_norm(
input: torch.Tensor, residual_in: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:

import aiter as rocm_aiter
residual_out = torch.empty_like(residual_in)
output = torch.empty_like(input)
rocm_aiter.rmsnorm2d_fwd_with_add(
output, # output
input, # input
residual_in, # residual input
residual_out, # residual output
weight,
variance_epsilon,
)

# Assuming the correct signature for rmsnorm2d_fwd_with_add
rocm_aiter.rmsnorm2d_fwd_with_add(
x, # output
x, # input
residual, # residual input
residual, # residual output
weight,
variance_epsilon,
)
return x, residual
return output, residual_out

def rocm_aiter_fused_add_rms_norm_fake(
input: torch.Tensor, residual_in: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
return input.clone(), residual_in.clone()

try:
direct_register_custom_op(
op_name="rocm_aiter_fused_add_rms_norm",
op_func=_rocm_aiter_fused_add_rms_norm,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITS: As standardizing the name like in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
_rocm_aiter_fused_add_rms_norm -> rocm_aiter_fused_add_rms_norm_impl

mutates_args=[],
fake_impl=rocm_aiter_fused_add_rms_norm_fake,
)
rocm_aiter_fused_add_rms_norm = \
torch.ops.vllm.rocm_aiter_fused_add_rms_norm

except AttributeError:
rocm_aiter_fused_add_rms_norm = _rocm_aiter_fused_add_rms_norm


def dispatch_cuda_rmsnorm_func(add_residual: bool):
Expand Down
11 changes: 8 additions & 3 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,14 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1:
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add on_mi250_mi300() to the condition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MHA should be used in mi350 too. I won't add the condition.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
It is fine to leave the condition out if we don't expect Radeon GPU users to use AITER.

logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
else:
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
Expand Down
211 changes: 189 additions & 22 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,143 @@
if current_platform.is_cuda():
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata)
if current_platform.is_rocm():
import aiter

from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op

@triton.jit
def _vllm_layout_trans_kernel(
k_buffer_ptr,
v_buffer_ptr,
k_values_ptr,
v_values_ptr,
b_seq_lens_loc,
block_table,
block_table_stride_0,
E_DIM: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx +
tl.arange(0, 2))
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
seq_len = batch_token_end - batch_token_start
if block_idx * BLOCK_SIZE < seq_len:
block_mask = (block_idx * BLOCK_SIZE +
tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len

kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 +
block_idx)

kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange(
0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :]
k_vals = tl.load(k_buffer_ptr + kv_buffer_off,
mask=block_mask,
other=0.0)
v_vals = tl.load(v_buffer_ptr + kv_buffer_off,
mask=block_mask,
other=0.0)

kv_values_off = batch_token_start * E_DIM + \
block_idx * BLOCK_SIZE * E_DIM + \
tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \
tl.arange(0, E_DIM)[None, :]
tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask)
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)

def vllm_layout_trans(b_seq_lens_loc, block_table, k_buffer, v_buffer,
max_seq_len, total_tokens):
H_KV = v_buffer.shape[2]
D = v_buffer.shape[3]
BLOCK_SIZE = v_buffer.shape[1]
dtype = k_buffer.dtype
k_values = torch.empty((total_tokens, H_KV, D),
dtype=dtype,
device="cuda")
v_values = torch.empty((total_tokens, H_KV, D),
dtype=dtype,
device="cuda")

grid = (block_table.shape[0],
(max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)

_vllm_layout_trans_kernel[grid](k_buffer,
v_buffer,
k_values,
v_values,
b_seq_lens_loc,
block_table,
block_table.stride(0),
E_DIM=H_KV * D,
BLOCK_SIZE=BLOCK_SIZE)

return k_values, v_values

def _flash_attn_varlen_func(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
total_tokens: int,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
window_size: Optional[list[int]], # -1 means infinite context window
alibi_slopes: Optional[list[float]],
block_table: torch.Tensor,
) -> torch.Tensor:
k, v = vllm_layout_trans(cu_seqlens_k, block_table, k_cache, v_cache,
max_seqlen_k, total_tokens)
output = aiter.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
window_size=window_size,
out=out,
)
return output

def flash_attn_varlen_func_fake(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
total_tokens: int,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
window_size: Optional[list[int]], # -1 means infinite context window
alibi_slopes: Optional[list[float]],
block_table: torch.Tensor,
) -> torch.Tensor:
return torch.empty(q.shape[0],
q.shape[1],
v_cache.shape[-2],
dtype=torch.float8_e4m3fnuz,
device="cuda")

try:
direct_register_custom_op("flash_attn_varlen_func",
_flash_attn_varlen_func, ["out"],
flash_attn_varlen_func_fake)
flash_attn_varlen_func = torch.ops.vllm.flash_attn_varlen_func

except AttributeError:
flash_attn_varlen_func = _flash_attn_varlen_func

logger = init_logger(__name__)

Expand Down Expand Up @@ -83,6 +220,8 @@ class FlashAttentionMetadata:
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
cu_seq_lens: torch.Tensor
total_tokens: int
block_table: torch.Tensor
slot_mapping: torch.Tensor

Expand Down Expand Up @@ -321,12 +460,20 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
total_tokens = self.runner.seq_lens_np[:num_reqs].sum()
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]

cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
dtype=torch.int32,
device="cuda")
torch.cumsum(seq_lens,
dim=0,
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be
Expand Down Expand Up @@ -440,6 +587,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
cu_seq_lens=cu_seq_lens,
total_tokens=total_tokens,
block_table=block_table,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
Expand Down Expand Up @@ -605,28 +754,46 @@ def forward(
scheduler_metadata = attn_metadata.scheduler_metadata

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if current_platform.is_rocm():
cu_seq_lens = attn_metadata.cu_seq_lens
total_tokens = attn_metadata.total_tokens
flash_attn_varlen_func(
query[:num_actual_tokens],
key_cache,
value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
total_tokens=total_tokens,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
window_size=list(self.sliding_window),
block_table=block_table,
cu_seqlens_k=cu_seq_lens,
)
else:
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output

assert not use_local_attn, (
Expand Down
Loading