-
Notifications
You must be signed in to change notification settings - Fork 3.4k
kernel: support slightly faster merge_state_v2 cuda kernel #5381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
2112e0e
kernel: support merge_attn_states cuda kernel
DefTruth 5914254
kernel: support merge_attn_states cuda kernel
DefTruth c5a28f6
kernel: support merge_state_v2 cuda kernel
DefTruth 8845f70
Merge branch 'main' into cuda-merge-attn-states
DefTruth d781b20
Merge branch 'main' into cuda-merge-attn-states
DefTruth File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| // Adapted from https://github.com/vllm-project/vllm/pull/16173 | ||
DefTruth marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #include <ATen/cuda/CUDAContext.h> | ||
| #include <c10/cuda/CUDAGuard.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <optional> | ||
|
|
||
| #include "pytorch_extension_utils.h" | ||
|
|
||
| // Helper functions to convert between different data types | ||
| // (float, half, bfloat16) for the merge attention states kernel. | ||
| inline __device__ float to_float(float u) { | ||
| return u; | ||
| } | ||
| inline __device__ float to_float(half u) { | ||
| return __half2float(u); | ||
| } | ||
| inline __device__ float to_float(__nv_bfloat16 u) { | ||
| return __bfloat162float(u); | ||
| } | ||
| inline __device__ void from_float(float& d, float s) { | ||
| d = s; | ||
| } | ||
| inline __device__ void from_float(half& d, float s) { | ||
| d = __float2half(s); | ||
| } | ||
| inline __device__ void from_float(__nv_bfloat16& d, float s) { | ||
| d = __float2bfloat16(s); | ||
| } | ||
|
|
||
| // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 | ||
| // can be used to combine partial attention results (in the split-KV case) | ||
| // Reference blog: https://zhuanlan.zhihu.com/p/1892966682634473987 | ||
| template <typename scalar_t, const uint NUM_THREADS> | ||
| __global__ void merge_attn_states_kernel( | ||
| scalar_t* output, | ||
| float* output_lse, | ||
| const scalar_t* prefix_output, | ||
| const float* prefix_lse, | ||
| const scalar_t* suffix_output, | ||
| const float* suffix_lse, | ||
| const uint num_tokens, | ||
| const uint num_heads, | ||
| const uint head_size) { | ||
| using pack_128b_t = uint4; | ||
| const uint pack_size = 16 / sizeof(scalar_t); | ||
| const uint threads_per_head = head_size / pack_size; | ||
|
|
||
| const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x; | ||
| const uint token_head_threads = num_tokens * num_heads * threads_per_head; | ||
|
|
||
| if (global_idx >= token_head_threads) return; | ||
|
|
||
| // global_idx -> token_idx + head_idx + pack_idx | ||
| const uint token_head_idx = global_idx / threads_per_head; | ||
| const uint pack_idx = global_idx % threads_per_head; | ||
|
|
||
| const uint token_idx = token_head_idx / num_heads; | ||
| const uint head_idx = token_head_idx % num_heads; | ||
|
|
||
| const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. | ||
| const uint head_offset = token_idx * num_heads * head_size + head_idx * head_size; | ||
| const scalar_t* prefix_head_ptr = prefix_output + head_offset; | ||
| const scalar_t* suffix_head_ptr = suffix_output + head_offset; | ||
| scalar_t* output_head_ptr = output + head_offset; | ||
|
|
||
| float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; | ||
| float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; | ||
| p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse; | ||
| s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse; | ||
|
|
||
| const float max_lse = fmaxf(p_lse, s_lse); | ||
| p_lse = p_lse - max_lse; | ||
| s_lse = s_lse - max_lse; | ||
| const float p_se = expf(p_lse); | ||
| const float s_se = expf(s_lse); | ||
| const float out_se = p_se + s_se; | ||
| const float p_scale = p_se / out_se; | ||
| const float s_scale = s_se / out_se; | ||
|
|
||
| if (pack_offset < head_size) { | ||
| // Pack 128b load | ||
| pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(prefix_head_ptr)[pack_offset / pack_size]; | ||
| pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(suffix_head_ptr)[pack_offset / pack_size]; | ||
| pack_128b_t o_out_pack; | ||
|
|
||
| #pragma unroll | ||
| for (uint i = 0; i < pack_size; ++i) { | ||
| // Always use float for FMA to keep high precision. | ||
| // half(uint16_t), bfloat16, float -> float. | ||
| const float p_out_f = to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]); | ||
| const float s_out_f = to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]); | ||
| // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale) | ||
| const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale); | ||
| // float -> half(uint16_t), bfloat16, float. | ||
| from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f); | ||
| } | ||
|
|
||
| // Pack 128b storage | ||
| reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] = o_out_pack; | ||
| } | ||
| // We only need to write to output_lse once per head. | ||
| if (output_lse != nullptr && pack_idx == 0) { | ||
| float out_lse = logf(out_se) + max_lse; | ||
| output_lse[head_idx * num_tokens + token_idx] = out_lse; | ||
| } | ||
| } | ||
|
|
||
| // The following macro is used to dispatch the conversion function based on | ||
| // the output data type. The FN is a macro that calls a function with | ||
| // template<typename scalar_t>. | ||
| #define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \ | ||
| { \ | ||
| if (scalar_dtype == at::ScalarType::Float) { \ | ||
| fn(float); \ | ||
| } else if (scalar_dtype == at::ScalarType::Half) { \ | ||
| fn(half); \ | ||
| } else if (scalar_dtype == at::ScalarType::BFloat16) { \ | ||
| fn(__nv_bfloat16); \ | ||
| } else { \ | ||
| TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \ | ||
| } \ | ||
| } | ||
|
|
||
| #define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ | ||
| { \ | ||
| merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \ | ||
| reinterpret_cast<scalar_t*>(output.data_ptr()), \ | ||
| output_lse_ptr, \ | ||
| reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \ | ||
| reinterpret_cast<float*>(prefix_lse.data_ptr()), \ | ||
| reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \ | ||
| reinterpret_cast<float*>(suffix_lse.data_ptr()), \ | ||
| num_tokens, \ | ||
| num_heads, \ | ||
| head_size); \ | ||
| } | ||
|
|
||
| /*@brief Merges the attention states from prefix and suffix | ||
| * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d | ||
| * | ||
| * @param output [n,h,d] The output tensor to store the merged attention states. | ||
| * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. | ||
| * @param prefix_output [n,h,d] The prefix attention states. | ||
| * @param prefix_lse [h,d] The log-sum-exp values for the prefix attention | ||
| * states. | ||
| * @param suffix_output [n,h,d] The suffix attention states. | ||
| * @param suffix_lse [h,d] The log-sum-exp values for the suffix attention | ||
| * states. | ||
| */ | ||
| template <typename scalar_t> | ||
| void merge_attn_states_launcher( | ||
| at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] | ||
| std::optional<at::Tensor> output_lse, // [NUM_HEADS, NUM_TOKENS] | ||
| const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] | ||
| const at::Tensor& prefix_lse, // [NUM_HEADS, NUM_TOKENS] | ||
| const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] | ||
| const at::Tensor& suffix_lse // [NUM_HEADS, NUM_TOKENS] | ||
| ) { | ||
| constexpr uint NUM_THREADS = 128; | ||
| const uint num_tokens = output.size(0); | ||
| const uint num_heads = output.size(1); | ||
| const uint head_size = output.size(2); | ||
| const uint pack_size = 16 / sizeof(scalar_t); | ||
| TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); | ||
| float* output_lse_ptr = nullptr; | ||
| if (output_lse.has_value()) { | ||
| output_lse_ptr = output_lse.value().data_ptr<float>(); | ||
| } | ||
| // process one pack elements per thread. float -> 4, half/bf16 -> 8 | ||
| const uint threads_per_head = head_size / pack_size; | ||
| const uint total_threads = num_tokens * num_heads * threads_per_head; | ||
|
|
||
| dim3 block(NUM_THREADS); | ||
| dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); | ||
|
|
||
| LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); | ||
| } | ||
|
|
||
| #define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ | ||
| { merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse); } | ||
|
|
||
| void merge_attn_states( | ||
| at::Tensor& output, | ||
| std::optional<at::Tensor> output_lse, | ||
| const at::Tensor& prefix_output, | ||
| const at::Tensor& prefix_lse, | ||
| const at::Tensor& suffix_output, | ||
| const at::Tensor& suffix_lse) { | ||
| DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.