-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[ROCM][KERNEL] Paged attention for V1 #15720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
robertgshaw2-redhat
merged 10 commits into
vllm-project:main
from
ROCm:v1_rocm_paged_attention_integration
Apr 3, 2025
Merged
Changes from 2 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
d2fd7c1
initial working commit with C++ kernel
d8f1598
crash fix for absent element in query
a053de0
Merge branch 'upstream/main' into v1_rocm_paged_attention_integration
e7a77aa
removing duplicated code
5e102cb
removing duplicated code
2f1ad6c
acting on comment about optional tensor
f7e8e7d
moved get_device_properties into function to let GPUs initialized
ee68e0e
borken kernel test fix plus comments application
233454b
Merge branch 'upstream/main' into v1_rocm_paged_attention_integration
67d07de
borken kernel test fix plus build fix
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,8 @@ | |
| #include <hip/hip_bf16.h> | ||
| #include "cuda_compat.h" | ||
|
|
||
| #include <inttypes.h> | ||
|
|
||
| #include <algorithm> | ||
| #include "../attention/dtype_fp8.cuh" | ||
| #include "../quantization/fp8/amd/quant_utils.cuh" | ||
|
|
@@ -272,6 +274,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( | |
| const float scale, | ||
| const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] | ||
| const int* __restrict__ context_lens, // [num_seqs] | ||
| const int* __restrict__ query_start_loc_ptr, // [num_seqs] | ||
| const int max_num_blocks_per_seq, | ||
| const float* __restrict__ alibi_slopes, // [num_heads] | ||
| const int q_stride, | ||
|
|
@@ -291,6 +294,13 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( | |
| const int rowid = laneid / 16; | ||
|
|
||
| const auto seq_idx = blockIdx.x; | ||
| // NOTE queries with sequence len > 1 are prefills and taken care by another | ||
| // kernel. | ||
| if (query_start_loc_ptr != nullptr && | ||
| (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) { | ||
| return; | ||
| } | ||
|
|
||
| const auto partition_idx = blockIdx.y; | ||
|
|
||
| constexpr int T_PAR_SIZE = 256; // token partition size set to 256 | ||
|
|
@@ -377,9 +387,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( | |
| // fetch Q in shared across warps and then write to registers | ||
| const int local_qhead_idx = 4 * warpid + rowid; | ||
| const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; | ||
| const int64_t seq_idx64 = static_cast<int64_t>(seq_idx); | ||
| const int64_t query_start_off = | ||
| query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx; | ||
|
||
| const scalar_t* q_ptr = | ||
| q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; | ||
| q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE; | ||
|
|
||
| const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; | ||
| if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { | ||
|
|
@@ -777,6 +788,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( | |
| const float scale, | ||
| const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] | ||
| const int* __restrict__ context_lens, // [num_seqs] | ||
| const int* __restrict__ query_start_loc_ptr, // [num_seqs] | ||
| const int max_num_blocks_per_seq, | ||
| const float* __restrict__ alibi_slopes, // [num_heads] | ||
| const int q_stride, | ||
|
|
@@ -794,6 +806,12 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( | |
| const int lane4id = laneid % 4; | ||
|
|
||
| const auto seq_idx = blockIdx.x; | ||
| // NOTE queries with sequence len > 1 are prefills and taken care by another | ||
| // kernel. | ||
| if (query_start_loc_ptr != nullptr && | ||
| (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { | ||
| return; | ||
| } | ||
| const auto partition_idx = blockIdx.y; | ||
| const auto partition_size = blockDim.x; | ||
| const auto max_num_partitions = gridDim.y; | ||
|
|
@@ -882,9 +900,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( | |
| } | ||
|
|
||
| // fetch q elements | ||
| // every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems | ||
| // every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elemsc | ||
| const int64_t query_start_off = | ||
| query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx; | ||
| const scalar_t* q_ptr = | ||
| q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; | ||
| q + query_start_off * q_stride + wg_start_head_idx * HEAD_SIZE; | ||
| const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr); | ||
| const int qhead_elemh8 = laneid / 4; | ||
|
|
||
|
|
@@ -1267,10 +1287,19 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( | |
| const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, | ||
| // max_num_partitions, head_size] | ||
| const int* __restrict__ context_lens, // [num_seqs] | ||
| const int* __restrict__ query_start_loc_ptr, // [num_seqs] | ||
| const int max_num_partitions) { | ||
| const auto num_heads = gridDim.x; | ||
| const auto head_idx = blockIdx.x; | ||
| const auto seq_idx = blockIdx.y; | ||
|
|
||
| // NOTE queries with sequence len > 1 are prefills and taken care by another | ||
| // kernel. | ||
| if (query_start_loc_ptr != nullptr && | ||
| (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { | ||
| return; | ||
| } | ||
|
|
||
| const int context_len = context_lens[seq_idx]; | ||
| const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); | ||
| [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; | ||
|
|
@@ -1439,7 +1468,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( | |
| __fdividef(1.0f, shared_global_exp_sum + 1e-6f); | ||
| acc *= inv_global_exp_sum; | ||
|
|
||
| OUTT* out_ptr = out + static_cast<int64_t>(seq_idx) * num_heads * HEAD_SIZE + | ||
| const int64_t query_start_off = | ||
| query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx; | ||
| OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + | ||
| static_cast<int64_t>(head_idx) * HEAD_SIZE; | ||
| if constexpr (std::is_same<OUTT, bit8_t>::value) { | ||
| out_ptr[threadIdx.x] = | ||
|
|
@@ -1466,6 +1497,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( | |
| const float scale, | ||
| const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] | ||
| const int* __restrict__ context_lens, // [num_seqs] | ||
| const int* __restrict__ query_start_loc_ptr, // [num_seqs] | ||
| const int max_num_blocks_per_seq, | ||
| const float* __restrict__ alibi_slopes, // [num_heads] | ||
| const int q_stride, | ||
|
|
@@ -1492,6 +1524,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( | |
| const float scale, | ||
| const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] | ||
| const int* __restrict__ context_lens, // [num_seqs] | ||
| const int* __restrict__ query_start_loc_ptr, // [num_seqs] | ||
| const int max_num_blocks_per_seq, | ||
| const float* __restrict__ alibi_slopes, // [num_heads] | ||
| const int q_stride, | ||
|
|
@@ -1515,41 +1548,42 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( | |
| const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] | ||
| const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] | ||
| const int* __restrict__ context_lens, // [num_seqs] | ||
| const int* __restrict__ query_start_loc_ptr, // [num_seqs] | ||
| const int max_num_partitions) { | ||
| UNREACHABLE_CODE | ||
| } | ||
| // clang-format on | ||
|
|
||
| #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support | ||
|
|
||
| #define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ | ||
| paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \ | ||
| HEAD_SIZE, NTHR, ALIBI_ENABLED, \ | ||
| GQA_RATIO> \ | ||
| <<<grid, block, 0, stream>>>( \ | ||
| query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ | ||
| block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ | ||
| alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ | ||
| exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ | ||
| k_scale_ptr, v_scale_ptr); | ||
|
|
||
| #define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ | ||
| paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \ | ||
| HEAD_SIZE, NTHR, ALIBI_ENABLED, \ | ||
| GQA_RATIO> \ | ||
| <<<grid, block, 0, stream>>>( \ | ||
| query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ | ||
| block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ | ||
| alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ | ||
| exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ | ||
| k_scale_ptr, v_scale_ptr); | ||
| #define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ | ||
| paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \ | ||
| HEAD_SIZE, NTHR, ALIBI_ENABLED, \ | ||
| GQA_RATIO> \ | ||
| <<<grid, block, 0, stream>>>( \ | ||
| query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ | ||
| block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ | ||
| max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ | ||
| kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ | ||
| max_ctx_blocks, k_scale_ptr, v_scale_ptr); | ||
|
|
||
| #define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ | ||
| paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \ | ||
| HEAD_SIZE, NTHR, ALIBI_ENABLED, \ | ||
| GQA_RATIO> \ | ||
| <<<grid, block, 0, stream>>>( \ | ||
| query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ | ||
| block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ | ||
| max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ | ||
| kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ | ||
| max_ctx_blocks, k_scale_ptr, v_scale_ptr); | ||
|
|
||
| #define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ | ||
| paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \ | ||
| PARTITION_SIZE, NPAR_LOOPS> \ | ||
| <<<reduce_grid, reduce_block, 0, stream>>>( \ | ||
| out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ | ||
| context_lens_ptr, max_num_partitions); | ||
| context_lens_ptr, query_start_loc_ptr, max_num_partitions); | ||
|
|
||
| template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE, | ||
| int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD, | ||
|
|
@@ -1559,16 +1593,24 @@ void paged_attention_custom_launcher( | |
| torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, | ||
| torch::Tensor& value_cache, const int num_kv_heads, float scale, | ||
| torch::Tensor& block_tables, torch::Tensor& context_lens, | ||
| int max_context_len, const std::optional<torch::Tensor>& alibi_slopes, | ||
| torch::Tensor& k_scale, torch::Tensor& v_scale) { | ||
| int num_seqs = query.size(0); | ||
| const std::optional<torch::Tensor>& query_start_loc, int max_context_len, | ||
| const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale, | ||
| torch::Tensor& v_scale) { | ||
| int num_seqs = block_tables.size(0); | ||
| int num_heads = query.size(1); | ||
| int head_size = query.size(2); | ||
| int max_num_blocks_per_seq = block_tables.size(1); | ||
| int q_stride = query.stride(0); | ||
| int kv_block_stride = key_cache.stride(0); | ||
| int kv_head_stride = key_cache.stride(1); | ||
|
|
||
| // NOTE: query start location is optional for V0 decode should not be used. | ||
| // If batch contains mix of prefills and decode, prefills should be skipped. | ||
| const int* query_start_loc_ptr = | ||
| query_start_loc | ||
| ? reinterpret_cast<const int*>(query_start_loc.value().data_ptr()) | ||
| : nullptr; | ||
|
|
||
| // NOTE: alibi_slopes is optional. | ||
| const float* alibi_slopes_ptr = | ||
| alibi_slopes | ||
|
|
@@ -1700,8 +1742,8 @@ void paged_attention_custom_launcher( | |
| paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ | ||
| PSIZE, ALIBI_ENABLED>( \ | ||
| out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ | ||
| num_kv_heads, scale, block_tables, context_lens, max_context_len, \ | ||
| alibi_slopes, k_scale, v_scale); | ||
| num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ | ||
| max_context_len, alibi_slopes, k_scale, v_scale); | ||
|
|
||
| #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ | ||
| PSIZE) \ | ||
|
|
@@ -1750,6 +1792,7 @@ void paged_attention( | |
| double scale, | ||
| torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] | ||
| torch::Tensor& context_lens, // [num_seqs] | ||
| const std::optional<torch::Tensor>& query_start_loc, // [num_seqs] | ||
| int64_t block_size, int64_t max_context_len, | ||
| const std::optional<torch::Tensor>& alibi_slopes, | ||
| const std::string& kv_cache_dtype, torch::Tensor& k_scale, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.