Skip to content

Commit 7af074d

Browse files
committed
Enable AMD Radeon GPU Custom Paged Attention on v1
Signed-off-by: Hosang Yoon <[email protected]>
1 parent 7c49487 commit 7af074d

File tree

7 files changed

+150
-62
lines changed

7 files changed

+150
-62
lines changed

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
NUM_BLOCKS = 128 * 1024
1818
PARTITION_SIZE = 512
1919
PARTITION_SIZE_ROCM = 256
20-
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
21-
ON_NAVI = "gfx1" in GPU_ARCH
2220

2321

2422
@torch.inference_mode()
@@ -88,7 +86,7 @@ def main(
8886
if version == "v2":
8987
if current_platform.is_rocm():
9088
global PARTITION_SIZE
91-
if not args.custom_paged_attn and not ON_NAVI:
89+
if not args.custom_paged_attn and not current_platform.is_navi():
9290
PARTITION_SIZE = 1024
9391
else:
9492
PARTITION_SIZE = PARTITION_SIZE_ROCM
@@ -168,13 +166,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
168166
scale,
169167
block_tables,
170168
seq_lens,
169+
None,
171170
block_size,
172171
max_seq_len,
173172
alibi_slopes,
174173
kv_cache_dtype,
175174
k_scale,
176175
v_scale,
177-
ON_NAVI,
178176
)
179177
else:
180178
raise ValueError(f"Invalid version: {version}")

csrc/rocm/attention.cu

Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,7 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
15811581
}
15821582
}
15831583

1584+
// clang-format off
15841585
template <typename scalar_t, typename cache_t,
15851586
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
15861587
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
@@ -1594,6 +1595,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
15941595
const int num_kv_heads, const float scale,
15951596
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
15961597
const int* __restrict__ context_lens, // [num_seqs]
1598+
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
15971599
const int max_num_blocks_per_seq,
15981600
const float* __restrict__ alibi_slopes, // [num_heads]
15991601
const int q_stride, const int kv_block_stride, const int kv_head_stride,
@@ -1604,6 +1606,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
16041606
// head_size]
16051607
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
16061608
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
1609+
// clang-format on
16071610
constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11
16081611
const int warpid = threadIdx.x / WARP_SIZE;
16091612
const int laneid = threadIdx.x % WARP_SIZE;
@@ -1613,6 +1616,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
16131616
const int rowid = laneid / 16;
16141617

16151618
const int seq_idx = blockIdx.x;
1619+
// NOTE queries with sequence len > 1 are prefills and taken care by another
1620+
// kernel.
1621+
if (query_start_loc_ptr != nullptr &&
1622+
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) {
1623+
return;
1624+
}
1625+
16161626
const int partition_idx = blockIdx.y;
16171627

16181628
constexpr int T_PAR_SIZE = 256; // token partition size set to 256
@@ -1671,12 +1681,14 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
16711681
// output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens
16721682
// across 2 rows x 8 tokens per lane
16731683

1684+
const int64_t query_start_off = static_cast<int64_t>(
1685+
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
1686+
16741687
if (GQA_RATIO == 1) {
16751688
const int local_qhead_idx = lane16id % GQA_RATIO;
16761689
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
1677-
const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
16781690
const scalar_t* q_ptr =
1679-
q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE;
1691+
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
16801692
if (lane16id < GQA_RATIO) {
16811693
#pragma unroll
16821694
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) {
@@ -1690,9 +1702,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
16901702
// fetch Q in shared across warps and then write to registers
16911703
const int local_qhead_idx = 2 * warpid + rowid;
16921704
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
1693-
const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
16941705
const scalar_t* q_ptr =
1695-
q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE;
1706+
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
16961707

16971708
const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
16981709
if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
@@ -2024,6 +2035,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
20242035
const int num_kv_heads, const float scale,
20252036
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
20262037
const int* __restrict__ context_lens, // [num_seqs]
2038+
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
20272039
const int max_num_blocks_per_seq,
20282040
const float* __restrict__ alibi_slopes, // [num_heads]
20292041
const int q_stride, const int kv_block_stride, const int kv_head_stride,
@@ -2050,15 +2062,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
20502062
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
20512063
// max_num_partitions, head_size]
20522064
const int* __restrict__ context_lens, // [num_seqs]
2065+
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
20532066
const int max_num_partitions) {
2054-
const int num_heads = gridDim.x;
2055-
const int head_idx = blockIdx.x;
2056-
const int seq_idx = blockIdx.y;
2067+
const auto num_heads = gridDim.x;
2068+
const auto head_idx = blockIdx.x;
2069+
const auto seq_idx = blockIdx.y;
2070+
2071+
// NOTE queries with sequence len > 1 are prefills and taken care by another
2072+
// kernel.
2073+
if (query_start_loc_ptr != nullptr &&
2074+
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
2075+
return;
2076+
}
2077+
20572078
const int context_len = context_lens[seq_idx];
20582079
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
2059-
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
2080+
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
20602081
const int warpid = threadIdx.x / WARP_SIZE;
2061-
const int laneid = threadIdx.x % WARP_SIZE;
2082+
[[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE;
20622083

20632084
__shared__ float shared_global_exp_sum;
20642085
// max num partitions supported is warp_size * NPAR_LOOPS
@@ -2221,7 +2242,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
22212242
const float inv_global_exp_sum =
22222243
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
22232244
acc *= inv_global_exp_sum;
2224-
OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
2245+
2246+
const int64_t query_start_off = static_cast<int64_t>(
2247+
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
2248+
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
2249+
static_cast<int64_t>(head_idx) * HEAD_SIZE;
22252250
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
22262251
}
22272252

@@ -2328,6 +2353,7 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
23282353
}
23292354
}
23302355

2356+
// clang-format off
23312357
template <typename scalar_t, typename cache_t,
23322358
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
23332359
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
@@ -2341,6 +2367,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
23412367
const int num_kv_heads, const float scale,
23422368
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
23432369
const int* __restrict__ context_lens, // [num_seqs]
2370+
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
23442371
const int max_num_blocks_per_seq,
23452372
const float* __restrict__ alibi_slopes, // [num_heads]
23462373
const int q_stride, const int kv_block_stride, const int kv_head_stride,
@@ -2351,6 +2378,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
23512378
// head_size]
23522379
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
23532380
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
2381+
// clang-format on
23542382
constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11
23552383
const int warpid = threadIdx.x / WARP_SIZE;
23562384
const int laneid = threadIdx.x % WARP_SIZE;
@@ -2360,6 +2388,12 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
23602388
const int rowid = laneid / 16;
23612389

23622390
const int seq_idx = blockIdx.x;
2391+
// NOTE queries with sequence len > 1 are prefills and taken care by another
2392+
// kernel.
2393+
if (query_start_loc_ptr != nullptr &&
2394+
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
2395+
return;
2396+
}
23632397
const int partition_idx = blockIdx.y;
23642398

23652399
constexpr int T_PAR_SIZE = 256; // token partition size set to 256
@@ -2419,11 +2453,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
24192453
// output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens
24202454
// across 2 rows x 8 tokens per lane
24212455

2456+
const int64_t query_start_off = static_cast<int64_t>(
2457+
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
2458+
24222459
if (GQA_RATIO == 1) {
24232460
const int local_qhead_idx = lane16id % GQA_RATIO;
24242461
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
2425-
const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
2426-
const scalar_t* q_ptr = q + seq_idx64 * q_stride +
2462+
const scalar_t* q_ptr = q + query_start_off * q_stride +
24272463
global_qhead_idx * HEAD_SIZE +
24282464
rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD;
24292465
if (lane16id < GQA_RATIO) {
@@ -2439,9 +2475,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
24392475
// fetch Q in shared across warps and then write to registers
24402476
const int local_qhead_idx = 2 * warpid + rowid;
24412477
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
2442-
const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
24432478
const scalar_t* q_ptr =
2444-
q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE;
2479+
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
24452480

24462481
const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
24472482
if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
@@ -2736,6 +2771,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
27362771
const int num_kv_heads, const float scale,
27372772
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
27382773
const int* __restrict__ context_lens, // [num_seqs]
2774+
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
27392775
const int max_num_blocks_per_seq,
27402776
const float* __restrict__ alibi_slopes, // [num_heads]
27412777
const int q_stride, const int kv_block_stride, const int kv_head_stride,
@@ -2762,15 +2798,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
27622798
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
27632799
// max_num_partitions, head_size]
27642800
const int* __restrict__ context_lens, // [num_seqs]
2801+
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
27652802
const int max_num_partitions) {
2766-
const int num_heads = gridDim.x;
2767-
const int head_idx = blockIdx.x;
2768-
const int seq_idx = blockIdx.y;
2803+
const auto num_heads = gridDim.x;
2804+
const auto head_idx = blockIdx.x;
2805+
const auto seq_idx = blockIdx.y;
2806+
2807+
// NOTE queries with sequence len > 1 are prefills and taken care by another
2808+
// kernel.
2809+
if (query_start_loc_ptr != nullptr &&
2810+
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
2811+
return;
2812+
}
2813+
27692814
const int context_len = context_lens[seq_idx];
27702815
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
2771-
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
2816+
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
27722817
const int warpid = threadIdx.x / WARP_SIZE;
2773-
const int laneid = threadIdx.x % WARP_SIZE;
2818+
[[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE;
27742819

27752820
__shared__ float shared_global_exp_sum;
27762821
// max num partitions supported is warp_size * NPAR_LOOPS
@@ -2933,7 +2978,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
29332978
const float inv_global_exp_sum =
29342979
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
29352980
acc *= inv_global_exp_sum;
2936-
OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
2981+
2982+
const int64_t query_start_off = static_cast<int64_t>(
2983+
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
2984+
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
2985+
static_cast<int64_t>(head_idx) * HEAD_SIZE;
29372986
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
29382987
}
29392988

@@ -3201,16 +3250,24 @@ void paged_attention_custom_launcher_navi(
32013250
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
32023251
torch::Tensor& value_cache, const int num_kv_heads, float scale,
32033252
torch::Tensor& block_tables, torch::Tensor& context_lens,
3204-
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
3205-
torch::Tensor& k_scale, torch::Tensor& v_scale) {
3206-
int num_seqs = query.size(0);
3253+
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
3254+
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
3255+
torch::Tensor& v_scale) {
3256+
int num_seqs = block_tables.size(0);
32073257
int num_heads = query.size(1);
32083258
int head_size = query.size(2);
32093259
int max_num_blocks_per_seq = block_tables.size(1);
32103260
int q_stride = query.stride(0);
32113261
int kv_block_stride = key_cache.stride(0);
32123262
int kv_head_stride = key_cache.stride(1);
32133263

3264+
// NOTE: query start location is optional for V0 decode should not be used.
3265+
// If batch contains mix of prefills and decode, prefills should be skipped.
3266+
const int* query_start_loc_ptr =
3267+
query_start_loc
3268+
? reinterpret_cast<const int*>(query_start_loc.value().data_ptr())
3269+
: nullptr;
3270+
32143271
// NOTE: Navi does not support alibi_slopes.
32153272
const float* alibi_slopes_ptr = nullptr;
32163273

@@ -3363,14 +3420,14 @@ void paged_attention_custom_launcher_navi(
33633420
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
33643421
PSIZE, ALIBI_ENABLED>( \
33653422
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
3366-
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
3367-
alibi_slopes, k_scale, v_scale); \
3423+
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
3424+
max_context_len, alibi_slopes, k_scale, v_scale); \
33683425
} else { \
33693426
paged_attention_custom_launcher_navi<T, KVT, KV_DTYPE, BLK_SIZE, \
33703427
HEAD_SIZE, T, PSIZE, ALIBI_ENABLED>( \
33713428
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
3372-
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
3373-
alibi_slopes, k_scale, v_scale); \
3429+
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
3430+
max_context_len, alibi_slopes, k_scale, v_scale); \
33743431
}
33753432

33763433
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \

tests/kernels/attention/test_attention.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,7 @@ def test_paged_attention(
148148
or (version == "rocm" and head_size not in (64, 128))):
149149
pytest.skip()
150150

151-
is_rocm_navi = False
152-
if current_platform.is_rocm():
153-
is_rocm_navi = "gfx1" in torch.cuda.get_device_properties(
154-
"cuda").gcnArchName
155-
156-
if (version == "rocm" and is_rocm_navi
151+
if (version == "rocm" and current_platform.is_navi()
157152
and (kv_cache_dtype == "fp8" or head_size != 128
158153
or block_size != 16 or use_alibi)):
159154
pytest.skip()
@@ -285,20 +280,20 @@ def test_paged_attention(
285280
scale,
286281
block_tables,
287282
seq_lens,
283+
None,
288284
block_size,
289285
max_seq_len,
290286
alibi_slopes,
291287
kv_cache_dtype,
292288
k_scale,
293289
v_scale,
294-
is_rocm_navi,
295290
)
296291

297292
opcheck(torch.ops._rocm_C.paged_attention,
298293
(output, exp_sums, max_logits, tmp_output, query,
299294
key_cache, value_cache, num_kv_heads, scale, block_tables,
300-
seq_lens, block_size, max_seq_len, alibi_slopes,
301-
kv_cache_dtype, k_scale, v_scale, is_rocm_navi),
295+
seq_lens, None, block_size, max_seq_len, alibi_slopes,
296+
kv_cache_dtype, k_scale, v_scale),
302297
cond=(head_size == HEAD_SIZES[0]
303298
and block_size == BLOCK_SIZES[0]))
304299

vllm/_custom_ops.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,25 @@ def paged_attention_rocm(
119119
v_scale: torch.Tensor,
120120
is_navi: bool = False,
121121
) -> None:
122-
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
123-
key_cache, value_cache, num_kv_heads,
124-
scale, block_tables, seq_lens,
125-
query_start_loc, block_size, max_seq_len,
126-
alibi_slopes, kv_cache_dtype, k_scale,
127-
v_scale, is_navi)
122+
torch.ops._rocm_C.paged_attention(out,
123+
exp_sum,
124+
max_logits,
125+
tmp_out,
126+
query,
127+
key_cache,
128+
value_cache,
129+
num_kv_heads,
130+
scale,
131+
block_tables,
132+
seq_lens,
133+
query_start_loc,
134+
block_size,
135+
max_seq_len,
136+
alibi_slopes,
137+
kv_cache_dtype,
138+
k_scale,
139+
v_scale,
140+
is_navi=current_platform.is_navi())
128141

129142

130143
def mla_decode_kvcache_cpu(

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,6 @@ def forward(
908908
self.kv_cache_dtype,
909909
layer._k_scale,
910910
layer._v_scale,
911-
_ON_NAVI,
912911
)
913912
else:
914913
output[num_prefill_tokens:] = paged_attn.forward_decode(

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ def chunked_prefill_paged_decode(
283283
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
284284
block_size,
285285
num_queries_per_kv,
286-
max_seq_len, sliding_window)
286+
max_seq_len, sliding_window,
287+
kv_cache_dtype, alibi_slopes)
287288
if use_custom:
288289
_PARTITION_SIZE_ROCM = 256
289290
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //

0 commit comments

Comments
 (0)