Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions vllm/attention/backends/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_MLA
and (envs.VLLM_ROCM_USE_AITER_MLA
or envs.VLLM_ROCM_USE_AITER_TRITON_MLA)


class AiterMLABackend(MLACommonBackend):
Expand Down Expand Up @@ -362,21 +363,31 @@ def __init__(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")

from aiter import flash_attn_varlen_func
if envs.VLLM_ROCM_USE_AITER_TRITON_MLA:
from aiter.ops.triton.MHA import flash_attn_varlen_func
else:
from aiter import flash_attn_varlen_func

self.flash_attn_varlen_func = flash_attn_varlen_func

def _flash_attn_varlen_diff_headdims(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
softmax_scale: float, return_softmax_lse: bool,
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
output = self.flash_attn_varlen_func(
result = self.flash_attn_varlen_func(
q,
k,
v,
**kwargs,
)

return output
# Transpose the LSE if Triton MHA is used:
# (q.shape[0], num_q_heads) to (num_q_heads, q.shape[0])
if (envs.VLLM_ROCM_USE_AITER_TRITON_MLA
and type(result) is tuple and return_softmax_lse):
output, lse = result
lse = lse.T.contiguous()
return (output, lse)
return result

def _forward_decode(
self,
Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM: bool = True
ROCM_TRITON_MOE_PRESHUFFLE_SCALES: bool = True
VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: bool = False
VLLM_ROCM_USE_AITER_TRITON_MLA: bool = False

def get_default_cache_root():
return os.getenv(
Expand Down Expand Up @@ -1237,15 +1238,15 @@
# Use AITER Triton fused RMSNORM + Quantization
"VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT", "1"))),

Check failure on line 1241 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1241:81: E501 Line too long (92 > 80)
# Use AITER Triton fused elementwise multiply + elementwise addtion
"VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD", "1"))),

Check failure on line 1245 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1245:81: E501 Line too long (82 > 80)
# Use AITER Triton fused rope + zeros + reshape_and_cache
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", "1"))),

Check failure on line 1249 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1249:81: E501 Line too long (94 > 80)
# Use AITER Triton fused FP8 per-token group quant + FP8 batched GEMM
"VLLM_ROCM_USE_AITER_TRITON_FP8_BMM":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FP8_BMM", "1"))),
Expand All @@ -1271,6 +1272,10 @@
# Apply preshuffling for mxfp4 scales for ROCm backend
"ROCM_TRITON_MOE_PRESHUFFLE_SCALES":
lambda: bool(int(os.getenv("ROCM_TRITON_MOE_PRESHUFFLE_SCALES", "1"))),

# Use AITER Triton MLA
"VLLM_ROCM_USE_AITER_TRITON_MLA":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_MLA", "0"))),
}

# --8<-- [end:env-vars-definition]
Expand Down
20 changes: 15 additions & 5 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_MLA
and (envs.VLLM_ROCM_USE_AITER_MLA
or envs.VLLM_ROCM_USE_AITER_TRITON_MLA)


class AiterMLABackend(MLACommonBackend):
Expand Down Expand Up @@ -195,7 +196,10 @@ def __init__(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")

from aiter import flash_attn_varlen_func
if envs.VLLM_ROCM_USE_AITER_TRITON_MLA:
from aiter.ops.triton.mha import flash_attn_varlen_func
else:
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func

def _flash_attn_varlen_diff_headdims(self,
Expand All @@ -205,16 +209,22 @@ def _flash_attn_varlen_diff_headdims(self,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
output = self.flash_attn_varlen_func(
result = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)

return output
# Transpose the LSE if Triton MHA is used:
# (q.shape[0], num_q_heads) to (num_q_heads, q.shape[0])
if (envs.VLLM_ROCM_USE_AITER_TRITON_MLA
and type(result) is tuple and return_softmax_lse):
output, lse = result
lse = lse.T.contiguous()
return (output, lse)
return result

def _forward_decode(
self,
Expand Down
Loading