Skip to content

Commit

Permalink
[Bugfix]: Fix paged attention unit tests of #372 (#389)
Browse files Browse the repository at this point in the history
* [Bugfix]: fix paged attention tests based on the updated kernels in `csrc/attention/paged_attention_v1.cu`,`csrc/attention/paged_attention_v2.cu` and  `csrc/rocm/attention.cu`.

* improve code documentation.

* lint

---------

Co-authored-by: vllmellm <[email protected]>
  • Loading branch information
tjtanaa and vllmellm authored Jan 28, 2025
1 parent 5510e8c commit 49dfc1d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 21 deletions.
4 changes: 3 additions & 1 deletion csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -701,13 +701,15 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(

__syncthreads();

// disable rtz conversion due to its impact on accuracy.
constexpr bool LOGITS_RTZ_CONVERSION = false;

// write logits to shared mem
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] *= inv_sum_scale;
if constexpr (LOGITS_RTZ_CONVERSION) {
// use rtz conversion for performance, with no visible impact on accuracy
// use rtz conversion for better performance, with negligible impact on
// accuracy.
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx4_rtz<scalar_t>(dout[token_depth]);
} else {
Expand Down
46 changes: 26 additions & 20 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes
from vllm.utils import get_max_shared_memory_bytes, is_navi

from .allclose_default import get_default_atol, get_default_rtol

Expand All @@ -33,7 +33,7 @@

# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]

BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
Expand Down Expand Up @@ -116,7 +116,8 @@ def ref_single_query_cached_kv_attention(


@pytest.mark.parametrize(
"version", ["v1", "v2"] if not current_platform.is_rocm() else ["rocm"])
"version",
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand Down Expand Up @@ -181,7 +182,11 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0]

# Using default kv_scale
k_scale = v_scale = torch.tensor(0.3, dtype=torch.float)
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32)

# additional argument for v1/v2 pa kernel
num_threads = 1024 if current_platform.is_rocm() \
and not is_navi() else 128

# Call the paged attention kernel.
output = torch.empty_like(query)
Expand All @@ -203,12 +208,12 @@ def test_paged_attention(
v_scale,
)

opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
opcheck(
torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, num_threads),
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]))

elif version in ("v2", "rocm"):
if current_platform.is_rocm():
Expand Down Expand Up @@ -247,13 +252,14 @@ def test_paged_attention(
v_scale,
)

opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
opcheck(
torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query, key_cache,
value_cache, num_kv_heads, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, 0, 0, 0, 64, 0, num_threads),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

else:
ops.paged_attention_rocm(
Expand Down Expand Up @@ -299,14 +305,14 @@ def test_paged_attention(
dtype=dtype,
device=device)
ops.convert_fp8(dequantized_key_cache, key_cache)
key_cache = k_scale * dequantized_key_cache
key_cache = dequantized_key_cache

value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(dequantized_value_cache, value_cache)
value_cache = v_scale * dequantized_value_cache
value_cache = dequantized_value_cache

ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
Expand Down Expand Up @@ -434,4 +440,4 @@ def test_multi_query_kv_attention(
)
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)

0 comments on commit 49dfc1d

Please sign in to comment.