diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 7ba2f820e9bdb..f6cc816b45ed2 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -67,6 +67,7 @@ Do not modify directly.*
* com.microsoft.PackedAttention
* com.microsoft.PackedMultiHeadAttention
* com.microsoft.Pad
+ * com.microsoft.PagedAttention
* com.microsoft.QAttention
* com.microsoft.QGemm
* com.microsoft.QLinearAdd
@@ -3683,6 +3684,100 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.PagedAttention**
+
+ Paged Attention.
+
+ This op leverages a block-based KV cache to enable continuous batching for LLMs. Currently, it is designed to work with
+ the CUDA Execution Provider only.
+
+ In other attention ops, batch entries typically aren't of the same length, so they are padded.
+ Below is a batch with 3 sequences where * denotes a padding token.
+ Sequence_0: 0, 1*, 2*, 3*
+ Sequence_1: 4, 5, 6*, 7*
+ Sequence_2: 8, 9, 10, 11
+
+ PagedAttention is designed to take in packed input, i.e., only the real tokens without padding.
+ For example, the input shown above will be packed into 3 tensors like below:
+ - query ([q0, q4, q5, q8, q9, q10, q11])
+ - key ([k0, k4, k5, k8, k9, k10, k11])
+ - value ([v0, v4, v5, v8, v9, v10, v11])
+ - cumulative_sequence_length: 0, 1, 1+2, 1+2+4
+ This packing omits padding tokens.
+
+ The query, key and value tensors contain result of hidden embedding of real tokens after input projections.
+ cumulative_sequence_length records cumulated length of each sequence length.
+
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- do_rotary : int
+- Whether to use rotary position embedding. Default value is 0.
+- kv_num_heads : int (required)
+- Number of attention heads for k and v
+- local_window_size : int
+- left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
+- num_heads : int (required)
+- Number of attention heads for q
+- rotary_interleaved : int
+- Rotate using interleaved pattern. Default value is 0 (False).
+- scale : float
+- Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+- softcap : float
+- Softcap value for attention weights. Default value is 0.
+
+
+#### Inputs (8 - 10)
+
+
+- query : T
+- Query with shape (num_tokens, hidden_size), or packed QKV with shape (num_tokens, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).
+- key (optional) : T
+- Key with shape (num_tokens, kv_hidden_size)
+- value (optional) : T
+- Value with shape (num_tokens, kv_hidden_size)
+- key_cache : T
+- Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated in place within the op.
+- value_cache : T
+- Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated in place within the op. This should be the same shape as key_cache.
+- cumulative_sequence_length : S
+- A tensor with shape (batch_size + 1). It specifies the cumulative sequence lengths between the packed entries in Q/K/V.
+- past_seqlens : S
+- A tensor with shape (batch_size). It specifies the past lengths of cached sequence in the KV cache.
+- block_table : S
+- 2D tensor with shape (batch_size, max_blocks_per_sequence) that maps each sequence in the batch to itscorresponding blocks in the KV cache.
+- cos_cache (optional) : T
+- 2D tensor with shape (max total seqlen, head_size / 2).
+- sin_cache (optional) : T
+- 2D tensor with shape (max total seqlen, head_size / 2).
+
+
+#### Outputs (1 - 3)
+
+
+- output : T
+- 3D output tensor with shape (num_tokens, hidden_size)
+- key_cache_out (optional) : T
+- Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always the same tensor as key_cache.
+- value_cache_out (optional) : T
+- Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always the same tensor as value_cache.
+
+
+#### Type Constraints
+
+
+- T : tensor(float16), tensor(bfloat16)
+- Constrain input and output to float tensors.
+- S : tensor(int32)
+- Constrain Positional inputs to int tensor.
+
+
+
### **com.microsoft.QAttention**
Quantization of Multi-Head Self Attention.
@@ -6345,3 +6440,5 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
+
+
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index b657c828fbde1..5154c334acc23 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -952,6 +952,7 @@ Do not modify directly.*
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|PagedAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* key_cache:**T**
*in* value_cache:**T**
*in* cumulative_sequence_length:**S**
*in* past_seqlens:**S**
*in* block_table:**S**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* key_cache_out:**T**
*out* value_cache_out:**T**|1+|**S** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)|
|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)|
|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
index 52dcb990ab67f..651f270230a75 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
@@ -227,7 +227,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
output_parameters->is_unidirectional = is_unidirectional_;
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
output_parameters->do_rotary = do_rotary_;
- output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;
+ output_parameters->rotary_dim = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;
output_parameters->mask_filter_value = mask_filter_value_;
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
index c3d5128948c6f..77d3089de5d09 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
@@ -22,11 +22,12 @@ struct AttentionParameters {
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
- int num_splits;
- int rotary_embedding;
+ int num_splits; // number of splits for splitkv
+ int rotary_dim = 0; // rotary embedding dimension
int beam_width;
bool is_unidirectional;
bool past_present_share_buffer;
+ bool is_packed_qkv = false; // whether qkv is packed
bool do_rotary;
bool broadcast_attn_bias_dim_0;
bool broadcast_attn_bias_dim_1;
@@ -46,13 +47,11 @@ struct DecoderMaskedMultiHeadAttentionParameters : AttentionParameters {
int beam_width = 1;
// Only NeoX style rotary embedding is supported
- int rotary_embedding_dim = 0;
int t_step = 0;
// Weather to use multihead attention(excludes matmul and bias)
bool is_mha = false;
bool is_cross_attention = false;
- bool is_packed_qkv = false;
// Useful to better use global memory bandwidth on certain CUDA architectures.
// Turned off by default for now until we fully understand performance implications
@@ -83,15 +82,12 @@ struct DecoderMaskedMultiHeadAttentionParameters : AttentionParameters {
// Parameters deduced from node attributes and inputs/outputs.
struct GroupQueryAttentionParameters : AttentionParameters {
+ int kv_num_heads; // number of heads of key or value
+ int kv_hidden_size; // hidden size of key or value
int seqlen_past_kv_cache; // sequence length of past kv tensor
int seqlen_present_kv_cache; // sequence length of present kv tensor
- int kv_hidden_size;
- int kv_num_heads;
- int num_splits; // number of splits for splitkv
- int rotary_dim; // rotary embedding dimension
- int local_window_size; // The window size excludes current token. It only includes tokens on the left side.
+ int local_window_size; // The window size excludes current token. It only includes tokens on the left side.
bool kv_share_buffer;
- bool is_packed_qkv;
bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1
bool is_first_prompt; // indicates whether this is first decoding step
bool rotary_interleaved;
@@ -102,18 +98,29 @@ struct GroupQueryAttentionParameters : AttentionParameters {
int* zero_ptr;
};
+// Parameters deduced from node attributes and inputs/outputs.
+struct PagedAttentionParameters : AttentionParameters {
+ int kv_num_heads; // number of heads of key or value
+ int kv_hidden_size; // hidden size of key or value
+ int token_count; // number of tokens in packed query
+ int block_size; // block size for kv cache
+ int max_num_blocks_per_seq; // max number of blocks per sequence for kv cache
+ int num_blocks; // number of blocks in kv cache
+ int local_window_size; // The window size excludes current token. It only includes tokens on the left side.
+ bool rotary_interleaved;
+ float softcap;
+};
+
// Parameters for sparse attention.
struct SparseAttentionParameters : AttentionParameters {
int kv_hidden_size; // hidden size of key or value
int kv_num_heads; // number of heads of key or value
bool do_rotary; // whether to use rotary embedding
bool rotary_interleaved; // whether to use interleaved rotary embedding
- int rotary_dim; // rotary embedding dimension
int sparse_block_size; // block size for sparse attention
int num_sparse_layout; // number of sparse layout
int stride_col_indices; // shape of block_col_indices is [num_sparse_layout, stride_col_indices]
int stride_row_indices; // shape of block_row_indices is [num_sparse_layout, stride_row_indices]
- bool is_packed_qkv; // whether qkv is packed
int max_rotary_sequence_length; // max sequence length for rotary cos/sin cache
int max_cache_sequence_length; // max sequence length for kv cache buffer
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
index fa0d33e891f46..338c34acb3cfb 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
@@ -12,6 +12,183 @@ namespace onnxruntime {
namespace contrib {
namespace group_query_attention_helper {
+template
+Status Check_Q_K_V(const T* query, const T* key, const T* value, const int num_heads, const int kv_num_heads,
+ int& batch_size, int& sequence_length, int& q_hidden_size, int& kv_hidden_size, int& head_size) {
+ const auto& query_dims = query->Shape().GetDims();
+ if (query_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
+ query_dims.size());
+ }
+ batch_size = static_cast(query_dims[0]);
+ sequence_length = static_cast(query_dims[1]);
+ q_hidden_size = static_cast(query_dims[2]);
+ head_size = static_cast(q_hidden_size) / num_heads;
+ if (head_size % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size must be a multiple of 8. Got head_size % 8 == ",
+ head_size % 8);
+ }
+ if (value == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
+ }
+ const auto& key_dims = key->Shape().GetDims();
+ if (key_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
+ key_dims.size());
+ } else if (query_dims[0] != key_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 0 (batch size)");
+ } else if (query_dims[1] != key_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 1 (sequence length)");
+ }
+ kv_hidden_size = static_cast(key_dims[2]);
+ if (kv_hidden_size % kv_num_heads != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "kv_hidden_size must be a multiple of kv_num_heads. Got kv_hidden_size % kv_num_heads == ",
+ kv_hidden_size % kv_num_heads);
+ } else if (kv_hidden_size / kv_num_heads != head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "kv_hidden_size / kv_num_heads must be equal to head_size. Got kv_hidden_size / kv_num_heads == ",
+ kv_hidden_size / kv_num_heads);
+ }
+ const auto& value_dims = value->Shape().GetDims();
+ if (value_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
+ value_dims.size());
+ } else if (query_dims[0] != value_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'value' shall have same dim 0 (batch size)");
+ } else if (query_dims[1] != value_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'value' shall have same dim 1 (sequence length)");
+ } else if (value_dims[2] != kv_hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
+ }
+ return Status::OK();
+}
+
+template
+Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const int kv_num_heads, int& batch_size,
+ int& sequence_length, int& q_hidden_size, int& kv_hidden_size, int& head_size) {
+ const auto& packed_dims = packed_qkv->Shape().GetDims();
+ if (packed_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
+ packed_dims.size());
+ }
+ batch_size = static_cast(packed_dims[0]);
+ sequence_length = static_cast(packed_dims[1]);
+ head_size = static_cast(static_cast(packed_dims[2])) / (num_heads + 2 * kv_num_heads);
+ // Check packed qkv
+ if (head_size % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size must be a multiple of 8. Got head_size % 8 == ",
+ head_size % 8);
+ }
+ if (value != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
+ }
+ q_hidden_size = head_size * num_heads;
+ kv_hidden_size = head_size * kv_num_heads;
+ return Status::OK();
+}
+
+template
+Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_num_heads, int head_size,
+ int& past_sequence_length) {
+ const auto& past_key_dims = past_key->Shape().GetDims();
+ const auto& past_value_dims = past_value->Shape().GetDims();
+
+ if (past_key_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' is expected to have 4 dimensions, got ",
+ past_key_dims.size());
+ }
+ if (past_value_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' is expected to have 4 dimensions, got ",
+ past_value_dims.size());
+ }
+
+ if (past_key_dims[0] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' dimension 0 should be batch_size, got ",
+ past_key_dims[0]);
+ }
+ if (past_value_dims[0] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' dimension 0 should be batch_size, got ",
+ past_value_dims[0]);
+ }
+
+ if (past_key_dims[2] != past_value_dims[2]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence"
+ "length or past sequence length), got ",
+ past_key_dims[1]);
+ }
+ if (past_key_dims[1] != kv_num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' shall have kv_num_heads");
+ }
+ if (past_value_dims[1] != kv_num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' shall have kv_num_heads");
+ }
+ // We assume all sequence in past kv are right-padded to max or past sequence length
+ past_sequence_length = static_cast(past_key_dims[2]);
+
+ if (past_key_dims[3] != head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' dimension 3 should be same as head_size, got ",
+ past_key_dims[3]);
+ }
+ if (past_value_dims[3] != head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' dimension 3 should be same as head_size, got ",
+ past_value_dims[3]);
+ }
+ return Status::OK();
+}
+
+template
+Status CheckRotaryCaches(const T* cos_cache, const T* sin_cache, int head_size, int total_sequence_length,
+ int& rotary_dim) {
+ const auto& cos_dims = cos_cache->Shape().GetDims();
+ const auto& sin_dims = sin_cache->Shape().GetDims();
+
+ if (head_size % 16 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size shall be a multiple of 16. Got head_size % 16 == ",
+ head_size % 16);
+ }
+ if (cos_dims[0] < total_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cos_cache dimension 0 shall not be less than total_sequence_length.");
+ }
+ if (sin_dims[0] < total_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "sin_cache dimension 0 shall not be less than total_sequence_length.");
+ }
+ if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
+ }
+ if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
+ }
+ if (cos_dims[1] != sin_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cos_cache and sin_cache dimension 1 must be the same.");
+ }
+ rotary_dim = static_cast(cos_dims[1] * 2);
+ return Status::OK();
+}
+
template
Status CheckInputs(const T* query,
const T* key,
@@ -37,18 +214,6 @@ Status CheckInputs(const T* query,
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
AttentionQkvFormat past_kv_format = Q_K_V_BNSH;
- const bool is_packed_qkv = key == nullptr;
-
- const auto& query_dims = query->Shape().GetDims();
- if (query_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
- query_dims.size());
- }
-
- int batch_size = static_cast(query_dims[0]);
- int sequence_length = static_cast(query_dims[1]);
- int q_hidden_size = static_cast(query_dims[2]);
- int head_size = 0;
if (num_heads % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@@ -56,115 +221,25 @@ Status CheckInputs(const T* query,
num_heads % kv_num_heads);
}
+ int batch_size = 0;
+ int sequence_length = 0;
+ int q_hidden_size = 0;
int kv_hidden_size = 0;
- // Check key and value when not packed
+ int head_size = 0;
+ const bool is_packed_qkv = key == nullptr;
if (!is_packed_qkv) {
- head_size = static_cast(q_hidden_size) / num_heads;
- if (head_size % 8 != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "head_size must be a multiple of 8. Got head_size % 8 == ",
- head_size % 8);
- }
- if (value == nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
- }
- const auto& key_dims = key->Shape().GetDims();
- if (key_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
- key_dims.size());
- } else if (query_dims[0] != key_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'key' shall have same dim 0 (batch size)");
- } else if (query_dims[1] != key_dims[1]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'key' shall have same dim 1 (sequence length)");
- }
- kv_hidden_size = static_cast(key_dims[2]);
- const auto& value_dims = value->Shape().GetDims();
- if (value_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
- value_dims.size());
- } else if (query_dims[0] != value_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'value' shall have same dim 0 (batch size)");
- } else if (query_dims[1] != value_dims[1]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'value' shall have same dim 1 (sequence length)");
- } else if (value_dims[2] != kv_hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
- }
+ ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, batch_size, sequence_length,
+ q_hidden_size, kv_hidden_size, head_size));
} else {
- // Check packed qkv
- head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads);
- if (head_size % 8 != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "head_size must be a multiple of 8. Got head_size % 8 == ",
- head_size % 8);
- }
- if (value != nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
- }
- q_hidden_size = head_size * num_heads;
- kv_hidden_size = head_size * kv_num_heads;
+ qkv_format = QKV_BS3NH;
+ ORT_RETURN_IF_ERROR(Check_QKV(query, value, num_heads, kv_num_heads, batch_size, sequence_length, q_hidden_size,
+ kv_hidden_size, head_size));
}
// Check past-present KV
int32_t past_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
- const auto& past_key_dims = past_key->Shape().GetDims();
- const auto& past_value_dims = past_value->Shape().GetDims();
-
- if (past_key_dims.size() != 4) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' is expected to have 4 dimensions, got ",
- past_key_dims.size());
- }
- if (past_value_dims.size() != 4) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_value' is expected to have 4 dimensions, got ",
- past_value_dims.size());
- }
-
- if (past_key_dims[0] != batch_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' dimension 0 should be batch_size, got ",
- past_key_dims[0]);
- }
- if (past_value_dims[0] != batch_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_value' dimension 0 should be batch_size, got ",
- past_value_dims[0]);
- }
-
- if (past_key_dims[2] != past_value_dims[2]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence"
- "length or past sequence length), got ",
- past_key_dims[1]);
- }
- if (past_key_dims[1] != kv_num_heads) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' shall have kv_num_heads");
- }
- if (past_value_dims[1] != kv_num_heads) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_value' shall have kv_num_heads");
- }
- // We assume all sequence in past kv are right-padded to max or past sequence length
- past_sequence_length = static_cast(past_key_dims[2]);
-
- if (past_key_dims[3] != head_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_key' dimension 3 should be same as head_size, got ",
- past_key_dims[3]);
- }
- if (past_value_dims[3] != head_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'past_value' dimension 3 should be same as head_size, got ",
- past_value_dims[3]);
- }
+ ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, batch_size, kv_num_heads, head_size, past_sequence_length));
} else if (past_key != nullptr || past_value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
@@ -186,35 +261,7 @@ Status CheckInputs(const T* query,
int rotary_dim = 0;
if (cos_cache != nullptr && sin_cache != nullptr) {
- const auto& cos_dims = cos_cache->Shape().GetDims();
- const auto& sin_dims = sin_cache->Shape().GetDims();
-
- if (head_size % 16 != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "head_size shall be a multiple of 16. Got head_size % 16 == ",
- head_size % 16);
- }
- if (cos_dims[0] < total_sequence_length) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "cos_cache dimension 0 shall not be less than total_sequence_length.");
- }
- if (sin_dims[0] < total_sequence_length) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "sin_cache dimension 0 shall not be less than total_sequence_length.");
- }
- if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
- }
- if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
- }
- if (cos_dims[1] != sin_dims[1]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "cos_cache and sin_cache dimension 1 must be the same.");
- }
- rotary_dim = static_cast(cos_dims[1] * 2);
+ ORT_RETURN_IF_ERROR(CheckRotaryCaches(cos_cache, sin_cache, head_size, total_sequence_length, rotary_dim));
} else if (cos_cache != nullptr || sin_cache != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h
index c7b06d50858b4..691391ccef0d0 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h
@@ -180,6 +180,35 @@ struct GroupQueryAttentionData {
bool use_memory_efficient_attention = false;
};
+template
+struct PagedAttentionData {
+ // Input Tensors
+ const T* query = nullptr;
+ const T* key = nullptr;
+ const T* value = nullptr;
+ T* key_cache = nullptr;
+ T* value_cache = nullptr;
+ const int* cumulative_seqlens_q = nullptr;
+ const int* past_seqlens = nullptr;
+ const int* block_table = nullptr;
+ const int* slot_mappings = nullptr;
+ const T* cos_cache = nullptr;
+ const T* sin_cache = nullptr;
+
+ // Flash buffers
+ T* softmax_lse = nullptr;
+ int* cumulative_seqlens_kv = nullptr; // Flash api takes cumulative sequence length for kv-cache
+
+ // Fused op buffers
+ T* workspace_buffer = nullptr;
+
+ // Output Tensors
+ T* output = nullptr;
+
+ // Kernel Flags
+ bool use_flash_attention = false;
+};
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
index 122e94d9558e3..a7989df3439ae 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
@@ -150,7 +150,7 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
- 3, parameters.do_rotary, parameters.rotary_embedding,
+ 3, parameters.do_rotary, parameters.rotary_dim,
parameters.past_sequence_length);
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc
index a15b59d0c018a..b4643da58eba5 100644
--- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc
@@ -195,7 +195,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont
if (do_rotary_) {
ORT_ENFORCE(parameters.head_size == 64 || parameters.head_size == 128,
"Current implementation of rotary embedding only supports head size of 64 or 128");
- parameters.rotary_embedding_dim = parameters.head_size;
+ parameters.rotary_dim = parameters.head_size;
parameters.t_step = parameters.past_sequence_length;
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
index 75ea7454791b6..6ba5ce66eaa60 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
@@ -212,13 +212,13 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
}
}
- if (params.rotary_embedding_dim > 0) {
- const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
+ if (params.rotary_dim > 0) {
+ const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_dim;
T* q_smem = reinterpret_cast(smem_);
- T* k_smem = q_smem + params.rotary_embedding_dim;
+ T* k_smem = q_smem + params.rotary_dim;
- const int half_rotary_dim = params.rotary_embedding_dim / 2;
+ const int half_rotary_dim = params.rotary_dim / 2;
const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim;
const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim;
const int smem_pitch = half_rotary_dim;
@@ -240,7 +240,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
apply_rotary_embedding(
- q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.t_step);
+ q, k, transpose_idx / tidx_factor, params.rotary_dim, params.t_step);
write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h
index 08e4293528d5a..586732834f0ad 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h
+++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h
@@ -165,8 +165,8 @@ inline size_t CalcDynamicBlockMemory(const DecoderMaskedMultiHeadAttentionParame
size_t red_sz = rows_per_red * params.head_size * sizeof(T) / 2;
size_t transpose_rotary_size = 0;
- if (params.rotary_embedding_dim > 0) {
- transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T);
+ if (params.rotary_dim > 0) {
+ transpose_rotary_size = 2 * params.rotary_dim * sizeof(T);
}
// The max.
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
index 4aa633ca45e2b..c24bf88fa729b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
@@ -70,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params {
int seqlen_k_rounded = 0;
int d_rounded = 0;
int rotary_dim = 0;
+ int total_q = 0;
// The scaling factors for the kernel.
float scale_softmax = 0.0;
@@ -129,6 +130,7 @@ struct Flash_fwd_params : public Qkv_params {
void* __restrict__ alibi_slopes_ptr = nullptr;
index_t alibi_slopes_batch_stride = 0;
+ bool unpadded_lse = false;
const cudaDeviceProp* dprops = nullptr;
};
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
index 453dffaa2e6e6..b0241c26aafc6 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
@@ -41,7 +41,8 @@ void set_params_fprop(Flash_fwd_params& params,
bool use_smooth_softmax,
bool kv_bsnh = true,
int window_size_left = -1,
- int window_size_right = -1) {
+ int window_size_right = -1,
+ const bool unpadded_lse = false) {
// Set the pointers and strides.
params.q_ptr = q;
params.k_ptr = k;
@@ -142,6 +143,7 @@ void set_params_fprop(Flash_fwd_params& params,
params.window_size_right = window_size_right;
params.is_seqlens_k_cumulative = true;
+ params.unpadded_lse = unpadded_lse;
}
size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) {
@@ -149,6 +151,11 @@ size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads)
return bytes;
}
+size_t get_softmax_lse_size(size_t token_count, size_t num_heads) {
+ size_t bytes = sizeof(float) * token_count * num_heads;
+ return bytes;
+}
+
size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q) {
size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads;
return bytes;
@@ -336,6 +343,8 @@ Status mha_fwd(const cudaDeviceProp& dprops,
return Status::OK();
}
+// TODO(aciddelgado): Baiju wants this https://github.com/Dao-AILab/flash-attention/pull/824
+
Status mha_varlen_fwd(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // half (total_q, num_heads, head_size)
@@ -353,10 +362,12 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
int head_size,
int max_seqlen_q,
int max_seqlen_k,
+ int total_q,
float softmax_scale,
const float softcap,
bool is_causal,
bool is_bf16,
+ int local_window_size,
int max_num_blocks_per_seq,
int page_block_size) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
@@ -384,8 +395,11 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
is_bf16,
false,
true,
- -1,
- is_causal ? 0 : -1);
+ local_window_size,
+ is_causal ? 0 : -1,
+ /*unpadded_lse*/ true);
+
+ params.total_q = total_q;
params.dprops = &dprops;
params.num_splits = 0;
params.softmax_lseaccum_ptr = nullptr;
@@ -394,7 +408,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
params.vnew_ptr = nullptr;
params.alibi_slopes_ptr = nullptr;
if (paged_KV) {
- params.block_table = block_table; // TODO(aciddelgado): cast to int pointer
+ params.block_table = block_table;
params.block_table_batch_stride = max_num_blocks_per_seq;
// params.num_blocks = num_blocks;
params.page_block_size = page_block_size;
@@ -406,7 +420,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
// params.num_blocks = 0;
params.page_block_size = 1;
}
- run_mha_fwd(params, stream);
+
+ run_mha_fwd(params, stream, paged_KV);
return Status::OK();
}
@@ -536,7 +551,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.alibi_slopes_ptr = nullptr;
if (paged_KV) {
- params.block_table = block_table; // TODO(aciddelgado): cast to int pointer
+ params.block_table = block_table;
params.block_table_batch_stride = max_num_blocks_per_seq;
// params.num_blocks = num_blocks;
params.page_block_size = page_block_size;
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
index 57752e8237d6e..e28e38ea3ed93 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
@@ -77,10 +77,12 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
int head_size,
int max_seqlen_q,
int max_seqlen_k,
+ int total_q,
float softmax_scale,
const float softcap,
bool is_causal,
bool is_bf16,
+ int local_window_size = -1,
int max_num_blocks_per_seq = 0,
int page_block_size = 1);
@@ -121,6 +123,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int page_block_size = 1);
size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads);
+size_t get_softmax_lse_size(size_t token_count, size_t num_heads);
std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads,
size_t head_size, size_t num_SMs);
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h
index d46d9597a758f..4110e715c4391 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h
@@ -34,6 +34,26 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
+template
+__forceinline__ __device__ auto get_lse_tile(const Params& params, const int bidb, const int bidh, const int m_block, const BlockInfo*Varlen=*/!Is_even_MN>& binfo) {
+ // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
+ // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
+ // Otherwise, it's written as (h, b, seqlen_q).
+ const bool varlen_q = params.unpadded_lse;
+ auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0;
+ auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + lse_offset);
+
+ auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q);
+ auto lse_stride = params.unpadded_lse
+ ? make_stride(params.h * params.total_q, params.total_q, 1)
+ : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1);
+
+ auto lse_layout = make_layout(lse_shape, lse_stride);
+ Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout);
+ auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _);
+ return local_tile(mLSE_slice, Shape>{}, make_coord(m_block));
+}
+
template
inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
@@ -70,10 +90,8 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
- Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)),
- make_shape(params.b, params.h, params.seqlen_q),
- make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
- Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block));
+
+ Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
@@ -375,10 +393,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
- Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)),
- make_shape(params.b, params.h, params.seqlen_q),
- make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
- Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block));
+ Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
@@ -938,8 +953,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded;
- const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
-
+ const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)) + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape, Int>{},
make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
@@ -1047,12 +1061,24 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) {
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
+ const index_t lse_size = params.b * params.h * params.seqlen_q;
+
const index_t row_offset_lse = bidx * kBlockM;
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse),
Shape, Int>{},
- make_stride(params.b * params.h * params.seqlen_q, _1{}));
+ make_stride(lse_size, _1{}));
+
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse),
Shape>{}, Stride<_1>{});
+ // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
+ Layout flat_layout = make_layout(lse_size);
+ Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));
+ auto transposed_stride = make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
+ Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
+ Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));
+
+ Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout);
+
constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
@@ -1107,7 +1133,14 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) {
ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? std::numeric_limits::infinity() : logf(lse_sum) + lse_max;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
- gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
+ if (params.unpadded_lse) {
+ const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
+ if (lse_offset < lse_size) {
+ gLSE_unpadded(lse_offset) = lse_logsum;
+ }
+ } else {
+ gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
+ }
}
// Store the scales exp(lse - lse_logsum) in shared memory.
#pragma unroll
diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
index 846d2be7bf2e1..07bca3f7fff99 100644
--- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
@@ -589,6 +589,7 @@ Status FlashAttention(
PackedMultiHeadAttentionData& data) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
+ const int token_count = parameters.token_count;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
@@ -638,6 +639,7 @@ Status FlashAttention(
qk_head_size,
sequence_length,
sequence_length,
+ token_count,
scale,
0.0,
false, // is causal
diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc
new file mode 100644
index 0000000000000..4189965ab9137
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc
@@ -0,0 +1,218 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "core/platform/env_var_utils.h"
+#include "contrib_ops/cpu/utils/dump_tensor.h"
+#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
+#include "contrib_ops/cuda/bert/paged_attention_impl.h"
+#include "contrib_ops/cuda/bert/paged_attention.h"
+#include "contrib_ops/cuda/bert/paged_attention_helper.h"
+#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
+
+using namespace onnxruntime::cuda;
+using namespace ::onnxruntime::common;
+using namespace ONNX_NAMESPACE;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ PagedAttention, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("S", DataTypeImpl::GetTensorType()), \
+ PagedAttention);
+
+REGISTER_KERNEL_TYPED(MLFloat16)
+REGISTER_KERNEL_TYPED(BFloat16)
+
+template
+PagedAttention::PagedAttention(const OpKernelInfo& info)
+ : CudaKernel(info) {
+ int64_t num_heads = 0;
+ int64_t kv_num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
+ ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
+ num_heads_ = static_cast(num_heads);
+ kv_num_heads_ = static_cast(kv_num_heads);
+ local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1));
+ do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1;
+ rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1;
+ scale_ = info.GetAttrOrDefault("scale", 0.0f);
+ softcap_ = info.GetAttrOrDefault("softcap", 0.0f);
+
+ kernel_options_ = this->GetAttentionKernelOptions();
+ disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
+}
+
+template
+Status PagedAttention::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* query = context->Input(0);
+ const Tensor* key = context->Input(1);
+ const Tensor* value = context->Input(2);
+ const Tensor* key_cache = context->Input(3);
+ const Tensor* value_cache = context->Input(4);
+ const Tensor* cumulative_seqlens_q = context->Input(5);
+ const Tensor* past_seqlens = context->Input(6);
+ const Tensor* block_table = context->Input(7);
+ const Tensor* cos_cache = context->Input(8);
+ const Tensor* sin_cache = context->Input(9);
+
+ auto& device_prop = GetDeviceProp();
+ PagedAttentionParameters parameters;
+ typedef typename ToCudaType::MappedType CudaT;
+ PagedAttentionData data;
+
+ // Check shapes of inputs to op and set parameters
+ ORT_RETURN_IF_ERROR(paged_attention_helper::CheckInputs(query,
+ key,
+ value,
+ key_cache,
+ value_cache,
+ cumulative_seqlens_q,
+ past_seqlens,
+ block_table,
+ cos_cache,
+ sin_cache,
+ ¶meters,
+ num_heads_,
+ kv_num_heads_,
+ scale_,
+ softcap_,
+ device_prop.maxThreadsPerBlock));
+ parameters.local_window_size = local_window_size_;
+ parameters.do_rotary = do_rotary_;
+ parameters.rotary_interleaved = rotary_interleaved_;
+
+ DUMP_STRING_INIT();
+ DUMP_STRING("Batch size = ", parameters.batch_size);
+ DUMP_STRING("Token count = ", parameters.token_count);
+ DUMP_STRING("Q hidden size = ", parameters.hidden_size);
+ DUMP_STRING("KV hidden size = ", parameters.kv_hidden_size);
+ DUMP_STRING("Q num heads = ", parameters.num_heads);
+ DUMP_STRING("KV num heads = ", parameters.kv_num_heads);
+ DUMP_STRING("Head size = ", parameters.head_size);
+ DUMP_STRING("Num blocks = ", parameters.num_blocks);
+ DUMP_STRING("Block size = ", parameters.block_size);
+ DUMP_STRING("Max num blocks per sequence = ", parameters.max_num_blocks_per_seq);
+ DUMP_STRING("Rotary dimension = ", parameters.rotary_dim);
+ DUMP_STRING("Is packed QKV = ", parameters.is_packed_qkv);
+
+ // Check rotary
+ if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cos_cache and sin_cache must be passed to PagedAttention when do_rotary = 1");
+ }
+
+ // Set output tensor shapes
+ TensorShapeVector output_shape(2);
+ output_shape[0] = static_cast(parameters.token_count);
+ output_shape[1] = static_cast(parameters.hidden_size);
+ Tensor* output = context->Output(0, output_shape);
+
+ TensorShapeVector key_cache_out_shape(4);
+ key_cache_out_shape[0] = static_cast(parameters.num_blocks);
+ key_cache_out_shape[1] = static_cast(parameters.block_size);
+ key_cache_out_shape[2] = static_cast(parameters.kv_num_heads);
+ key_cache_out_shape[3] = static_cast(parameters.head_size);
+ Tensor* key_cache_out = context->Output(1, key_cache_out_shape);
+
+ TensorShapeVector value_cache_out_shape(4);
+ value_cache_out_shape[0] = static_cast(parameters.num_blocks);
+ value_cache_out_shape[1] = static_cast(parameters.block_size);
+ value_cache_out_shape[2] = static_cast(parameters.kv_num_heads);
+ value_cache_out_shape[3] = static_cast(parameters.head_size);
+ Tensor* value_cache_out = context->Output(2, value_cache_out_shape);
+
+ if (key_cache_out != nullptr && key_cache->Data() != key_cache_out->MutableData()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "key_cache and key_cache_out must be the same buffer");
+ } else if (value_cache_out != nullptr && value_cache->Data() != value_cache_out->MutableData()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "value_cache and value_cache_out must be the same buffer");
+ }
+
+ // Check flash kernel availability and allocate buffers
+#if USE_FLASH_ATTENTION
+ bool use_flash_attention = !disable_flash_attention_ &&
+ onnxruntime::flash::is_supported(device_prop,
+ parameters.head_size,
+ parameters.num_heads,
+ parameters.kv_num_heads);
+ size_t softmax_lse_bytes = 0;
+ if (use_flash_attention) {
+ softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.token_count,
+ parameters.num_heads);
+ }
+ auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream());
+#else
+ constexpr bool use_flash_attention = false;
+ auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr
+#endif
+
+ if (!use_flash_attention) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Currently PagedAttention is only supported through the FlashAttention kernel.");
+ }
+
+ size_t cumulative_seqlens_kv_bytes = sizeof(int) * (parameters.batch_size + 1);
+ auto cumulative_seqlens_kv_buffer = GetScratchBuffer(cumulative_seqlens_kv_bytes, context->GetComputeStream());
+
+ size_t workspace_buffer_bytes = 0;
+ if (do_rotary_) {
+ workspace_buffer_bytes = sizeof(T) * parameters.token_count * (parameters.hidden_size + parameters.kv_hidden_size);
+ } else if (parameters.is_packed_qkv) {
+ workspace_buffer_bytes = sizeof(T) * parameters.token_count * parameters.hidden_size;
+ }
+ auto workspace_buffer = GetScratchBuffer(workspace_buffer_bytes, context->GetComputeStream());
+
+ // Print debug info
+ if (kernel_options_->AllowDebugInfo()) {
+ AttentionKernelDebugInfo debug_info;
+ debug_info.use_flash_attention = use_flash_attention;
+
+ debug_info.Print("PagedAttention",
+ this->Node().Name(),
+ std::is_same::value,
+ std::is_same::value);
+ }
+
+ // Set up data struct for kernel launch
+ data.query = reinterpret_cast(query->Data());
+ data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data());
+ data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data());
+ data.key_cache = reinterpret_cast(const_cast(key_cache->Data()));
+ data.value_cache = reinterpret_cast(const_cast(value_cache->Data()));
+ data.cumulative_seqlens_q = reinterpret_cast(cumulative_seqlens_q->Data());
+ data.past_seqlens = reinterpret_cast(past_seqlens->Data());
+ data.cumulative_seqlens_kv = reinterpret_cast(cumulative_seqlens_kv_buffer.get());
+ data.block_table = reinterpret_cast(block_table->Data());
+ data.output = reinterpret_cast(output->MutableData());
+ data.use_flash_attention = use_flash_attention;
+ if (softmax_lse_buffer != nullptr) {
+ data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get());
+ }
+ if (workspace_buffer != nullptr) {
+ data.workspace_buffer = reinterpret_cast(workspace_buffer.get());
+ }
+ if (parameters.do_rotary) {
+ data.cos_cache = reinterpret_cast(cos_cache->Data());
+ data.sin_cache = reinterpret_cast(sin_cache->Data());
+ }
+
+ cublasHandle_t cublas = GetCublasHandle(context);
+
+ return QkvToContext(
+ device_prop, cublas, context->GetComputeStream(), parameters, data);
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.h b/onnxruntime/contrib_ops/cuda/bert/paged_attention.h
new file mode 100644
index 0000000000000..a3df144745f61
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.h
@@ -0,0 +1,37 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include "core/providers/cuda/cuda_kernel.h"
+#include "contrib_ops/cuda/bert/paged_attention_impl.h"
+#include "contrib_ops/cuda/bert/attention_kernel_options.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+template
+class PagedAttention final : public CudaKernel {
+ public:
+ PagedAttention(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ protected:
+ int num_heads_; // number of attention heads
+ int kv_num_heads_; // different for k and v for group query attention
+ int local_window_size_;
+ bool do_rotary_;
+ bool rotary_interleaved_;
+ float scale_;
+ float softcap_;
+ bool disable_flash_attention_;
+ const AttentionKernelOptions* kernel_options_;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h
new file mode 100644
index 0000000000000..6fb8969aa9d0a
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h
@@ -0,0 +1,278 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/providers/common.h"
+#include "contrib_ops/cpu/bert/attention_common.h"
+#include "contrib_ops/cpu/bert/attention_parameters.h"
+#include "contrib_ops/cpu/bert/group_query_attention_helper.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace paged_attention_helper {
+
+template
+Status Check_Q_K_V(const T* query, const T* key, const T* value, const int num_heads, const int kv_num_heads,
+ int& token_count, int& q_hidden_size, int& kv_hidden_size, int& head_size) {
+ const auto& query_dims = query->Shape().GetDims();
+ if (query_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 2 dimensions, got ",
+ query_dims.size());
+ }
+ token_count = static_cast(query_dims[0]);
+ q_hidden_size = static_cast(query_dims[1]);
+ head_size = static_cast(q_hidden_size) / num_heads;
+ if (head_size % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size must be a multiple of 8. Got head_size % 8 == ",
+ head_size % 8);
+ }
+ if (value == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
+ }
+ const auto& key_dims = key->Shape().GetDims();
+ if (key_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 2 dimensions, got ",
+ key_dims.size());
+ } else if (token_count != key_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 0 (token count)");
+ }
+ kv_hidden_size = static_cast(key_dims[1]);
+ if (kv_hidden_size % kv_num_heads != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "kv_hidden_size must be a multiple of kv_num_heads. Got kv_hidden_size % kv_num_heads == ",
+ kv_hidden_size % kv_num_heads);
+ } else if (kv_hidden_size / kv_num_heads != head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "kv_hidden_size / kv_num_heads must be equal to head_size. Got kv_hidden_size / kv_num_heads == ",
+ kv_hidden_size / kv_num_heads);
+ }
+ const auto& value_dims = value->Shape().GetDims();
+ if (value_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 2 dimensions, got ",
+ value_dims.size());
+ } else if (token_count != value_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'value' shall have same dim 0 (token count)");
+ } else if (value_dims[1] != kv_hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
+ }
+ return Status::OK();
+}
+
+template
+Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const int kv_num_heads, int& token_count,
+ int& q_hidden_size, int& kv_hidden_size, int& head_size) {
+ const auto& packed_dims = packed_qkv->Shape().GetDims();
+ if (packed_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 2 dimensions, got ",
+ packed_dims.size());
+ }
+ token_count = static_cast(packed_dims[0]);
+ head_size = static_cast(static_cast(packed_dims[1])) / (num_heads + 2 * kv_num_heads);
+ if (head_size % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size must be a multiple of 8. Got head_size % 8 == ",
+ head_size % 8);
+ }
+ if (value != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
+ }
+ q_hidden_size = head_size * num_heads;
+ kv_hidden_size = head_size * kv_num_heads;
+ return Status::OK();
+}
+
+template
+Status CheckKVCache(const T* key_cache, const T* value_cache, const int kv_num_heads, const int head_size,
+ int& num_blocks, int& block_size) {
+ const auto& key_cache_dims = key_cache->Shape().GetDims();
+ const auto& value_cache_dims = value_cache->Shape().GetDims();
+ if (key_cache_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key_cache' is expected to have 4 dimensions, got ",
+ key_cache_dims.size());
+ }
+ if (value_cache_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'value_cache' is expected to have 4 dimensions, got ",
+ value_cache_dims.size());
+ }
+
+ num_blocks = static_cast(key_cache_dims[0]);
+ block_size = static_cast(key_cache_dims[1]);
+ // TODO(aciddelgado): block size multiple of 8
+ if (block_size % 256 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "block_size must be a multiple of 256. Got block_size % 256 == ",
+ block_size % 256);
+ }
+ if (value_cache_dims[0] != num_blocks) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'value_cache' dimension 0 should be num_blocks, got ",
+ value_cache_dims[0]);
+ } else if (value_cache_dims[1] != block_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'value_cache' dimension 1 should be block_size, got ",
+ value_cache_dims[0]);
+ }
+
+ if (key_cache_dims[2] != value_cache_dims[2]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key_cache' and 'value_cache' dimension 2 (kv num heads) should be the same, got ",
+ key_cache_dims[2], " and ", value_cache_dims[2]);
+ }
+ if (key_cache_dims[2] != kv_num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key_cache' shall have kv_num_heads, got ",
+ key_cache_dims[2]);
+ }
+ if (value_cache_dims[2] != kv_num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'value_cache' shall have kv_num_heads, got ",
+ value_cache_dims[2]);
+ }
+
+ if (key_cache_dims[3] != value_cache_dims[3]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key_cache' and 'value_cache' dimension 3 (head size) should be the same, got ",
+ key_cache_dims[3], " and ", value_cache_dims[3]);
+ }
+ if (key_cache_dims[3] != head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key_cache' dimension 3 should be same as head_size, got ",
+ key_cache_dims[3]);
+ }
+ if (value_cache_dims[3] != head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' dimension 3 should be same as head_size, got ",
+ value_cache_dims[3]);
+ }
+ return Status::OK();
+}
+
+template
+Status CheckSequenceLengthTensors(const T* cumulative_sequence_length, const T* seqlens, int& batch_size) {
+ const auto& cumulative_seqlen_dim = cumulative_sequence_length->Shape().GetDims();
+ if (cumulative_seqlen_dim.size() != 1 || cumulative_seqlen_dim[0] < 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cumulative_sequence_length must be shape (batch_size + 1).");
+ }
+ batch_size = static_cast(cumulative_seqlen_dim[0]) - 1;
+
+ const auto& seqlens_dim = seqlens->Shape().GetDims();
+ if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "seqlens must be shape (batch_size).");
+ }
+ return Status::OK();
+}
+
+template
+Status CheckBlockTable(const T* block_table, const int batch_size, int& max_num_blocks_per_seq) {
+ const auto& block_table_dims = block_table->Shape().GetDims();
+ if (block_table_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "block_table must be 2D.");
+ } else if (block_table_dims[0] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "block_table dimension 0 should be batch_size, got ",
+ block_table_dims[0]);
+ }
+ max_num_blocks_per_seq = static_cast(block_table_dims[1]);
+ return Status::OK();
+}
+
+template
+Status CheckInputs(const T* query,
+ const T* key,
+ const T* value,
+ const T* key_cache,
+ const T* value_cache,
+ const T* cumulative_sequence_length,
+ const T* seqlens,
+ const T* block_table,
+ const T* cos_cache,
+ const T* sin_cache,
+ void* parameters,
+ int num_heads,
+ int kv_num_heads,
+ float scale,
+ float softcap,
+ int max_threads_per_block) {
+ if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
+ }
+ if (num_heads % kv_num_heads != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
+ num_heads % kv_num_heads);
+ }
+
+ // Check query, key, and value
+ int token_count = 0;
+ int q_hidden_size = 0;
+ int kv_hidden_size = 0;
+ int head_size = 0;
+ const bool is_packed_qkv = key == nullptr;
+ if (!is_packed_qkv) {
+ ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, token_count, q_hidden_size,
+ kv_hidden_size, head_size));
+ } else {
+ ORT_RETURN_IF_ERROR(Check_QKV(query, value, num_heads, kv_num_heads, token_count, q_hidden_size, kv_hidden_size,
+ head_size));
+ }
+
+ // Check KV-Cache
+ int num_blocks = 0;
+ int block_size = 0;
+ ORT_RETURN_IF_ERROR(CheckKVCache(key_cache, value_cache, kv_num_heads, head_size, num_blocks, block_size));
+
+ // Check sequence length tensors
+ int batch_size = 0;
+ ORT_RETURN_IF_ERROR(CheckSequenceLengthTensors(cumulative_sequence_length, seqlens, batch_size));
+
+ // Check block table and slot mappings
+ int max_num_blocks_per_seq = 0;
+ ORT_RETURN_IF_ERROR(CheckBlockTable(block_table, batch_size, max_num_blocks_per_seq));
+
+ // Check rotary cache
+ int rotary_dim = 0;
+ if (cos_cache != nullptr && sin_cache != nullptr) {
+ // 0 to bypass checking rotary cache size
+ ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckRotaryCaches(cos_cache, sin_cache, head_size,
+ 0, rotary_dim));
+ } else if (cos_cache != nullptr || sin_cache != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
+ }
+
+ if (parameters != nullptr) {
+ PagedAttentionParameters* output_parameters = reinterpret_cast(parameters);
+ output_parameters->batch_size = batch_size;
+ output_parameters->token_count = token_count;
+ output_parameters->hidden_size = q_hidden_size;
+ output_parameters->kv_hidden_size = kv_hidden_size;
+ output_parameters->num_heads = num_heads;
+ output_parameters->kv_num_heads = kv_num_heads;
+ output_parameters->head_size = head_size;
+ output_parameters->block_size = block_size;
+ output_parameters->max_num_blocks_per_seq = max_num_blocks_per_seq;
+ output_parameters->num_blocks = num_blocks;
+ output_parameters->rotary_dim = rotary_dim;
+ output_parameters->is_packed_qkv = is_packed_qkv;
+ output_parameters->scale = scale;
+ output_parameters->softcap = softcap;
+ }
+
+ return Status::OK();
+}
+
+} // namespace paged_attention_helper
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu
new file mode 100644
index 0000000000000..7ecdf51bdde11
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu
@@ -0,0 +1,378 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/shared_inc/fpgeneric.h"
+#include "contrib_ops/cuda/bert/attention_softmax.h"
+#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
+#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
+#include "contrib_ops/cuda/bert/paged_attention_impl.h"
+#include "core/providers/cuda/shared_inc/cuda_call.h"
+#include "contrib_ops/cuda/bert/rotary_embedding_impl.h"
+#include
+
+using namespace onnxruntime::cuda;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+////////// Auxiliary Kernels
+
+template
+__global__ void UnpackQKVCumulative(const T* packed_qkv, T* unpacked_qkv, const int token_count, const int num_heads,
+ const int kv_num_heads, const int head_size) {
+ const int tid = threadIdx.x + blockIdx.x * blockDim.x;
+ if (tid >= token_count * (num_heads + 2 * kv_num_heads) * head_size) {
+ return;
+ }
+ const int q_hidden_size = num_heads * head_size;
+ const int kv_hidden_size = kv_num_heads * head_size;
+ const int in_seq_stride = q_hidden_size + 2 * kv_hidden_size;
+
+ int packed_i;
+ if (tid < token_count * q_hidden_size) {
+ const int token_id = tid / q_hidden_size;
+ const int offset = tid % q_hidden_size;
+ packed_i = token_id * in_seq_stride + offset;
+ } else if (tid < token_count * (q_hidden_size + kv_hidden_size)) {
+ const int id = tid - token_count * q_hidden_size;
+ const int token_id = id / kv_hidden_size;
+ const int offset = id % kv_hidden_size;
+ packed_i = token_id * in_seq_stride + q_hidden_size + offset;
+ } else if (tid < token_count * (q_hidden_size + 2 * kv_hidden_size)) {
+ const int id = tid - token_count * (q_hidden_size + kv_hidden_size);
+ const int token_id = id / kv_hidden_size;
+ const int offset = id % kv_hidden_size;
+ packed_i = token_id * in_seq_stride + q_hidden_size + kv_hidden_size + offset;
+ }
+ unpacked_qkv[tid] = packed_qkv[packed_i];
+}
+
+// Since QKV is unpacked into a single workspace buffer, this is similar to a transpose
+template
+Status LaunchUnpackQKVCumulative(const T* packed_qkv, T* unpacked_qkv, const int token_count, const int num_heads,
+ const int kv_num_heads, const int head_size, cudaStream_t stream,
+ const int max_threads_per_block) {
+ const int threads = max_threads_per_block;
+ const int blocks = (token_count * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads;
+ UnpackQKVCumulative<<>>(packed_qkv, unpacked_qkv, token_count, num_heads, kv_num_heads,
+ head_size);
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template
+__global__ void UnpackV(const T* input, T* output, const int token_count, const int hidden_size,
+ const int packed_seq_stride) {
+ const int tid = threadIdx.x + blockIdx.x * blockDim.x;
+ if (tid < token_count * hidden_size) {
+ int offset = tid % hidden_size;
+ int token_id = tid / hidden_size;
+ int packed_i = token_id * packed_seq_stride + offset;
+ output[tid] = input[packed_i];
+ }
+}
+
+template
+Status LaunchUnpackCumulative(const T* input, T* output, const int token_count, const int hidden_size,
+ const int packed_seq_stride, cudaStream_t stream, const int max_threads_per_block) {
+ const int threads = std::min(max_threads_per_block, token_count * hidden_size);
+ const int blocks = (token_count * hidden_size + threads - 1) / threads;
+ UnpackV<<>>(input, output, token_count, hidden_size, packed_seq_stride);
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template
+__global__ void RotaryEmbeddingTNH(T* output, // TxNxH
+ const T* input, // TxNxH
+ const T* cos_cache, // Mx(H/2)
+ const T* sin_cache, // Mx(H/2)
+ const int32_t* past_seqlens, // B
+ const int32_t* cumulative_seqlens_q, // B+1
+ const int head_size,
+ const int rotary_embedding_dim,
+ const bool interleaved,
+ const int3 in_strides, // TxNxH
+ const int3 out_strides) { // TxNxH
+ // Use .x in innermost loop to access global memory efficiently
+
+ const int b = blockIdx.y;
+ const int s = blockIdx.x;
+ const int n = blockIdx.z;
+ const int h = threadIdx.x;
+
+ const int sequence_length = cumulative_seqlens_q[b + 1] - cumulative_seqlens_q[b];
+ if (h >= head_size || s >= sequence_length) {
+ return;
+ }
+
+ const int t = cumulative_seqlens_q[b] + s; // t is the index of the token in the unpadded input/output
+ const T* input_data = input + t * in_strides.x + n * in_strides.y;
+ T* output_data = output + t * out_strides.x + n * out_strides.y;
+
+ if (h >= rotary_embedding_dim) {
+ output_data[h] = input_data[h];
+ return;
+ }
+
+ // Cache is (M, H/2)
+ const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
+ const int position_id = past_seqlens[b] + s;
+ const int cache_offset = position_id * half_rotary_embedding_dim;
+ const T* cos_data = cos_cache + cache_offset;
+ const T* sin_data = sin_cache + cache_offset;
+
+ int cache_idx = 0;
+ T sign = 0;
+ int j = 0;
+ if (interleaved) {
+ cache_idx = (h / 2) % half_rotary_embedding_dim;
+ sign = (h % 2 == 0) ? -1 : 1;
+ j = (h % 2 == 0) ? h + 1 : h - 1; // i - sign
+ } else {
+ cache_idx = h % half_rotary_embedding_dim;
+ sign = (h < half_rotary_embedding_dim) ? -1 : 1;
+ j = (h + half_rotary_embedding_dim) % rotary_embedding_dim;
+ }
+ output_data[h] = input_data[h] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
+}
+
+template
+Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int32_t* past_seqlens,
+ const int32_t* cumulative_seqlens_q, const T* cos_cache, const T* sin_cache,
+ const int batch_size, const int max_seqlen_q, const int num_heads,
+ const int head_size, const int rotary_embedding_dim, const bool interleaved,
+ const int in_seq_stride, const int max_threads_per_block) {
+ ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block");
+ int3 in_strides = {in_seq_stride <= 0 ? num_heads * head_size : in_seq_stride, head_size, 1};
+ int3 out_strides = {num_heads * head_size, head_size, 1};
+ int tpb = (head_size + 31) / 32 * 32;
+
+ const dim3 grid(max_seqlen_q, batch_size, num_heads);
+ const dim3 block(tpb);
+ RotaryEmbeddingTNH<<>>(
+ output, input, cos_cache, sin_cache, past_seqlens, cumulative_seqlens_q, head_size, rotary_embedding_dim,
+ interleaved, in_strides, out_strides);
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template
+__global__ void GetCumulativeSeqlensKV(int32_t* cumulative_seqlens_kv, const int32_t* cumulative_seqlens_q,
+ const int32_t* past_seqlens, const int batch_size) {
+ int id = blockDim.x * blockIdx.x + threadIdx.x;
+
+ if (id == 0) {
+ cumulative_seqlens_kv[0] = 0;
+ }
+
+ typedef cub::BlockScan BlockScan;
+ __shared__ typename BlockScan::TempStorage temp_storage;
+
+ // Sum past_seqlens to new sequence length (which we get by subtracting cumulative_seqlens_q).
+ // Then do an inclusive sum across present sequence lengths to get the cumulative sequence length
+ if (id < batch_size) {
+ cumulative_seqlens_kv[id + 1] = past_seqlens[id] + cumulative_seqlens_q[id + 1] - cumulative_seqlens_q[id];
+ int length = cumulative_seqlens_kv[id + 1];
+ BlockScan(temp_storage).InclusiveSum(length, length);
+ cumulative_seqlens_kv[id + 1] = length;
+ }
+}
+
+Status LaunchGetCumulativeSeqlensKV(int32_t* cumulative_seqlens_kv, const int32_t* cumulative_seqlens_q,
+ const int32_t* past_seqlens, const int batch_size, cudaStream_t stream) {
+ const int threads = 256;
+ const int blocks = (batch_size + threads - 1) / threads;
+ GetCumulativeSeqlensKV<256><<>>(cumulative_seqlens_kv, cumulative_seqlens_q, past_seqlens,
+ batch_size);
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template
+__global__ void ReshapeAndCache(const T* __restrict__ key, const T* __restrict__ value, T* __restrict__ key_cache,
+ T* __restrict__ value_cache, const int* __restrict__ block_table,
+ const int* __restrict__ past_seqlens, const int* __restrict__ cumulative_seqlens_q,
+ const int batch_size, const int max_num_blocks_per_seq, const int token_count,
+ const int kv_hidden_size, const int block_size, const int key_stride,
+ const int value_stride) {
+ const int tid = threadIdx.x + blockIdx.x * blockDim.x;
+ if (tid >= token_count * kv_hidden_size) {
+ return;
+ }
+ const int token_id = tid / kv_hidden_size;
+ const int hidden_offset = tid % kv_hidden_size;
+ int batch_id = 0;
+ for (int i = 0; i < batch_size; ++i) {
+ if (token_id < cumulative_seqlens_q[i + 1]) {
+ batch_id = i;
+ break;
+ }
+ }
+ const int token_offset = token_id - cumulative_seqlens_q[batch_id];
+ const int past_length = past_seqlens[batch_id];
+ const int block_id = block_table[batch_id * max_num_blocks_per_seq + (past_length + token_offset) / block_size];
+ const int block_offset = (past_length + token_offset) % block_size;
+
+ const int key_id = token_id * key_stride + hidden_offset;
+ const int value_id = token_id * value_stride + hidden_offset;
+ const int dst_id = block_id * block_size * kv_hidden_size + block_offset * kv_hidden_size + hidden_offset;
+ key_cache[dst_id] = key[key_id];
+ value_cache[dst_id] = value[value_id];
+}
+
+template
+Status LaunchReshapeAndCache(const T* key, const T* value, T* key_cache, T* value_cache, const int* block_table,
+ const int* past_seqlens, const int* cumulative_seqlens_q, const int batch_size,
+ const int max_num_blocks_per_seq, const int token_count, const int kv_hidden_size,
+ const int block_size, const int key_stride, const int value_stride, cudaStream_t stream,
+ const int max_threads_per_block) {
+ const int total_size = token_count * kv_hidden_size;
+ const int threads(std::min(total_size, max_threads_per_block));
+ const int blocks((total_size + threads - 1) / threads);
+ ReshapeAndCache<<>>(key, value, key_cache, value_cache, block_table, past_seqlens,
+ cumulative_seqlens_q, batch_size, max_num_blocks_per_seq,
+ token_count, kv_hidden_size, block_size, key_stride, value_stride);
+ return CUDA_CALL(cudaGetLastError());
+}
+
+////////// Launch Kernels
+
+#if USE_FLASH_ATTENTION
+template
+Status FlashAttention(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ contrib::PagedAttentionParameters& parameters,
+ PagedAttentionData& data,
+ float scale) {
+ // Get parameters
+ const int max_threads_per_block = device_prop.maxThreadsPerBlock;
+ const int batch_size = parameters.batch_size;
+ const int token_count = parameters.token_count;
+ const int q_hidden_size = parameters.hidden_size;
+ const int kv_hidden_size = parameters.kv_hidden_size;
+ const int num_heads = parameters.num_heads;
+ const int kv_num_heads = parameters.kv_num_heads;
+ const int head_size = parameters.head_size;
+ const float softcap = parameters.softcap;
+ bool is_bf16 = std::is_same::value;
+ const int local_window_size = parameters.local_window_size;
+ const int max_num_blocks_per_seq = parameters.max_num_blocks_per_seq;
+ const int block_size = parameters.block_size;
+ // The following are passed to flash api but not used by the kernel, so they can be determined heuristically
+ const int max_query_len = token_count - batch_size + 1;
+ const int max_seq_len = parameters.max_num_blocks_per_seq * parameters.block_size;
+
+ T* query = const_cast(data.query);
+ T* key;
+ T* value;
+ if (!parameters.is_packed_qkv) {
+ key = const_cast(data.key);
+ value = const_cast(data.value);
+ } else {
+ key = reinterpret_cast(query) + static_cast(num_heads * head_size);
+ value = reinterpret_cast(key) + static_cast(kv_num_heads * head_size);
+ }
+
+ // Calculate cumulative present sequence length in cumulative_seqlens_kv
+ int* cumulative_seqlens_q = const_cast(data.cumulative_seqlens_q);
+ int* past_seqlens = const_cast(data.past_seqlens);
+ int* cumulative_seqlens_kv = data.cumulative_seqlens_kv;
+ ORT_RETURN_IF_ERROR(LaunchGetCumulativeSeqlensKV(cumulative_seqlens_kv, cumulative_seqlens_q, past_seqlens,
+ batch_size, stream));
+
+ if (parameters.do_rotary) {
+ // Will unpack Q and K in case of packed_qkv
+ auto q_buffer = data.workspace_buffer;
+ auto k_buffer = data.workspace_buffer + token_count * num_heads * head_size;
+ const int packed_seq_stride = parameters.is_packed_qkv ? (num_heads + 2 * kv_num_heads) * head_size : -1;
+ ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(
+ stream, q_buffer, query, past_seqlens, cumulative_seqlens_q, data.cos_cache, data.sin_cache, batch_size,
+ max_query_len, num_heads, head_size, parameters.rotary_dim, parameters.rotary_interleaved, packed_seq_stride,
+ max_threads_per_block));
+ ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(
+ stream, k_buffer, key, past_seqlens, cumulative_seqlens_q, data.cos_cache, data.sin_cache, batch_size,
+ max_query_len, kv_num_heads, head_size, parameters.rotary_dim, parameters.rotary_interleaved, packed_seq_stride,
+ max_threads_per_block));
+ query = q_buffer;
+ key = k_buffer;
+ } else if (parameters.is_packed_qkv) {
+ // Only unpack Q. K and V are unpacked by ReshapeAndCache.
+ auto q_buffer = data.workspace_buffer;
+ const int packed_seq_stride = q_hidden_size + 2 * kv_hidden_size;
+ ORT_RETURN_IF_ERROR(LaunchUnpackCumulative(
+ query, q_buffer, token_count, q_hidden_size, packed_seq_stride, stream, max_threads_per_block));
+ query = q_buffer;
+ }
+
+ // Insert key and value into block-based KV cache
+ int* block_table = const_cast(data.block_table);
+ const int key_stride = parameters.is_packed_qkv && !parameters.do_rotary ? q_hidden_size + 2 * kv_hidden_size : kv_hidden_size;
+ const int value_stride = parameters.is_packed_qkv ? q_hidden_size + 2 * kv_hidden_size : kv_hidden_size;
+ ORT_RETURN_IF_ERROR(LaunchReshapeAndCache(key, value, data.key_cache, data.value_cache, block_table, past_seqlens,
+ cumulative_seqlens_q, batch_size, max_num_blocks_per_seq, token_count,
+ kv_hidden_size, block_size, key_stride, value_stride, stream,
+ max_threads_per_block));
+
+ // Launch kernel
+ void* q = reinterpret_cast(query);
+ void* key_cache = reinterpret_cast(data.key_cache);
+ void* value_cache = reinterpret_cast(data.value_cache);
+ void* output = reinterpret_cast(data.output);
+ void* softmax_lse = reinterpret_cast(data.softmax_lse);
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_varlen_fwd(
+ device_prop, stream, q, key_cache, value_cache, output, cumulative_seqlens_q, cumulative_seqlens_kv,
+ /*seqused_k*/ nullptr, block_table, softmax_lse, batch_size, num_heads, kv_num_heads, head_size, max_query_len,
+ max_seq_len, token_count, scale, softcap, /*is_causal*/ true, is_bf16, local_window_size, max_num_blocks_per_seq,
+ block_size));
+
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR("flash attention output", data.output, token_count, num_heads, head_size);
+
+ return Status::OK();
+}
+#endif
+
+////////// API Functions
+
+template
+Status QkvToContext(
+ const cudaDeviceProp& device_prop,
+ cublasHandle_t& /*cublas*/,
+ Stream* ort_stream,
+ contrib::PagedAttentionParameters& parameters,
+ PagedAttentionData& data) {
+ auto stream = static_cast(ort_stream->GetHandle());
+ const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale;
+
+#if USE_FLASH_ATTENTION
+ if (data.use_flash_attention) {
+ return FlashAttention(device_prop, stream, parameters, data, scale);
+ }
+#endif
+
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Paged Attention not implemented.");
+}
+
+template struct PagedAttentionData;
+template Status QkvToContext(
+ const cudaDeviceProp& device_prop,
+ cublasHandle_t& cublas,
+ Stream* ort_stream,
+ contrib::PagedAttentionParameters& parameters,
+ PagedAttentionData& data);
+
+template struct PagedAttentionData;
+template Status QkvToContext(
+ const cudaDeviceProp& device_prop,
+ cublasHandle_t& cublas,
+ Stream* ort_stream,
+ contrib::PagedAttentionParameters& parameters,
+ PagedAttentionData& data);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h
new file mode 100644
index 0000000000000..7e27556a5c63f
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h
@@ -0,0 +1,32 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+#include
+#include
+#include "contrib_ops/cpu/bert/attention_common.h"
+#include "contrib_ops/cpu/bert/attention_parameters.h"
+#include "contrib_ops/cuda/bert/attention_data.h"
+#include "core/framework/allocator.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+Status QkvToContext(
+ const cudaDeviceProp& device_prop,
+ cublasHandle_t& cublas,
+ Stream* stream,
+ contrib::PagedAttentionParameters& parameters,
+ PagedAttentionData& data);
+
+template
+Status LaunchUnpackQKVCumulative(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads,
+ const int kv_num_heads, const int head_size, const int token_count, cudaStream_t stream,
+ const int max_threads_per_block);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 17f3433aed38a..d016d50d6c445 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -99,6 +99,8 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_MLFloat16, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, GroupQueryAttention);
+class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PagedAttention);
+class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, PagedAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderAttention);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, int32_t, DynamicSlice);
@@ -311,6 +313,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 238dd8d4573de..f2757c2c96471 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1206,6 +1206,202 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
GroupQueryAttentionTypeAndShapeInference(ctx, 3);
}));
+constexpr const char* PagedAttention_ver1_doc = R"DOC(
+Paged Attention.
+
+This op leverages a block-based KV cache to enable continuous batching for LLMs. Currently, it is designed to work with
+the CUDA Execution Provider only.
+
+In other attention ops, batch entries typically aren't of the same length, so they are padded.
+Below is a batch with 3 sequences where * denotes a padding token.
+ Sequence_0: 0, 1*, 2*, 3*
+ Sequence_1: 4, 5, 6*, 7*
+ Sequence_2: 8, 9, 10, 11
+
+PagedAttention is designed to take in packed input, i.e., only the real tokens without padding.
+For example, the input shown above will be packed into 3 tensors like below:
+ - query ([q0, q4, q5, q8, q9, q10, q11])
+ - key ([k0, k4, k5, k8, k9, k10, k11])
+ - value ([v0, v4, v5, v8, v9, v10, v11])
+ - cumulative_sequence_length: 0, 1, 1+2, 1+2+4
+This packing omits padding tokens.
+
+The query, key and value tensors contain result of hidden embedding of real tokens after input projections.
+cumulative_sequence_length records cumulated length of each sequence length.
+
+)DOC";
+
+// Shape inference for PagedAttention. Here are the shapes of inputs and output:
+// When Q, K and V are not packed:
+// Input 'query': (token_count, hidden_size)
+// Input 'key': (token_count, kv_hidden_size)
+// Input 'value': (token_count, kv_hidden_size)
+// When Q, K and V are packed:
+// Input 'query': (token_count, (num_heads + 2 * kv_num_heads) * head_size)
+// Input 'key': None
+// Input 'value': None
+// Input 'key_cache': (num_blocks, block_size, kv_num_heads, head_size)
+// Input 'value_cache': (num_blocks, block_size, kv_num_heads, head_size)
+// Input 'cumulative_sequence_length': (batch_size + 1)
+// Input 'seqlens': (batch_size)
+// Input 'block_table': (batch_size, max_blocks_per_sequence)
+// Input 'cos_cache': (max_seq_len, head_size / 2)
+// Input 'sin_cache': (max_seq_len, head_size / 2)
+// Output 'output': (token_count, hidden_size)
+// Output 'key_cache_out': (num_blocks, block_size, kv_num_heads, head_size)
+// Output 'value_cache_out': (num_blocks, block_size, kv_num_heads, head_size)
+void PagedAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
+ // Type inference
+ ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
+
+ // Shape inference for output tensor
+ if (hasInputShape(ctx, 0)) {
+ auto& query_shape = getInputShape(ctx, 0);
+ auto& query_dims = query_shape.dim();
+
+ if (query_dims.size() != 2) {
+ fail_shape_inference("Input 0 (query) shall be 2 dimensions");
+ }
+
+ if (ctx.hasInput(2)) {
+ ONNX_NAMESPACE::TensorShapeProto output_shape;
+ propagateShapeFromInputToOutput(ctx, 0, 0);
+ } else { // packed QKV
+ ONNX_NAMESPACE::TensorShapeProto output_shape;
+ *output_shape.add_dim() = query_dims[0];
+ int64_t num_heads = getAttribute(ctx, "num_heads", 0);
+ int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0);
+ int64_t hidden_size = query_dims[1].dim_value();
+ if (hidden_size <= 0 || num_heads <= 0 || kv_num_heads < 0) {
+ fail_shape_inference("Invalid hidden size or number of heads. Hidden size, num_heads and kv_num_heads must be positive integers.");
+ } else if (hidden_size % (num_heads + 2 * kv_num_heads) != 0) {
+ fail_shape_inference("Hidden size must be divisible by (num_heads + 2 * kv_num_heads).");
+ }
+ int64_t head_size = hidden_size / (num_heads + 2 * kv_num_heads);
+ output_shape.add_dim()->set_dim_value(head_size * num_heads);
+ updateOutputShape(ctx, 0, output_shape);
+ }
+ }
+
+ // Shape inference for KV Cache output tensors
+ if (ctx.getNumOutputs() > 1) { // has kv cache output
+ if (ctx.getNumOutputs() != 3) {
+ fail_shape_inference("Key cache and value cache output tensors must be both present or both absent.");
+ }
+ // types
+ ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1);
+ ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2);
+ // shapes
+ auto& key_cache_shape = getInputShape(ctx, 3);
+ auto& key_cache_dims = key_cache_shape.dim();
+ if (key_cache_dims.size() != 4) {
+ fail_shape_inference("The block-based KV cache inputs shall be 4 dimensions");
+ }
+ // KV cache in and out share the same buffer, thus they have the same shape
+ ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 3, 1);
+ ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 4, 2);
+ ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 3, 1);
+ ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 4, 2);
+ }
+}
+
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ PagedAttention, 1,
+ OpSchema()
+ .SetDoc(PagedAttention_ver1_doc)
+ .Attr("num_heads", "Number of attention heads for q", AttributeProto::INT)
+ .Attr("kv_num_heads", "Number of attention heads for k and v", AttributeProto::INT)
+ .Attr("scale",
+ "Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
+ AttributeProto::FLOAT,
+ OPTIONAL_VALUE)
+ .Attr("softcap",
+ "Softcap value for attention weights. Default value is 0.",
+ AttributeProto::FLOAT,
+ OPTIONAL_VALUE)
+ .Attr("local_window_size",
+ "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.",
+ AttributeProto::INT,
+ static_cast(-1))
+ .Attr("do_rotary",
+ "Whether to use rotary position embedding. Default value is 0.",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
+ .Attr("rotary_interleaved",
+ "Rotate using interleaved pattern. Default value is 0 (False).",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
+ .Input(0,
+ "query",
+ "Query with shape (num_tokens, hidden_size), or packed QKV with shape (num_tokens, d) "
+ "where d is (num_heads * head_size + 2 * kv_num_heads * head_size).",
+ "T")
+ .Input(1,
+ "key",
+ "Key with shape (num_tokens, kv_hidden_size) ",
+ "T",
+ OpSchema::Optional)
+ .Input(2,
+ "value",
+ "Value with shape (num_tokens, kv_hidden_size)",
+ "T",
+ OpSchema::Optional)
+ .Input(3,
+ "key_cache",
+ "Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated in "
+ "place within the op.",
+ "T")
+ .Input(4,
+ "value_cache",
+ "Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated "
+ "in place within the op. This should be the same shape as key_cache.",
+ "T")
+ .Input(5,
+ "cumulative_sequence_length",
+ "A tensor with shape (batch_size + 1). It specifies the cumulative sequence lengths between the packed "
+ "entries in Q/K/V.",
+ "S")
+ .Input(6,
+ "past_seqlens",
+ "A tensor with shape (batch_size). It specifies the past lengths of cached sequence in the KV cache.",
+ "S")
+ .Input(7,
+ "block_table",
+ "2D tensor with shape (batch_size, max_blocks_per_sequence) that maps each sequence in the batch to its"
+ "corresponding blocks in the KV cache.",
+ "S")
+ .Input(8,
+ "cos_cache",
+ "2D tensor with shape (max total seqlen, head_size / 2).",
+ "T",
+ OpSchema::Optional)
+ .Input(9,
+ "sin_cache",
+ "2D tensor with shape (max total seqlen, head_size / 2).",
+ "T",
+ OpSchema::Optional)
+ .Output(0,
+ "output",
+ "3D output tensor with shape (num_tokens, hidden_size)",
+ "T")
+ .Output(1,
+ "key_cache_out",
+ "Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always "
+ "the same tensor as key_cache.",
+ "T",
+ OpSchema::Optional)
+ .Output(2,
+ "value_cache_out",
+ "Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always "
+ "the same tensor as value_cache.",
+ "T",
+ OpSchema::Optional)
+ .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
+ .TypeConstraint("S", {"tensor(int32)"}, "Constrain Positional inputs to int tensor.")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ PagedAttentionTypeAndShapeInference(ctx);
+ }));
+
constexpr const char* SparseAttention_ver1_doc = R"DOC(
Block Sparse Attention used in Phi-3-small (https://arxiv.org/pdf/2404.14219).
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index a9a89f756b071..6c20aae94d132 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -87,6 +87,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MoE);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, PagedAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad);
@@ -197,6 +198,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py
new file mode 100644
index 0000000000000..cc9a02a8074c0
--- /dev/null
+++ b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py
@@ -0,0 +1,735 @@
+# --------------------------------------------------------------------------
+# Copyright 2020 The HuggingFace Inc. team
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# -------------------------------------------------------------------------
+import math
+import platform
+import random
+import unittest
+
+import numpy
+import torch
+from einops import rearrange, repeat
+from onnx import TensorProto, helper
+from packaging import version
+from parameterized import parameterized
+from test_gqa_cpu import smooth_softmax_ref
+
+from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers
+
+torch.manual_seed(0)
+
+pipeline_mode = True # Reduces number of tests so pipeline doesn't time out
+
+
+class Config:
+ batch_size = 0
+ sequence_length = 0
+ total_sequence_length = 0
+ num_heads = 0
+ kv_num_heads = 0
+ head_size = 0
+ paged_kv_block_size = 0
+ local = False
+ rotary = False
+ rotary_interleaved = False
+ packed = False
+ softcap = 0.0
+ ep = "CUDAExecutionProvider"
+
+ def __init__(
+ self,
+ batch_size,
+ sequence_length,
+ total_sequence_length,
+ num_heads,
+ kv_num_heads,
+ head_size,
+ paged_kv_block_size,
+ local,
+ rotary,
+ rotary_interleaved,
+ packed,
+ softcap,
+ ):
+ self.batch_size = batch_size
+ self.sequence_length = sequence_length
+ self.total_sequence_length = total_sequence_length
+ self.num_heads = num_heads
+ self.kv_num_heads = kv_num_heads
+ self.head_size = head_size
+ self.paged_kv_block_size = paged_kv_block_size
+ self.local = local
+ self.rotary = rotary
+ self.rotary_interleaved = rotary_interleaved
+ self.packed = packed
+ self.softcap = softcap
+
+ def __repr__(self):
+ short_ep = self.ep[: -len("ExecutionProvider")].lower()
+ return (
+ f"Config(batch_size={self.batch_size}, sequence_length={self.sequence_length}, "
+ f"total_sequence_length={self.total_sequence_length}, num_heads={self.num_heads}, "
+ f"kv_num_heads={self.kv_num_heads}, head_size={self.head_size}, "
+ f"paged_kv_block_size={self.paged_kv_block_size} rotary={self.rotary}, "
+ f"rotary_interleaved={self.rotary_interleaved}, packed={self.packed}, softcap={self.softcap}, "
+ f"ep={short_ep})"
+ )
+
+
+def create_paged_attention_graph(
+ config,
+ num_tokens,
+ num_blocks,
+ max_blocks_per_sequence,
+ local_window_size=-1,
+):
+ nodes = [
+ helper.make_node(
+ "PagedAttention",
+ [
+ "query",
+ "key" if not config.packed else "",
+ "value" if not config.packed else "",
+ "key_cache",
+ "value_cache",
+ "cumulative_sequence_length",
+ "past_seqlens",
+ "block_table",
+ "cos_cache" if config.rotary else "",
+ "sin_cache" if config.rotary else "",
+ ],
+ ["output", "key_cache_out", "value_cache_out"],
+ "PagedAttention_0",
+ num_heads=config.num_heads,
+ kv_num_heads=config.kv_num_heads,
+ local_window_size=local_window_size,
+ do_rotary=config.rotary,
+ rotary_interleaved=config.rotary_interleaved,
+ softcap=config.softcap,
+ domain="com.microsoft",
+ ),
+ ]
+
+ graph_input = [
+ helper.make_tensor_value_info(
+ "query",
+ TensorProto.FLOAT16,
+ [
+ num_tokens,
+ (config.num_heads * config.head_size)
+ if not config.packed
+ else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size),
+ ],
+ ),
+ helper.make_tensor_value_info(
+ "key_cache",
+ TensorProto.FLOAT16,
+ [
+ num_blocks,
+ config.paged_kv_block_size,
+ config.kv_num_heads,
+ config.head_size,
+ ],
+ ),
+ helper.make_tensor_value_info(
+ "value_cache",
+ TensorProto.FLOAT16,
+ [
+ num_blocks,
+ config.paged_kv_block_size,
+ config.kv_num_heads,
+ config.head_size,
+ ],
+ ),
+ helper.make_tensor_value_info(
+ "cumulative_sequence_length",
+ TensorProto.INT32,
+ [config.batch_size + 1],
+ ),
+ helper.make_tensor_value_info(
+ "past_seqlens",
+ TensorProto.INT32,
+ [config.batch_size],
+ ),
+ helper.make_tensor_value_info(
+ "block_table",
+ TensorProto.INT32,
+ [config.batch_size, max_blocks_per_sequence],
+ ),
+ ]
+ if not config.packed:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "key",
+ TensorProto.FLOAT16,
+ [
+ num_tokens,
+ config.kv_num_heads * config.head_size,
+ ],
+ ),
+ helper.make_tensor_value_info(
+ "value",
+ TensorProto.FLOAT16,
+ [
+ num_tokens,
+ config.kv_num_heads * config.head_size,
+ ],
+ ),
+ ]
+ if config.rotary:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "cos_cache",
+ TensorProto.FLOAT16,
+ [
+ config.total_sequence_length,
+ (math.floor(config.head_size / 16) * 16) // 2,
+ ],
+ ),
+ helper.make_tensor_value_info(
+ "sin_cache",
+ TensorProto.FLOAT16,
+ [
+ config.total_sequence_length,
+ (math.floor(config.head_size / 16) * 16) // 2,
+ ],
+ ),
+ ]
+
+ graph_output = [
+ helper.make_tensor_value_info(
+ "output",
+ TensorProto.FLOAT16,
+ [num_tokens, config.num_heads * config.head_size],
+ ),
+ helper.make_tensor_value_info(
+ "key_cache_out",
+ TensorProto.FLOAT16,
+ [
+ num_blocks,
+ config.paged_kv_block_size,
+ config.kv_num_heads,
+ config.head_size,
+ ],
+ ),
+ helper.make_tensor_value_info(
+ "value_cache_out",
+ TensorProto.FLOAT16,
+ [
+ num_blocks,
+ config.paged_kv_block_size,
+ config.kv_num_heads,
+ config.head_size,
+ ],
+ ),
+ ]
+
+ graph = helper.make_graph(
+ nodes,
+ "PagedAttention_Graph",
+ graph_input,
+ graph_output,
+ )
+
+ model = helper.make_model(graph)
+ return model.SerializeToString()
+
+
+def rotary_options_for_current_os():
+ # Reference implementation of rotary uses triton, which is not available in Windows.
+ # So we only test rotary in Linux right now.
+ return [(False, False)] if platform.system() != "Linux" else [(True, False), (True, True), (False, False)]
+
+
+def paged_attention_func(
+ config,
+ query,
+ key,
+ value,
+ key_cache,
+ value_cache,
+ cumulative_sequence_length,
+ past_seqlens,
+ block_table,
+ cos=None,
+ sin=None,
+ window_size=-1,
+):
+ num_tokens = cumulative_sequence_length[-1].item()
+ num_blocks = key_cache.shape[0]
+ max_blocks_per_sequence = block_table.shape[1]
+ onnx_model_str = create_paged_attention_graph(
+ config,
+ num_tokens,
+ num_blocks,
+ max_blocks_per_sequence,
+ local_window_size=window_size,
+ )
+ ort_inputs = {
+ "query": query.detach().cpu().numpy(),
+ "key_cache": OrtValue.ortvalue_from_numpy(key_cache.detach().cpu().numpy(), "cuda", 0),
+ "value_cache": OrtValue.ortvalue_from_numpy(value_cache.detach().cpu().numpy(), "cuda", 0),
+ "cumulative_sequence_length": cumulative_sequence_length.detach().cpu().numpy(),
+ "past_seqlens": past_seqlens.detach().cpu().numpy(),
+ "block_table": block_table.detach().cpu().numpy(),
+ }
+ sess_options = SessionOptions()
+ ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep])
+ io_binding = ort_session.io_binding()
+ if key is not None and value is not None:
+ ort_inputs["key"] = key.detach().cpu().numpy()
+ ort_inputs["value"] = value.detach().cpu().numpy()
+ io_binding.bind_cpu_input("key", ort_inputs["key"])
+ io_binding.bind_cpu_input("value", ort_inputs["value"])
+ if cos is not None and sin is not None:
+ ort_inputs["cos_cache"] = cos.detach().cpu().numpy()
+ ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
+ io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
+ io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
+ io_binding.bind_cpu_input("query", ort_inputs["query"])
+ io_binding.bind_input(
+ "key_cache", "cuda", 0, numpy.float16, ort_inputs["key_cache"].shape(), ort_inputs["key_cache"].data_ptr()
+ )
+ io_binding.bind_input(
+ "value_cache", "cuda", 0, numpy.float16, ort_inputs["value_cache"].shape(), ort_inputs["value_cache"].data_ptr()
+ )
+ io_binding.bind_cpu_input("cumulative_sequence_length", ort_inputs["cumulative_sequence_length"])
+ io_binding.bind_cpu_input("past_seqlens", ort_inputs["past_seqlens"])
+ io_binding.bind_cpu_input("block_table", ort_inputs["block_table"])
+ io_binding.bind_output("output")
+ io_binding.bind_ortvalue_output("key_cache_out", ort_inputs["key_cache"])
+ io_binding.bind_ortvalue_output("value_cache_out", ort_inputs["value_cache"])
+ ort_session.run_with_iobinding(io_binding)
+ output, key_cache_out, value_cache_out = io_binding.copy_outputs_to_cpu()
+ output = torch.tensor(numpy.array(output))
+ return output, key_cache_out, value_cache_out
+
+
+def construct_local_mask(
+ seqlen_q,
+ seqlen_k,
+ window_size=(-1, -1), # -1 means infinite window size
+ query_padding_mask=None,
+ key_padding_mask=None,
+ device=None,
+):
+ row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
+ col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
+ sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
+ sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
+ if window_size[0] < 0:
+ return col_idx > row_idx + sk - sq + window_size[1]
+ else:
+ sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
+ return torch.logical_or(
+ col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
+ col_idx < row_idx + sk - sq - window_size[0],
+ )
+
+
+def attention_ref(
+ q,
+ k,
+ v,
+ query_padding_mask=None,
+ key_padding_mask=None,
+ dropout_p=0.0,
+ dropout_mask=None,
+ causal=False,
+ window_size=(-1, -1), # -1 means infinite window size
+ softcap=0.0,
+ upcast=True,
+ reorder_ops=False,
+ use_smooth_softmax=False,
+):
+ """
+ Arguments:
+ q: (batch_size, seqlen_q, nheads, head_dim)
+ k: (batch_size, seqlen_k, nheads_k, head_dim)
+ v: (batch_size, seqlen_k, nheads_k, head_dim)
+ query_padding_mask: (batch_size, seqlen_q)
+ key_padding_mask: (batch_size, seqlen_k)
+ dropout_p: float
+ dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
+ causal: whether to apply causal masking
+ window_size: (int, int), left and right window size
+ upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
+ output back to fp16/bf16.
+ reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
+ without changing the math. This is to estimate the numerical error from operation
+ reordering.
+ Output:
+ output: (batch_size, seqlen_q, nheads, head_dim)
+ attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
+ """
+ if causal:
+ window_size = (window_size[0], 0)
+ dtype_og = q.dtype
+ if upcast:
+ q, k, v = q.float(), k.float(), v.float()
+ seqlen_q, seqlen_k = q.shape[1], k.shape[1]
+ k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
+ v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
+ d = q.shape[-1]
+ if not reorder_ops:
+ scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
+ else:
+ scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
+ if softcap > 0:
+ scores = scores / softcap
+ scores = scores.tanh()
+ scores = scores * softcap
+ if key_padding_mask is not None:
+ scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
+ if window_size[0] >= 0 or window_size[1] >= 0:
+ local_mask = construct_local_mask(
+ seqlen_q,
+ seqlen_k,
+ window_size,
+ query_padding_mask,
+ key_padding_mask,
+ q.device,
+ )
+ scores.masked_fill_(local_mask, float("-inf"))
+
+ if use_smooth_softmax:
+ attention = smooth_softmax_ref(scores)
+ else:
+ attention = torch.softmax(scores, dim=-1)
+
+ # Some rows might be completely masked out so we fill them with zero instead of NaN
+ if window_size[0] >= 0 or window_size[1] >= 0:
+ attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
+ # We want to mask here so that the attention matrix doesn't have any NaNs
+ # Otherwise we'll get NaN in dV
+ if query_padding_mask is not None:
+ attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
+ dropout_scaling = 1.0 / (1 - dropout_p)
+ if dropout_mask is not None:
+ attention_drop = attention.masked_fill(~dropout_mask, 0.0)
+ else:
+ attention_drop = attention
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
+ if query_padding_mask is not None:
+ output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
+ return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
+
+
+def rotary_embedding(*args, **kwargs):
+ # Use local import since triton is not available in Windows.
+ from rotary_flash import apply_rotary_emb
+
+ return apply_rotary_emb(*args, **kwargs)
+
+
+def unpad_qkv(config: Config, q, k, v, cum_seqlens):
+ token_count = cum_seqlens[-1]
+ q_unpad = torch.zeros(
+ token_count,
+ config.num_heads * config.head_size,
+ dtype=torch.float16,
+ device="cuda",
+ )
+ k_unpad = torch.zeros(
+ token_count,
+ config.kv_num_heads * config.head_size,
+ dtype=torch.float16,
+ device="cuda",
+ )
+ v_unpad = torch.zeros(
+ token_count,
+ config.kv_num_heads * config.head_size,
+ dtype=torch.float16,
+ device="cuda",
+ )
+ for i in range(config.batch_size):
+ new_seqlen = cum_seqlens[i + 1] - cum_seqlens[i]
+ q_unpad[cum_seqlens[i] : cum_seqlens[i + 1]] = rearrange(q[i, :new_seqlen], "s n h -> s (n h)")
+ k_unpad[cum_seqlens[i] : cum_seqlens[i + 1]] = rearrange(k[i, :new_seqlen], "s n h -> s (n h)")
+ v_unpad[cum_seqlens[i] : cum_seqlens[i + 1]] = rearrange(v[i, :new_seqlen], "s n h -> s (n h)")
+ return q_unpad, k_unpad, v_unpad
+
+
+def generate_block_kvcache(config: Config, device, dtype):
+ num_blocks = math.ceil(config.total_sequence_length / config.paged_kv_block_size) * config.batch_size * 3
+ k_cache_paged = torch.randn(
+ num_blocks, config.paged_kv_block_size, config.kv_num_heads, config.head_size, device=device, dtype=dtype
+ )
+ v_cache_paged = torch.randn(
+ num_blocks, config.paged_kv_block_size, config.kv_num_heads, config.head_size, device=device, dtype=dtype
+ )
+ block_table = rearrange(
+ torch.randperm(num_blocks, dtype=torch.int32, device=device),
+ "(b nblocks) -> b nblocks",
+ b=config.batch_size,
+ )
+ k_cache = rearrange(
+ # pytorch 1.12 doesn't have indexing with int32
+ k_cache_paged[block_table.to(dtype=torch.long).flatten()],
+ "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+ b=config.batch_size,
+ )[:, : config.total_sequence_length]
+ v_cache = rearrange(
+ v_cache_paged[block_table.to(dtype=torch.long).flatten()],
+ "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+ b=config.batch_size,
+ )[:, : config.total_sequence_length]
+ return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged
+
+
+def parity_check_paged_attention(
+ config: Config,
+ rtol=1e-3,
+ atol=1e-3,
+):
+ # Generate padded inputs
+ q = torch.randn(
+ config.batch_size,
+ config.sequence_length,
+ config.num_heads,
+ config.head_size,
+ device="cuda",
+ dtype=torch.float16,
+ requires_grad=False,
+ )
+ k_new = torch.randn(
+ config.batch_size,
+ config.sequence_length,
+ config.kv_num_heads,
+ config.head_size,
+ device="cuda",
+ dtype=torch.float16,
+ requires_grad=False,
+ )
+ v_new = torch.randn(
+ config.batch_size,
+ config.sequence_length,
+ config.kv_num_heads,
+ config.head_size,
+ device="cuda",
+ dtype=torch.float16,
+ requires_grad=False,
+ )
+
+ # Generate random sequence lengths
+ past_seqlens = torch.randint(
+ 0,
+ config.total_sequence_length - config.sequence_length + 1, # one above highest integer to be drawn
+ (config.batch_size,),
+ dtype=torch.int32,
+ device="cuda",
+ )
+ new_seqlens = torch.randint(
+ 1,
+ config.sequence_length + 1,
+ (config.batch_size,),
+ dtype=torch.int32,
+ device="cuda",
+ )
+ cum_seqlens = torch.cat(
+ (torch.tensor([0], dtype=torch.int32, device="cuda"), torch.cumsum(new_seqlens, dim=0))
+ ).type(torch.int32)
+ total_seqlens = past_seqlens + new_seqlens
+
+ q_unpad, k_unpad, v_unpad = unpad_qkv(config, q, k_new, v_new, cum_seqlens)
+
+ # Generate kv cache and associated block-based data structures
+ k_cache, v_cache, block_table, k_cache_paged, v_cache_paged = generate_block_kvcache(config, "cuda", torch.float16)
+
+ # Set window size for local / causal
+ window_size = (-1, -1)
+ left_window_size = -1
+ if config.local:
+ left_window_size = random.randint(0, config.total_sequence_length - 1) # random.randint is inclusive
+ window_size = (left_window_size, 0)
+ else:
+ left_window_size = -1
+ window_size = (-1, 0)
+
+ # Apply rotary embedding for reference implementation
+ if config.rotary:
+ rotary_fraction = 1.0
+ rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
+ angle = torch.rand(config.total_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi
+ cos = torch.cos(angle).to(dtype=torch.float16)
+ sin = torch.sin(angle).to(dtype=torch.float16)
+ q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=past_seqlens, interleaved=config.rotary_interleaved)
+ k_ro = rotary_embedding(k_new, cos, sin, seqlen_offsets=past_seqlens, interleaved=config.rotary_interleaved)
+ else:
+ cos, sin = None, None
+ q_ro, k_ro = q, k_new
+
+ # Update reference kv cache
+ k_cache_ref = k_cache.clone()
+ v_cache_ref = v_cache.clone()
+ total_range = rearrange(torch.arange(config.total_sequence_length, device="cuda"), "s -> 1 s")
+ past_seqlens_expanded = rearrange(past_seqlens, "b -> b 1")
+ update_mask = torch.logical_and(
+ past_seqlens_expanded <= total_range, total_range < past_seqlens_expanded + config.sequence_length
+ )
+ k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
+ v_cache_ref[update_mask] = rearrange(v_new, "b s ... -> (b s) ...")
+ k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
+ v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
+
+ # Create padding masks for reference implementation
+ total_seqlens_expanded = rearrange(total_seqlens, "b -> b 1")
+ key_padding_mask = total_range < total_seqlens_expanded
+ query_range = rearrange(torch.arange(config.sequence_length, device="cuda"), "s -> 1 s")
+ new_seqlens_expanded = rearrange(new_seqlens, "b -> b 1")
+ query_padding_mask = query_range < new_seqlens_expanded
+
+ # Run reference implementation of attention
+ out_ref, _ = attention_ref(
+ q_ro,
+ k_cache_rep,
+ v_cache_rep,
+ query_padding_mask,
+ key_padding_mask,
+ 0.0,
+ None,
+ causal=True,
+ window_size=window_size,
+ softcap=config.softcap,
+ )
+ out_ref = out_ref.detach().cpu().numpy()
+
+ if config.packed:
+ q_unpad = torch.concatenate([q_unpad, k_unpad, v_unpad], dim=1)
+ k_unpad = None
+ v_unpad = None
+ out, updated_k_cache_paged, updated_v_cache_paged = paged_attention_func(
+ config,
+ q_unpad,
+ k_unpad,
+ v_unpad,
+ k_cache_paged,
+ v_cache_paged,
+ cum_seqlens,
+ past_seqlens,
+ block_table,
+ cos,
+ sin,
+ left_window_size,
+ )
+ num_tokens = q_unpad.shape[0]
+ out = torch.reshape(out, (num_tokens, config.num_heads, config.head_size))
+ out = out.detach().cpu().numpy()
+
+ err_msg = f" with {config}"
+ # Make sure past-present buffer updating correctly
+ present_k = rearrange(
+ updated_k_cache_paged[block_table.to(dtype=torch.long).flatten().cpu()],
+ "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+ b=config.batch_size,
+ )[:, : config.total_sequence_length]
+ present_v = rearrange(
+ updated_v_cache_paged[block_table.to(dtype=torch.long).flatten().cpu()],
+ "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+ b=config.batch_size,
+ )[:, : config.total_sequence_length]
+ for i in range(config.batch_size):
+ numpy.testing.assert_allclose(
+ present_k[i, : total_seqlens[i]],
+ k_cache_ref[i, : total_seqlens[i]].detach().cpu().numpy(),
+ rtol=rtol,
+ atol=atol,
+ equal_nan=True,
+ err_msg=err_msg,
+ )
+ numpy.testing.assert_allclose(
+ present_v[i, : total_seqlens[i]],
+ v_cache_ref[i, : total_seqlens[i]].detach().cpu().numpy(),
+ rtol=rtol,
+ atol=atol,
+ equal_nan=True,
+ err_msg=err_msg,
+ )
+ new_seqlen = cum_seqlens[i + 1] - cum_seqlens[i]
+ out_i = out[cum_seqlens[i] : cum_seqlens[i + 1]]
+ out_ref_i = out_ref[i, :new_seqlen]
+ numpy.testing.assert_allclose(out_i, out_ref_i, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg)
+
+
+def has_flash_attention():
+ if not torch.cuda.is_available():
+ return False
+ if "CUDAExecutionProvider" not in get_available_providers():
+ return False
+ major, _ = torch.cuda.get_device_capability()
+ return major >= 8 and (
+ platform.system() == "Linux"
+ or (platform.system() == "Windows" and version.parse(torch.version.cuda) >= version.parse("12.0"))
+ )
+
+
+def paged_attention_test_cases():
+ batches = [4] if pipeline_mode else [1, 3, 5]
+ seqs = (
+ [(1025, 2047)]
+ if pipeline_mode
+ else [
+ (3, 1024),
+ (1, 339),
+ (408, 800),
+ (333, 799),
+ (64, 2048),
+ (837, 4000),
+ (17, 49),
+ (257, 257),
+ (459, 459),
+ ]
+ )
+ num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
+ h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
+ block_sizes = [256] if pipeline_mode else [256, 512]
+
+ for b in batches:
+ for s, s2 in seqs:
+ for n, n2 in num_h:
+ for h in h_sizes:
+ for block_size in block_sizes:
+ for local in [False, True]:
+ for rotary, rotary_interleaved in rotary_options_for_current_os():
+ for packed in [False, True]:
+ for softcap in [0.0, 50.0]:
+ if rotary and h % 16 > 0:
+ continue
+
+ config = Config(
+ b,
+ s,
+ s2,
+ n,
+ n2,
+ h,
+ block_size,
+ local,
+ rotary,
+ rotary_interleaved,
+ packed,
+ softcap,
+ )
+ yield (
+ str(config),
+ config,
+ )
+
+
+@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.")
+class TestPagedAttention(unittest.TestCase):
+ @parameterized.expand(paged_attention_test_cases())
+ def test_paged_attention(self, _, config):
+ parity_check_paged_attention(config, rtol=5e-3, atol=5e-3)
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)