Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from torch._ops import OpOverload

from vllm.platforms.rocm import on_gfx12x
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
Expand Down Expand Up @@ -50,9 +50,9 @@ def is_aiter_found_and_supported() -> bool:
VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery.
"""
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
from vllm.platforms.rocm import on_gfx9, on_gfx12x

return on_gfx9()
return on_gfx9() or on_gfx12x()
return False


Expand Down Expand Up @@ -1769,11 +1769,12 @@ def group_fp8_quant(
group_size: int = 128,
) -> tuple[torch.Tensor, torch.Tensor]:
assert group_size == 128, "Group size must be 128"

return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size)

@staticmethod
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
return (n, k) in [
shapes = [
(1024, 8192),
(2112, 7168),
(3072, 1536),
Expand Down Expand Up @@ -1910,7 +1911,70 @@ def flash_attn_varlen_func(
even when VLLM_ROCM_USE_AITER=0.

Note: This performs lazy import of aiter.flash_attn_varlen_func
For gfx12x (RDNA4) GPUs, uses mha_v3 Triton kernel for better performance.
"""
# Check if we should use mha_v3 for gfx12x
# mha_v3 provides better performance on RDNA4 architectures
if on_gfx12x():
# Check if mha_v3 can be used (no alibi, no dropout, no custom window size)
# window_size=None means infinite context window (default behavior)
can_use_mha_v3 = (
alibi_slopes is None
and dropout_p == 0.0
and (window_size is None or window_size == (-1, -1))
and not return_lse # mha_v3 doesn't support returning LSE
)

if can_use_mha_v3:
try:
from aiter.ops.triton.attention.mha_v3 import (
flash_attn_varlen_func as mha_v3_varlen_func,
)

# mha_v3 has a different signature - convert parameters
# mha_v3 expects window_size as a tuple, default is (-1, -1) for infinite window
mha_v3_window_size = (
(-1, -1) if window_size is None else window_size
)

mha_v3_out = mha_v3_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=mha_v3_window_size,
attention_chunk=0,
softcap=0.0,
deterministic=False,
sm_margin=0,
)

if out is not None:
out.copy_(mha_v3_out)
return out
return mha_v3_out

except Exception as e:
# Fall through to default implementation if mha_v3 fails
# Log the error for debugging (only once to avoid spam)
import warnings

warnings.warn(
f"mha_v3 failed, falling back to default: {e}",
RuntimeWarning,
stacklevel=2,
)
pass

# Default: use standard flash_attn_varlen_func (CK-based or Triton-based)
from aiter import flash_attn_varlen_func

return flash_attn_varlen_func(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
{
"triton_version": "3.5.1+rocm7.2.0.gita272dfa8",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 1
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"triton_version": "3.5.1+rocm7.2.0.gita272dfa8",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 1
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 4
}
}
6 changes: 4 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_batched_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,13 +921,15 @@ def _supports_quant_scheme(
) -> bool:
p = current_platform
if p.is_rocm():
from vllm.platforms.rocm import on_gfx9
from vllm.platforms.rocm import on_gfx9, on_gfx12x

is_rocm_on_gfx9 = on_gfx9()
is_rocm_on_gfx12 = on_gfx12x()
else:
is_rocm_on_gfx9 = False
is_rocm_on_gfx12 = False

device_supports_fp8 = is_rocm_on_gfx9 or (
device_supports_fp8 = is_rocm_on_gfx9 or is_rocm_on_gfx12 or (
p.is_cuda() and p.has_device_capability((8, 9))
)

Expand Down
14 changes: 10 additions & 4 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,14 +1953,18 @@ def _supports_quant_scheme(
) -> bool:
p = current_platform
if p.is_rocm():
from vllm.platforms.rocm import on_gfx9
from vllm.platforms.rocm import on_gfx9, on_gfx12x

is_rocm_on_gfx9 = on_gfx9()
is_rocm_on_gfx12 = on_gfx12x()
else:
is_rocm_on_gfx9 = False
is_rocm_on_gfx12 = False

device_supports_fp8 = is_rocm_on_gfx9 or (
p.is_cuda() and p.has_device_capability((8, 9))
device_supports_fp8 = (
is_rocm_on_gfx9
or is_rocm_on_gfx12
or (p.is_cuda() and p.has_device_capability((8, 9)))
)

if not device_supports_fp8:
Expand Down Expand Up @@ -2056,7 +2060,9 @@ def apply(
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids
)

# print("weight N=", N)
# print("weight M=", E)
# print('weight dtype:', hidden_states.dtype)
if global_num_experts == -1:
global_num_experts = E

Expand Down
12 changes: 9 additions & 3 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def _get_gcn_arch() -> str:
# These are plain Python bools — fully torch.compile/Dynamo safe.
_GCN_ARCH = _get_gcn_arch()

_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11"])
_ON_GFX12X = any(arch in _GCN_ARCH for arch in ["gfx12"])
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
_ON_GFX942 = "gfx942" in _GCN_ARCH
Expand Down Expand Up @@ -226,6 +227,10 @@ def on_gfx1x() -> bool:
return _ON_GFX1X


def on_gfx12x() -> bool:
return _ON_GFX12X


def on_mi3xx() -> bool:
return _ON_MI3XX

Expand Down Expand Up @@ -286,7 +291,7 @@ def use_rocm_custom_paged_attention(

@cache
def flash_attn_triton_available() -> bool:
if not on_gfx1x():
if not on_gfx1x() and not on_gfx12x():
return False
try:
from importlib.util import find_spec
Expand Down Expand Up @@ -434,7 +439,7 @@ def get_attn_backend_cls(

# Priority 2: Check for AITER MHA (Flash Attention)
# Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and (on_gfx9() or on_gfx12x()):
logger.info("Using Aiter Flash Attention backend.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()

Expand Down Expand Up @@ -511,6 +516,7 @@ def get_vit_attn_backend(
# RDNA3/RDNA4 (gfx11xx/gfx12xx): Use Flash Attention Triton backend
if (
on_gfx1x()
and on_gfx12x()
and flash_attn_triton_available()
and (dtype == torch.float16 or dtype == torch.bfloat16)
):
Expand Down