Skip to content
Merged
15 changes: 15 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@ struct GroupQueryAttentionParameters : AttentionParameters {
int* zero_ptr;
};

// Parameters deduced from node attributes and inputs/outputs.
struct PagedAttentionParameters : AttentionParameters {
int token_count; // number of tokens in packed query
int kv_num_heads;
Comment thread
aciddelgado marked this conversation as resolved.
Outdated
int kv_hidden_size;
int block_size; // block size for kv cache
int max_num_blocks_per_seq; // max number of blocks per sequence for kv cache
int num_blocks; // number of blocks in kv cache
int rotary_dim; // rotary embedding dimension
Comment thread
aciddelgado marked this conversation as resolved.
Outdated
int local_window_size; // The window size excludes current token. It only includes tokens on the left side.
bool is_packed_qkv;
Comment thread
aciddelgado marked this conversation as resolved.
Outdated
bool rotary_interleaved;
float softcap;
};

// Parameters for sparse attention.
struct SparseAttentionParameters : AttentionParameters {
int kv_hidden_size; // hidden size of key or value
Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,35 @@ struct GroupQueryAttentionData {
bool use_memory_efficient_attention = false;
};

template <typename T>
struct PagedAttentionData {
// Input Tensors
const T* query = nullptr;
const T* key = nullptr;
const T* value = nullptr;
T* key_cache = nullptr;
T* value_cache = nullptr;
const int* cumulative_sequence_length = nullptr;
Comment thread
aciddelgado marked this conversation as resolved.
Outdated
const int* seqlens = nullptr;
const int* block_table = nullptr;
const int* slot_mappings = nullptr;
const T* cos_cache = nullptr;
const T* sin_cache = nullptr;

// Flash buffers
T* softmax_lse = nullptr;
int* cumulative_seqlens_kv = nullptr; // Flash api takes cumulative sequence length for kv-cache

// Fused op buffers
T* workspace_buffer = nullptr;

// Output Tensors
T* output = nullptr;

// Kernel Flags
bool use_flash_attention = false;
};

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
17 changes: 13 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads)
return bytes;
}

size_t get_softmax_lse_size(size_t token_count, size_t num_heads) {
size_t bytes = sizeof(float) * token_count * num_heads;
return bytes;
}

size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q) {
size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads;
return bytes;
Expand Down Expand Up @@ -336,6 +341,8 @@ Status mha_fwd(const cudaDeviceProp& dprops,
return Status::OK();
}

// TODO(aciddelgado): Baiju wants this https://github.com/Dao-AILab/flash-attention/pull/824

Status mha_varlen_fwd(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // half (total_q, num_heads, head_size)
Expand All @@ -357,6 +364,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
const float softcap,
bool is_causal,
bool is_bf16,
int local_window_size,
int max_num_blocks_per_seq,
int page_block_size) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
Expand Down Expand Up @@ -384,7 +392,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
is_bf16,
false,
true,
-1,
local_window_size,
is_causal ? 0 : -1);
params.dprops = &dprops;
params.num_splits = 0;
Expand All @@ -394,7 +402,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
params.vnew_ptr = nullptr;
params.alibi_slopes_ptr = nullptr;
if (paged_KV) {
params.block_table = block_table; // TODO(aciddelgado): cast to int pointer
params.block_table = block_table;
params.block_table_batch_stride = max_num_blocks_per_seq;
// params.num_blocks = num_blocks;
params.page_block_size = page_block_size;
Expand All @@ -406,7 +414,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
// params.num_blocks = 0;
params.page_block_size = 1;
}
run_mha_fwd(params, stream);

run_mha_fwd(params, stream, paged_KV);
return Status::OK();
}

Expand Down Expand Up @@ -538,7 +547,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,

params.alibi_slopes_ptr = nullptr;
if (paged_KV) {
params.block_table = block_table; // TODO(aciddelgado): cast to int pointer
params.block_table = block_table;
params.block_table_batch_stride = max_num_blocks_per_seq;
// params.num_blocks = num_blocks;
params.page_block_size = page_block_size;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
const float softcap,
bool is_causal,
bool is_bf16,
int local_window_size=-1,
int max_num_blocks_per_seq = 0,
Comment thread
aciddelgado marked this conversation as resolved.
int page_block_size = 1);

Expand Down Expand Up @@ -121,6 +122,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int page_block_size = 1);

size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads);
size_t get_softmax_lse_size(size_t token_count, size_t num_heads);

std::tuple<size_t, size_t, size_t> get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads,
size_t head_size, size_t num_SMs);
Expand Down
212 changes: 212 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/paged_attention.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Comment thread Fixed
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_common.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/paged_attention_impl.h"
#include "contrib_ops/cuda/bert/paged_attention.h"
#include "contrib_ops/cuda/bert/paged_attention_helper.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"

using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;

namespace onnxruntime {
namespace contrib {
namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
PagedAttention, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()) \
.InputMemoryType(OrtMemTypeCPUInput, 7) \
.InputMemoryType(OrtMemTypeCPUInput, 8), \
PagedAttention<T>);
Comment thread
aciddelgado marked this conversation as resolved.

REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)

template <typename T>
PagedAttention<T>::PagedAttention(const OpKernelInfo& info)
: CudaKernel(info) {
int64_t num_heads = 0;
int64_t kv_num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
num_heads_ = static_cast<int>(num_heads);
kv_num_heads_ = static_cast<int>(kv_num_heads);
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
softcap_ = info.GetAttrOrDefault<float>("softcap", 0.0f);

kernel_options_ = this->GetAttentionKernelOptions();
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
}

template <typename T>
Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* query = context->Input<Tensor>(0);
const Tensor* key = context->Input<Tensor>(1);
const Tensor* value = context->Input<Tensor>(2);
const Tensor* key_cache = context->Input<Tensor>(3);
const Tensor* value_cache = context->Input<Tensor>(4);
const Tensor* cumulative_sequence_length = context->Input<Tensor>(5);
const Tensor* seqlens = context->Input<Tensor>(6);
const Tensor* max_query_len = context->Input<Tensor>(7);
const Tensor* max_seq_len = context->Input<Tensor>(8);
const Tensor* block_table = context->Input<Tensor>(9);
const Tensor* slot_mappings = context->Input<Tensor>(10);
const Tensor* cos_cache = context->Input<Tensor>(11);
const Tensor* sin_cache = context->Input<Tensor>(12);

auto& device_prop = GetDeviceProp();
PagedAttentionParameters parameters;
typedef typename ToCudaType<T>::MappedType CudaT;
PagedAttentionData<CudaT> data;

// Check shapes of inputs to op and set parameters
ORT_RETURN_IF_ERROR(paged_attention_helper::CheckInputs(query,
key,
value,
key_cache,
value_cache,
cumulative_sequence_length,
seqlens,
max_query_len,
max_seq_len,
block_table,
slot_mappings,
cos_cache,
sin_cache,
&parameters,
num_heads_,
kv_num_heads_,
scale_,
softcap_,
device_prop.maxThreadsPerBlock));
parameters.local_window_size = local_window_size_;
parameters.do_rotary = do_rotary_;
parameters.rotary_interleaved = rotary_interleaved_;

// Check rotary
if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache and sin_cache must be passed to PagedAttention when do_rotary = 1");
}

// Set output tensor shapes
TensorShapeVector output_shape(2);
output_shape[0] = static_cast<int64_t>(parameters.token_count);
output_shape[1] = static_cast<int64_t>(parameters.hidden_size);
Tensor* output = context->Output(0, output_shape);

TensorShapeVector key_cache_out_shape(4);
key_cache_out_shape[0] = static_cast<int64_t>(parameters.num_blocks);
key_cache_out_shape[1] = static_cast<int64_t>(parameters.block_size);
key_cache_out_shape[2] = static_cast<int64_t>(parameters.kv_num_heads);
key_cache_out_shape[3] = static_cast<int64_t>(parameters.head_size);
Tensor* key_cache_out = context->Output(1, key_cache_out_shape);

TensorShapeVector value_cache_out_shape(4);
value_cache_out_shape[0] = static_cast<int64_t>(parameters.num_blocks);
value_cache_out_shape[1] = static_cast<int64_t>(parameters.block_size);
value_cache_out_shape[2] = static_cast<int64_t>(parameters.kv_num_heads);
value_cache_out_shape[3] = static_cast<int64_t>(parameters.head_size);
Tensor* value_cache_out = context->Output(2, value_cache_out_shape);

if (key_cache_out != nullptr && key_cache->Data<T>() != key_cache_out->MutableData<T>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"key_cache and key_cache_out must be the same buffer");
} else if (value_cache_out != nullptr && value_cache->Data<T>() != value_cache_out->MutableData<T>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"value_cache and value_cache_out must be the same buffer");
}

// Check flash kernel availability and allocate buffers
#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
onnxruntime::flash::is_supported(device_prop,
parameters.head_size,
parameters.num_heads,
parameters.kv_num_heads);
size_t softmax_lse_bytes = 0;
if (use_flash_attention) {
softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.batch_size,
parameters.num_heads,
parameters.token_count);
}
Comment thread
aciddelgado marked this conversation as resolved.
Outdated
auto softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, context->GetComputeStream());
#else
constexpr bool use_flash_attention = false;
auto softmax_lse_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif
Comment thread
aciddelgado marked this conversation as resolved.

if (!use_flash_attention) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Currently PagedAttention is only supported through the FlashAttention kernel.");
}

size_t cumulative_seqlens_kv_bytes = sizeof(int) * (parameters.batch_size + 1);
auto cumulative_seqlens_kv_buffer = GetScratchBuffer<void>(cumulative_seqlens_kv_bytes, context->GetComputeStream());

size_t workspace_buffer_bytes = 0;
if (parameters.is_packed_qkv) { // unpacking and rotary can be done with the same buffer in the same operation
workspace_buffer_bytes = parameters.token_count * (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T);
Comment thread
aciddelgado marked this conversation as resolved.
Outdated
} else if (do_rotary_) {
workspace_buffer_bytes = 2 * sizeof(T) * parameters.token_count * parameters.num_heads * parameters.head_size;
}
auto workspace_buffer = GetScratchBuffer<void>(workspace_buffer_bytes, context->GetComputeStream());

// Print debug info
if (kernel_options_->AllowDebugInfo()) {
Comment thread
aciddelgado marked this conversation as resolved.
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;

debug_info.Print("PagedAttention",
this->Node().Name(),
std::is_same<T, MLFloat16>::value,
std::is_same<T, BFloat16>::value);
}

// Set up data struct for kernel launch
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
data.key_cache = reinterpret_cast<CudaT*>(const_cast<T*>(key_cache->Data<T>()));
data.value_cache = reinterpret_cast<CudaT*>(const_cast<T*>(value_cache->Data<T>()));
data.cumulative_sequence_length = reinterpret_cast<const int*>(cumulative_sequence_length->Data<int>());
data.seqlens = reinterpret_cast<const int*>(seqlens->Data<int>());
data.cumulative_seqlens_kv = reinterpret_cast<int*>(cumulative_seqlens_kv_buffer.get());
data.block_table = reinterpret_cast<const int*>(block_table->Data<int>());
data.slot_mappings = reinterpret_cast<const int*>(slot_mappings->Data<int>());
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.use_flash_attention = use_flash_attention;
if (softmax_lse_buffer != nullptr) {
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
}
if (workspace_buffer != nullptr) {
data.workspace_buffer = reinterpret_cast<CudaT*>(workspace_buffer.get());
}
if (parameters.do_rotary) {
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
}

cublasHandle_t cublas = GetCublasHandle(context);

return QkvToContext<CudaT>(
device_prop, cublas, context->GetComputeStream(), parameters, data);
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
37 changes: 37 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/paged_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <memory>
#include "core/providers/cuda/cuda_kernel.h"
#include "contrib_ops/cuda/bert/paged_attention_impl.h"
#include "contrib_ops/cuda/bert/attention_kernel_options.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

using namespace onnxruntime::cuda;

template <typename T>
class PagedAttention final : public CudaKernel {
public:
PagedAttention(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;

protected:
int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
int local_window_size_;
bool do_rotary_;
bool rotary_interleaved_;
float scale_;
float softcap_;
bool disable_flash_attention_;
const AttentionKernelOptions* kernel_options_;
};

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Loading