-
Notifications
You must be signed in to change notification settings - Fork 4k
Add Paged Attention Op for CUDA SM80 support #24595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
6e75e22
paged attention op
aciddelgado aa6fd44
test file amid-debug
aciddelgado 5583aaa
paged attention works
aciddelgado c628e66
everything works and is implemented
aciddelgado 554d0c7
small stuff
aciddelgado 497a22e
lint
aciddelgado c3276d1
address comments
aciddelgado 482240c
fix dmmha rotary and address tianleiwu comments
aciddelgado 6a197dc
increase efficiency
aciddelgado 0f23e6f
correction
aciddelgado cd15f81
update flash and remove redundant inputs
aciddelgado aa9b6f1
bert comment
aciddelgado e0432f8
merge main
aciddelgado 50a3767
docs
aciddelgado 5eccad2
comments
aciddelgado 1c3b53c
docs
aciddelgado File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
|
|
||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/providers/cuda/cuda_common.h" | ||
| #include "core/platform/env_var_utils.h" | ||
| #include "contrib_ops/cuda/bert/paged_attention_impl.h" | ||
| #include "contrib_ops/cuda/bert/paged_attention.h" | ||
| #include "contrib_ops/cuda/bert/paged_attention_helper.h" | ||
| #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" | ||
|
|
||
| using namespace onnxruntime::cuda; | ||
| using namespace ::onnxruntime::common; | ||
| using namespace ONNX_NAMESPACE; | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
| namespace cuda { | ||
|
|
||
| #define REGISTER_KERNEL_TYPED(T) \ | ||
| ONNX_OPERATOR_TYPED_KERNEL_EX( \ | ||
| PagedAttention, \ | ||
| kMSDomain, \ | ||
| 1, \ | ||
| T, \ | ||
| kCudaExecutionProvider, \ | ||
| (*KernelDefBuilder::Create()) \ | ||
| .TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \ | ||
| .TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()) \ | ||
| .InputMemoryType(OrtMemTypeCPUInput, 7) \ | ||
| .InputMemoryType(OrtMemTypeCPUInput, 8), \ | ||
| PagedAttention<T>); | ||
|
aciddelgado marked this conversation as resolved.
|
||
|
|
||
| REGISTER_KERNEL_TYPED(MLFloat16) | ||
| REGISTER_KERNEL_TYPED(BFloat16) | ||
|
|
||
| template <typename T> | ||
| PagedAttention<T>::PagedAttention(const OpKernelInfo& info) | ||
| : CudaKernel(info) { | ||
| int64_t num_heads = 0; | ||
| int64_t kv_num_heads = 0; | ||
| ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); | ||
| ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); | ||
| num_heads_ = static_cast<int>(num_heads); | ||
| kv_num_heads_ = static_cast<int>(kv_num_heads); | ||
| local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1)); | ||
| do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1; | ||
| rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1; | ||
| scale_ = info.GetAttrOrDefault<float>("scale", 0.0f); | ||
| softcap_ = info.GetAttrOrDefault<float>("softcap", 0.0f); | ||
|
|
||
| kernel_options_ = this->GetAttentionKernelOptions(); | ||
| disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); | ||
| } | ||
|
|
||
| template <typename T> | ||
| Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const { | ||
| const Tensor* query = context->Input<Tensor>(0); | ||
| const Tensor* key = context->Input<Tensor>(1); | ||
| const Tensor* value = context->Input<Tensor>(2); | ||
| const Tensor* key_cache = context->Input<Tensor>(3); | ||
| const Tensor* value_cache = context->Input<Tensor>(4); | ||
| const Tensor* cumulative_sequence_length = context->Input<Tensor>(5); | ||
| const Tensor* seqlens = context->Input<Tensor>(6); | ||
| const Tensor* max_query_len = context->Input<Tensor>(7); | ||
| const Tensor* max_seq_len = context->Input<Tensor>(8); | ||
| const Tensor* block_table = context->Input<Tensor>(9); | ||
| const Tensor* slot_mappings = context->Input<Tensor>(10); | ||
| const Tensor* cos_cache = context->Input<Tensor>(11); | ||
| const Tensor* sin_cache = context->Input<Tensor>(12); | ||
|
|
||
| auto& device_prop = GetDeviceProp(); | ||
| PagedAttentionParameters parameters; | ||
| typedef typename ToCudaType<T>::MappedType CudaT; | ||
| PagedAttentionData<CudaT> data; | ||
|
|
||
| // Check shapes of inputs to op and set parameters | ||
| ORT_RETURN_IF_ERROR(paged_attention_helper::CheckInputs(query, | ||
| key, | ||
| value, | ||
| key_cache, | ||
| value_cache, | ||
| cumulative_sequence_length, | ||
| seqlens, | ||
| max_query_len, | ||
| max_seq_len, | ||
| block_table, | ||
| slot_mappings, | ||
| cos_cache, | ||
| sin_cache, | ||
| ¶meters, | ||
| num_heads_, | ||
| kv_num_heads_, | ||
| scale_, | ||
| softcap_, | ||
| device_prop.maxThreadsPerBlock)); | ||
| parameters.local_window_size = local_window_size_; | ||
| parameters.do_rotary = do_rotary_; | ||
| parameters.rotary_interleaved = rotary_interleaved_; | ||
|
|
||
| // Check rotary | ||
| if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "cos_cache and sin_cache must be passed to PagedAttention when do_rotary = 1"); | ||
| } | ||
|
|
||
| // Set output tensor shapes | ||
| TensorShapeVector output_shape(2); | ||
| output_shape[0] = static_cast<int64_t>(parameters.token_count); | ||
| output_shape[1] = static_cast<int64_t>(parameters.hidden_size); | ||
| Tensor* output = context->Output(0, output_shape); | ||
|
|
||
| TensorShapeVector key_cache_out_shape(4); | ||
| key_cache_out_shape[0] = static_cast<int64_t>(parameters.num_blocks); | ||
| key_cache_out_shape[1] = static_cast<int64_t>(parameters.block_size); | ||
| key_cache_out_shape[2] = static_cast<int64_t>(parameters.kv_num_heads); | ||
| key_cache_out_shape[3] = static_cast<int64_t>(parameters.head_size); | ||
| Tensor* key_cache_out = context->Output(1, key_cache_out_shape); | ||
|
|
||
| TensorShapeVector value_cache_out_shape(4); | ||
| value_cache_out_shape[0] = static_cast<int64_t>(parameters.num_blocks); | ||
| value_cache_out_shape[1] = static_cast<int64_t>(parameters.block_size); | ||
| value_cache_out_shape[2] = static_cast<int64_t>(parameters.kv_num_heads); | ||
| value_cache_out_shape[3] = static_cast<int64_t>(parameters.head_size); | ||
| Tensor* value_cache_out = context->Output(2, value_cache_out_shape); | ||
|
|
||
| if (key_cache_out != nullptr && key_cache->Data<T>() != key_cache_out->MutableData<T>()) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "key_cache and key_cache_out must be the same buffer"); | ||
| } else if (value_cache_out != nullptr && value_cache->Data<T>() != value_cache_out->MutableData<T>()) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "value_cache and value_cache_out must be the same buffer"); | ||
| } | ||
|
|
||
| // Check flash kernel availability and allocate buffers | ||
| #if USE_FLASH_ATTENTION | ||
| bool use_flash_attention = !disable_flash_attention_ && | ||
| onnxruntime::flash::is_supported(device_prop, | ||
| parameters.head_size, | ||
| parameters.num_heads, | ||
| parameters.kv_num_heads); | ||
| size_t softmax_lse_bytes = 0; | ||
| if (use_flash_attention) { | ||
| softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.batch_size, | ||
| parameters.num_heads, | ||
| parameters.token_count); | ||
| } | ||
|
aciddelgado marked this conversation as resolved.
Outdated
|
||
| auto softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, context->GetComputeStream()); | ||
| #else | ||
| constexpr bool use_flash_attention = false; | ||
| auto softmax_lse_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr | ||
| #endif | ||
|
aciddelgado marked this conversation as resolved.
|
||
|
|
||
| if (!use_flash_attention) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "Currently PagedAttention is only supported through the FlashAttention kernel."); | ||
| } | ||
|
|
||
| size_t cumulative_seqlens_kv_bytes = sizeof(int) * (parameters.batch_size + 1); | ||
| auto cumulative_seqlens_kv_buffer = GetScratchBuffer<void>(cumulative_seqlens_kv_bytes, context->GetComputeStream()); | ||
|
|
||
| size_t workspace_buffer_bytes = 0; | ||
| if (parameters.is_packed_qkv) { // unpacking and rotary can be done with the same buffer in the same operation | ||
| workspace_buffer_bytes = parameters.token_count * (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T); | ||
|
aciddelgado marked this conversation as resolved.
Outdated
|
||
| } else if (do_rotary_) { | ||
| workspace_buffer_bytes = 2 * sizeof(T) * parameters.token_count * parameters.num_heads * parameters.head_size; | ||
| } | ||
| auto workspace_buffer = GetScratchBuffer<void>(workspace_buffer_bytes, context->GetComputeStream()); | ||
|
|
||
| // Print debug info | ||
| if (kernel_options_->AllowDebugInfo()) { | ||
|
aciddelgado marked this conversation as resolved.
|
||
| AttentionKernelDebugInfo debug_info; | ||
| debug_info.use_flash_attention = use_flash_attention; | ||
|
|
||
| debug_info.Print("PagedAttention", | ||
| this->Node().Name(), | ||
| std::is_same<T, MLFloat16>::value, | ||
| std::is_same<T, BFloat16>::value); | ||
| } | ||
|
|
||
| // Set up data struct for kernel launch | ||
| data.query = reinterpret_cast<const CudaT*>(query->Data<T>()); | ||
| data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>()); | ||
| data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>()); | ||
| data.key_cache = reinterpret_cast<CudaT*>(const_cast<T*>(key_cache->Data<T>())); | ||
| data.value_cache = reinterpret_cast<CudaT*>(const_cast<T*>(value_cache->Data<T>())); | ||
| data.cumulative_sequence_length = reinterpret_cast<const int*>(cumulative_sequence_length->Data<int>()); | ||
| data.seqlens = reinterpret_cast<const int*>(seqlens->Data<int>()); | ||
| data.cumulative_seqlens_kv = reinterpret_cast<int*>(cumulative_seqlens_kv_buffer.get()); | ||
| data.block_table = reinterpret_cast<const int*>(block_table->Data<int>()); | ||
| data.slot_mappings = reinterpret_cast<const int*>(slot_mappings->Data<int>()); | ||
| data.output = reinterpret_cast<CudaT*>(output->MutableData<T>()); | ||
| data.use_flash_attention = use_flash_attention; | ||
| if (softmax_lse_buffer != nullptr) { | ||
| data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get()); | ||
| } | ||
| if (workspace_buffer != nullptr) { | ||
| data.workspace_buffer = reinterpret_cast<CudaT*>(workspace_buffer.get()); | ||
| } | ||
| if (parameters.do_rotary) { | ||
| data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>()); | ||
| data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>()); | ||
| } | ||
|
|
||
| cublasHandle_t cublas = GetCublasHandle(context); | ||
|
|
||
| return QkvToContext<CudaT>( | ||
| device_prop, cublas, context->GetComputeStream(), parameters, data); | ||
| } | ||
|
|
||
| } // namespace cuda | ||
| } // namespace contrib | ||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <memory> | ||
| #include "core/providers/cuda/cuda_kernel.h" | ||
| #include "contrib_ops/cuda/bert/paged_attention_impl.h" | ||
| #include "contrib_ops/cuda/bert/attention_kernel_options.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
| namespace cuda { | ||
|
|
||
| using namespace onnxruntime::cuda; | ||
|
|
||
| template <typename T> | ||
| class PagedAttention final : public CudaKernel { | ||
| public: | ||
| PagedAttention(const OpKernelInfo& info); | ||
| Status ComputeInternal(OpKernelContext* context) const override; | ||
|
|
||
| protected: | ||
| int num_heads_; // number of attention heads | ||
| int kv_num_heads_; // different for k and v for group query attention | ||
| int local_window_size_; | ||
| bool do_rotary_; | ||
| bool rotary_interleaved_; | ||
| float scale_; | ||
| float softcap_; | ||
| bool disable_flash_attention_; | ||
| const AttentionKernelOptions* kernel_options_; | ||
| }; | ||
|
|
||
| } // namespace cuda | ||
| } // namespace contrib | ||
| } // namespace onnxruntime |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.