Skip to content
Merged
3 changes: 2 additions & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main(
if version == "v2":
if current_platform.is_rocm():
global PARTITION_SIZE
if not args.custom_paged_attn:
if not args.custom_paged_attn and not current_platform.is_navi():
PARTITION_SIZE = 1024
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
Expand Down Expand Up @@ -166,6 +166,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
scale,
block_tables,
seq_lens,
None,
block_size,
max_seq_len,
alibi_slopes,
Expand Down
2,029 changes: 1,858 additions & 171 deletions csrc/rocm/attention.cu

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale);
torch::Tensor& v_scale, bool is_navi);
3 changes: 2 additions & 1 deletion csrc/rocm/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
" Tensor k_scale, Tensor v_scale,"
" bool is_navi) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}

Expand Down
8 changes: 7 additions & 1 deletion tests/kernels/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ def test_paged_attention(
or (version == "rocm" and head_size not in (64, 128))):
pytest.skip()

if (version == "rocm" and current_platform.is_navi()
and (kv_cache_dtype == "fp8" or head_size != 128
or block_size != 16 or use_alibi)):
pytest.skip()

global PARTITION_SIZE

current_platform.seed_everything(seed)
Expand Down Expand Up @@ -275,6 +280,7 @@ def test_paged_attention(
scale,
block_tables,
seq_lens,
None,
block_size,
max_seq_len,
alibi_slopes,
Expand All @@ -286,7 +292,7 @@ def test_paged_attention(
opcheck(torch.ops._rocm_C.paged_attention,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
seq_lens, None, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
Expand Down
26 changes: 20 additions & 6 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,27 @@ def paged_attention_rocm(
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
is_navi: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you delete this argument. It looks like its unused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
query_start_loc, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale,
v_scale)
torch.ops._rocm_C.paged_attention(out,
exp_sum,
max_logits,
tmp_out,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
query_start_loc,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
is_navi=current_platform.is_navi())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we determine if we are running on navi hardware inside of the kernel? Looking at the archname in the kernel dispatching function seems reasonable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated



def mla_decode_kvcache_cpu(
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,8 @@ def forward(
gqa_ratio = num_heads // self.num_kv_heads
use_custom = use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window)
decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/ops/chunked_prefill_paged_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def chunked_prefill_paged_decode(
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
block_size,
num_queries_per_kv,
max_seq_len, sliding_window)
max_seq_len, sliding_window,
kv_cache_dtype, alibi_slopes)
if use_custom:
_PARTITION_SIZE_ROCM = 256
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
Expand Down
53 changes: 39 additions & 14 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,24 +100,45 @@ def on_mi250_mi300() -> bool:
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])


def on_navi3_navi4() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])


@cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int,
sliding_window: int) -> bool:
def use_rocm_custom_paged_attention(
qtype: torch.dtype,
head_size: int,
block_size: int,
gqa_ratio: int,
max_seq_len: int,
sliding_window: int,
kv_cache_dtype: str,
alibi_slopes: Optional[torch.Tensor] = None) -> bool:

# rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER))
if on_mi250_mi300():
return ((not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER))

else:
return (on_navi3_navi4()
and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 32768 and alibi_slopes is None
and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)


class RocmPlatform(Platform):
Expand Down Expand Up @@ -344,3 +365,7 @@ def use_custom_allreduce(cls) -> bool:
def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(
device_id).multi_processor_count

@classmethod
def is_navi(cls) -> bool:
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName