diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 7ba2f820e9bdb..f6cc816b45ed2 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -67,6 +67,7 @@ Do not modify directly.* * com.microsoft.PackedAttention * com.microsoft.PackedMultiHeadAttention * com.microsoft.Pad + * com.microsoft.PagedAttention * com.microsoft.QAttention * com.microsoft.QGemm * com.microsoft.QLinearAdd @@ -3683,6 +3684,100 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.PagedAttention** + + Paged Attention. + + This op leverages a block-based KV cache to enable continuous batching for LLMs. Currently, it is designed to work with + the CUDA Execution Provider only. + + In other attention ops, batch entries typically aren't of the same length, so they are padded. + Below is a batch with 3 sequences where * denotes a padding token. + Sequence_0: 0, 1*, 2*, 3* + Sequence_1: 4, 5, 6*, 7* + Sequence_2: 8, 9, 10, 11 + + PagedAttention is designed to take in packed input, i.e., only the real tokens without padding. + For example, the input shown above will be packed into 3 tensors like below: + - query ([q0, q4, q5, q8, q9, q10, q11]) + - key ([k0, k4, k5, k8, k9, k10, k11]) + - value ([v0, v4, v5, v8, v9, v10, v11]) + - cumulative_sequence_length: 0, 1, 1+2, 1+2+4 + This packing omits padding tokens. + + The query, key and value tensors contain result of hidden embedding of real tokens after input projections. + cumulative_sequence_length records cumulated length of each sequence length. + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
do_rotary : int
+
Whether to use rotary position embedding. Default value is 0.
+
kv_num_heads : int (required)
+
Number of attention heads for k and v
+
local_window_size : int
+
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
+
num_heads : int (required)
+
Number of attention heads for q
+
rotary_interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
+
scale : float
+
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
softcap : float
+
Softcap value for attention weights. Default value is 0.
+
+ +#### Inputs (8 - 10) + +
+
query : T
+
Query with shape (num_tokens, hidden_size), or packed QKV with shape (num_tokens, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).
+
key (optional) : T
+
Key with shape (num_tokens, kv_hidden_size)
+
value (optional) : T
+
Value with shape (num_tokens, kv_hidden_size)
+
key_cache : T
+
Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated in place within the op.
+
value_cache : T
+
Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated in place within the op. This should be the same shape as key_cache.
+
cumulative_sequence_length : S
+
A tensor with shape (batch_size + 1). It specifies the cumulative sequence lengths between the packed entries in Q/K/V.
+
past_seqlens : S
+
A tensor with shape (batch_size). It specifies the past lengths of cached sequence in the KV cache.
+
block_table : S
+
2D tensor with shape (batch_size, max_blocks_per_sequence) that maps each sequence in the batch to itscorresponding blocks in the KV cache.
+
cos_cache (optional) : T
+
2D tensor with shape (max total seqlen, head_size / 2).
+
sin_cache (optional) : T
+
2D tensor with shape (max total seqlen, head_size / 2).
+
+ +#### Outputs (1 - 3) + +
+
output : T
+
3D output tensor with shape (num_tokens, hidden_size)
+
key_cache_out (optional) : T
+
Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always the same tensor as key_cache.
+
value_cache_out (optional) : T
+
Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always the same tensor as value_cache.
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(bfloat16)
+
Constrain input and output to float tensors.
+
S : tensor(int32)
+
Constrain Positional inputs to int tensor.
+
+ + ### **com.microsoft.QAttention** Quantization of Multi-Head Self Attention. @@ -6345,3 +6440,5 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
+ + diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index b657c828fbde1..5154c334acc23 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -952,6 +952,7 @@ Do not modify directly.* |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|PagedAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* key_cache:**T**
*in* value_cache:**T**
*in* cumulative_sequence_length:**S**
*in* past_seqlens:**S**
*in* block_table:**S**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* key_cache_out:**T**
*out* value_cache_out:**T**|1+|**S** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 52dcb990ab67f..651f270230a75 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -227,7 +227,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->is_unidirectional = is_unidirectional_; output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr); output_parameters->do_rotary = do_rotary_; - output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_; + output_parameters->rotary_dim = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_; output_parameters->mask_filter_value = mask_filter_value_; output_parameters->scale = scale_; output_parameters->mask_type = mask_type; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index c3d5128948c6f..77d3089de5d09 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -22,11 +22,12 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; - int num_splits; - int rotary_embedding; + int num_splits; // number of splits for splitkv + int rotary_dim = 0; // rotary embedding dimension int beam_width; bool is_unidirectional; bool past_present_share_buffer; + bool is_packed_qkv = false; // whether qkv is packed bool do_rotary; bool broadcast_attn_bias_dim_0; bool broadcast_attn_bias_dim_1; @@ -46,13 +47,11 @@ struct DecoderMaskedMultiHeadAttentionParameters : AttentionParameters { int beam_width = 1; // Only NeoX style rotary embedding is supported - int rotary_embedding_dim = 0; int t_step = 0; // Weather to use multihead attention(excludes matmul and bias) bool is_mha = false; bool is_cross_attention = false; - bool is_packed_qkv = false; // Useful to better use global memory bandwidth on certain CUDA architectures. // Turned off by default for now until we fully understand performance implications @@ -83,15 +82,12 @@ struct DecoderMaskedMultiHeadAttentionParameters : AttentionParameters { // Parameters deduced from node attributes and inputs/outputs. struct GroupQueryAttentionParameters : AttentionParameters { + int kv_num_heads; // number of heads of key or value + int kv_hidden_size; // hidden size of key or value int seqlen_past_kv_cache; // sequence length of past kv tensor int seqlen_present_kv_cache; // sequence length of present kv tensor - int kv_hidden_size; - int kv_num_heads; - int num_splits; // number of splits for splitkv - int rotary_dim; // rotary embedding dimension - int local_window_size; // The window size excludes current token. It only includes tokens on the left side. + int local_window_size; // The window size excludes current token. It only includes tokens on the left side. bool kv_share_buffer; - bool is_packed_qkv; bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1 bool is_first_prompt; // indicates whether this is first decoding step bool rotary_interleaved; @@ -102,18 +98,29 @@ struct GroupQueryAttentionParameters : AttentionParameters { int* zero_ptr; }; +// Parameters deduced from node attributes and inputs/outputs. +struct PagedAttentionParameters : AttentionParameters { + int kv_num_heads; // number of heads of key or value + int kv_hidden_size; // hidden size of key or value + int token_count; // number of tokens in packed query + int block_size; // block size for kv cache + int max_num_blocks_per_seq; // max number of blocks per sequence for kv cache + int num_blocks; // number of blocks in kv cache + int local_window_size; // The window size excludes current token. It only includes tokens on the left side. + bool rotary_interleaved; + float softcap; +}; + // Parameters for sparse attention. struct SparseAttentionParameters : AttentionParameters { int kv_hidden_size; // hidden size of key or value int kv_num_heads; // number of heads of key or value bool do_rotary; // whether to use rotary embedding bool rotary_interleaved; // whether to use interleaved rotary embedding - int rotary_dim; // rotary embedding dimension int sparse_block_size; // block size for sparse attention int num_sparse_layout; // number of sparse layout int stride_col_indices; // shape of block_col_indices is [num_sparse_layout, stride_col_indices] int stride_row_indices; // shape of block_row_indices is [num_sparse_layout, stride_row_indices] - bool is_packed_qkv; // whether qkv is packed int max_rotary_sequence_length; // max sequence length for rotary cos/sin cache int max_cache_sequence_length; // max sequence length for kv cache buffer }; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index fa0d33e891f46..338c34acb3cfb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -12,6 +12,183 @@ namespace onnxruntime { namespace contrib { namespace group_query_attention_helper { +template +Status Check_Q_K_V(const T* query, const T* key, const T* value, const int num_heads, const int kv_num_heads, + int& batch_size, int& sequence_length, int& q_hidden_size, int& kv_hidden_size, int& head_size) { + const auto& query_dims = query->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", + query_dims.size()); + } + batch_size = static_cast(query_dims[0]); + sequence_length = static_cast(query_dims[1]); + q_hidden_size = static_cast(query_dims[2]); + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } else if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != key_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 1 (sequence length)"); + } + kv_hidden_size = static_cast(key_dims[2]); + if (kv_hidden_size % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_hidden_size must be a multiple of kv_num_heads. Got kv_hidden_size % kv_num_heads == ", + kv_hidden_size % kv_num_heads); + } else if (kv_hidden_size / kv_num_heads != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_hidden_size / kv_num_heads must be equal to head_size. Got kv_hidden_size / kv_num_heads == ", + kv_hidden_size / kv_num_heads); + } + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } else if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 1 (sequence length)"); + } else if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + return Status::OK(); +} + +template +Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const int kv_num_heads, int& batch_size, + int& sequence_length, int& q_hidden_size, int& kv_hidden_size, int& head_size) { + const auto& packed_dims = packed_qkv->Shape().GetDims(); + if (packed_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", + packed_dims.size()); + } + batch_size = static_cast(packed_dims[0]); + sequence_length = static_cast(packed_dims[1]); + head_size = static_cast(static_cast(packed_dims[2])) / (num_heads + 2 * kv_num_heads); + // Check packed qkv + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + return Status::OK(); +} + +template +Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_num_heads, int head_size, + int& past_sequence_length) { + const auto& past_key_dims = past_key->Shape().GetDims(); + const auto& past_value_dims = past_value->Shape().GetDims(); + + if (past_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' is expected to have 4 dimensions, got ", + past_key_dims.size()); + } + if (past_value_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' is expected to have 4 dimensions, got ", + past_value_dims.size()); + } + + if (past_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 0 should be batch_size, got ", + past_key_dims[0]); + } + if (past_value_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 0 should be batch_size, got ", + past_value_dims[0]); + } + + if (past_key_dims[2] != past_value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" + "length or past sequence length), got ", + past_key_dims[1]); + } + if (past_key_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' shall have kv_num_heads"); + } + if (past_value_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' shall have kv_num_heads"); + } + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[2]); + + if (past_key_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3]); + } + if (past_value_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + past_value_dims[3]); + } + return Status::OK(); +} + +template +Status CheckRotaryCaches(const T* cos_cache, const T* sin_cache, int head_size, int total_sequence_length, + int& rotary_dim) { + const auto& cos_dims = cos_cache->Shape().GetDims(); + const auto& sin_dims = sin_cache->Shape().GetDims(); + + if (head_size % 16 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size shall be a multiple of 16. Got head_size % 16 == ", + head_size % 16); + } + if (cos_dims[0] < total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 0 shall not be less than total_sequence_length."); + } + if (sin_dims[0] < total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 0 shall not be less than total_sequence_length."); + } + if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (cos_dims[1] != sin_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache dimension 1 must be the same."); + } + rotary_dim = static_cast(cos_dims[1] * 2); + return Status::OK(); +} + template Status CheckInputs(const T* query, const T* key, @@ -37,18 +214,6 @@ Status CheckInputs(const T* query, AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = Q_K_V_BNSH; - const bool is_packed_qkv = key == nullptr; - - const auto& query_dims = query->Shape().GetDims(); - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", - query_dims.size()); - } - - int batch_size = static_cast(query_dims[0]); - int sequence_length = static_cast(query_dims[1]); - int q_hidden_size = static_cast(query_dims[2]); - int head_size = 0; if (num_heads % kv_num_heads != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -56,115 +221,25 @@ Status CheckInputs(const T* query, num_heads % kv_num_heads); } + int batch_size = 0; + int sequence_length = 0; + int q_hidden_size = 0; int kv_hidden_size = 0; - // Check key and value when not packed + int head_size = 0; + const bool is_packed_qkv = key == nullptr; if (!is_packed_qkv) { - head_size = static_cast(q_hidden_size) / num_heads; - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } else if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != key_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 1 (sequence length)"); - } - kv_hidden_size = static_cast(key_dims[2]); - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } else if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 1 (sequence length)"); - } else if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } + ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, batch_size, sequence_length, + q_hidden_size, kv_hidden_size, head_size)); } else { - // Check packed qkv - head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - q_hidden_size = head_size * num_heads; - kv_hidden_size = head_size * kv_num_heads; + qkv_format = QKV_BS3NH; + ORT_RETURN_IF_ERROR(Check_QKV(query, value, num_heads, kv_num_heads, batch_size, sequence_length, q_hidden_size, + kv_hidden_size, head_size)); } // Check past-present KV int32_t past_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { - const auto& past_key_dims = past_key->Shape().GetDims(); - const auto& past_value_dims = past_value->Shape().GetDims(); - - if (past_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' is expected to have 4 dimensions, got ", - past_key_dims.size()); - } - if (past_value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' is expected to have 4 dimensions, got ", - past_value_dims.size()); - } - - if (past_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 0 should be batch_size, got ", - past_key_dims[0]); - } - if (past_value_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 0 should be batch_size, got ", - past_value_dims[0]); - } - - if (past_key_dims[2] != past_value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" - "length or past sequence length), got ", - past_key_dims[1]); - } - if (past_key_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' shall have kv_num_heads"); - } - if (past_value_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' shall have kv_num_heads"); - } - // We assume all sequence in past kv are right-padded to max or past sequence length - past_sequence_length = static_cast(past_key_dims[2]); - - if (past_key_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); - } - if (past_value_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3]); - } + ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, batch_size, kv_num_heads, head_size, past_sequence_length)); } else if (past_key != nullptr || past_value != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' and 'past_value' shall be both present or both absent."); @@ -186,35 +261,7 @@ Status CheckInputs(const T* query, int rotary_dim = 0; if (cos_cache != nullptr && sin_cache != nullptr) { - const auto& cos_dims = cos_cache->Shape().GetDims(); - const auto& sin_dims = sin_cache->Shape().GetDims(); - - if (head_size % 16 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size shall be a multiple of 16. Got head_size % 16 == ", - head_size % 16); - } - if (cos_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 shall not be less than total_sequence_length."); - } - if (sin_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 shall not be less than total_sequence_length."); - } - if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (cos_dims[1] != sin_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache and sin_cache dimension 1 must be the same."); - } - rotary_dim = static_cast(cos_dims[1] * 2); + ORT_RETURN_IF_ERROR(CheckRotaryCaches(cos_cache, sin_cache, head_size, total_sequence_length, rotary_dim)); } else if (cos_cache != nullptr || sin_cache != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index c7b06d50858b4..691391ccef0d0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -180,6 +180,35 @@ struct GroupQueryAttentionData { bool use_memory_efficient_attention = false; }; +template +struct PagedAttentionData { + // Input Tensors + const T* query = nullptr; + const T* key = nullptr; + const T* value = nullptr; + T* key_cache = nullptr; + T* value_cache = nullptr; + const int* cumulative_seqlens_q = nullptr; + const int* past_seqlens = nullptr; + const int* block_table = nullptr; + const int* slot_mappings = nullptr; + const T* cos_cache = nullptr; + const T* sin_cache = nullptr; + + // Flash buffers + T* softmax_lse = nullptr; + int* cumulative_seqlens_kv = nullptr; // Flash api takes cumulative sequence length for kv-cache + + // Fused op buffers + T* workspace_buffer = nullptr; + + // Output Tensors + T* output = nullptr; + + // Kernel Flags + bool use_flash_attention = false; +}; + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 122e94d9558e3..a7989df3439ae 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -150,7 +150,7 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, - 3, parameters.do_rotary, parameters.rotary_embedding, + 3, parameters.do_rotary, parameters.rotary_dim, parameters.past_sequence_length); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc index a15b59d0c018a..b4643da58eba5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc @@ -195,7 +195,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont if (do_rotary_) { ORT_ENFORCE(parameters.head_size == 64 || parameters.head_size == 128, "Current implementation of rotary embedding only supports head size of 64 or 128"); - parameters.rotary_embedding_dim = parameters.head_size; + parameters.rotary_dim = parameters.head_size; parameters.t_step = parameters.past_sequence_length; } diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 75ea7454791b6..6ba5ce66eaa60 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -212,13 +212,13 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio } } - if (params.rotary_embedding_dim > 0) { - const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; + if (params.rotary_dim > 0) { + const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_dim; T* q_smem = reinterpret_cast(smem_); - T* k_smem = q_smem + params.rotary_embedding_dim; + T* k_smem = q_smem + params.rotary_dim; - const int half_rotary_dim = params.rotary_embedding_dim / 2; + const int half_rotary_dim = params.rotary_dim / 2; const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; const int smem_pitch = half_rotary_dim; @@ -240,7 +240,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.t_step); + q, k, transpose_idx / tidx_factor, params.rotary_dim, params.t_step); write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h index 08e4293528d5a..586732834f0ad 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h @@ -165,8 +165,8 @@ inline size_t CalcDynamicBlockMemory(const DecoderMaskedMultiHeadAttentionParame size_t red_sz = rows_per_red * params.head_size * sizeof(T) / 2; size_t transpose_rotary_size = 0; - if (params.rotary_embedding_dim > 0) { - transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T); + if (params.rotary_dim > 0) { + transpose_rotary_size = 2 * params.rotary_dim * sizeof(T); } // The max. diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 4aa633ca45e2b..c24bf88fa729b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -70,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params { int seqlen_k_rounded = 0; int d_rounded = 0; int rotary_dim = 0; + int total_q = 0; // The scaling factors for the kernel. float scale_softmax = 0.0; @@ -129,6 +130,7 @@ struct Flash_fwd_params : public Qkv_params { void* __restrict__ alibi_slopes_ptr = nullptr; index_t alibi_slopes_batch_stride = 0; + bool unpadded_lse = false; const cudaDeviceProp* dprops = nullptr; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 453dffaa2e6e6..b0241c26aafc6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -41,7 +41,8 @@ void set_params_fprop(Flash_fwd_params& params, bool use_smooth_softmax, bool kv_bsnh = true, int window_size_left = -1, - int window_size_right = -1) { + int window_size_right = -1, + const bool unpadded_lse = false) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; @@ -142,6 +143,7 @@ void set_params_fprop(Flash_fwd_params& params, params.window_size_right = window_size_right; params.is_seqlens_k_cumulative = true; + params.unpadded_lse = unpadded_lse; } size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) { @@ -149,6 +151,11 @@ size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) return bytes; } +size_t get_softmax_lse_size(size_t token_count, size_t num_heads) { + size_t bytes = sizeof(float) * token_count * num_heads; + return bytes; +} + size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q) { size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; return bytes; @@ -336,6 +343,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, return Status::OK(); } +// TODO(aciddelgado): Baiju wants this https://github.com/Dao-AILab/flash-attention/pull/824 + Status mha_varlen_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // half (total_q, num_heads, head_size) @@ -353,10 +362,12 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, int head_size, int max_seqlen_q, int max_seqlen_k, + int total_q, float softmax_scale, const float softcap, bool is_causal, bool is_bf16, + int local_window_size, int max_num_blocks_per_seq, int page_block_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; @@ -384,8 +395,11 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, is_bf16, false, true, - -1, - is_causal ? 0 : -1); + local_window_size, + is_causal ? 0 : -1, + /*unpadded_lse*/ true); + + params.total_q = total_q; params.dprops = &dprops; params.num_splits = 0; params.softmax_lseaccum_ptr = nullptr; @@ -394,7 +408,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, params.vnew_ptr = nullptr; params.alibi_slopes_ptr = nullptr; if (paged_KV) { - params.block_table = block_table; // TODO(aciddelgado): cast to int pointer + params.block_table = block_table; params.block_table_batch_stride = max_num_blocks_per_seq; // params.num_blocks = num_blocks; params.page_block_size = page_block_size; @@ -406,7 +420,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, // params.num_blocks = 0; params.page_block_size = 1; } - run_mha_fwd(params, stream); + + run_mha_fwd(params, stream, paged_KV); return Status::OK(); } @@ -536,7 +551,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.alibi_slopes_ptr = nullptr; if (paged_KV) { - params.block_table = block_table; // TODO(aciddelgado): cast to int pointer + params.block_table = block_table; params.block_table_batch_stride = max_num_blocks_per_seq; // params.num_blocks = num_blocks; params.page_block_size = page_block_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 57752e8237d6e..e28e38ea3ed93 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -77,10 +77,12 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, int head_size, int max_seqlen_q, int max_seqlen_k, + int total_q, float softmax_scale, const float softcap, bool is_causal, bool is_bf16, + int local_window_size = -1, int max_num_blocks_per_seq = 0, int page_block_size = 1); @@ -121,6 +123,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int page_block_size = 1); size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads); +size_t get_softmax_lse_size(size_t token_count, size_t num_heads); std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads, size_t head_size, size_t num_SMs); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index d46d9597a758f..4110e715c4391 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -34,6 +34,26 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ auto get_lse_tile(const Params& params, const int bidb, const int bidh, const int m_block, const BlockInfo& binfo) { + // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path. + // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick. + // Otherwise, it's written as (h, b, seqlen_q). + const bool varlen_q = params.unpadded_lse; + auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0; + auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + lse_offset); + + auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q); + auto lse_stride = params.unpadded_lse + ? make_stride(params.h * params.total_q, params.total_q, 1) + : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1); + + auto lse_layout = make_layout(lse_shape, lse_stride); + Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout); + auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _); + return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); +} + template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -70,10 +90,8 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi make_stride(params.o_row_stride, params.o_head_stride, _1{})); Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), - make_shape(params.b, params.h, params.seqlen_q), - make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); - Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); @@ -375,10 +393,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi make_stride(params.o_row_stride, params.o_head_stride, _1{})); Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), - make_shape(params.b, params.h, params.seqlen_q), - make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); - Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); @@ -938,8 +953,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; - const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - + const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)) + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); @@ -1047,12 +1061,24 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { const int tidx = threadIdx.x; const int bidx = blockIdx.x; + const index_t lse_size = params.b * params.h * params.seqlen_q; + const index_t row_offset_lse = bidx * kBlockM; Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), Shape, Int>{}, - make_stride(params.b * params.h * params.seqlen_q, _1{})); + make_stride(lse_size, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); + // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. + Layout flat_layout = make_layout(lse_size); + Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); + auto transposed_stride = make_stride(1, params.seqlen_q * params.b, params.seqlen_q); + Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); + Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); + + Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // Read the LSE values from gmem and store them in shared memory, then tranpose them. @@ -1107,7 +1133,14 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? std::numeric_limits::infinity() : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { - gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + if (params.unpadded_lse) { + const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + if (lse_offset < lse_size) { + gLSE_unpadded(lse_offset) = lse_logsum; + } + } else { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } } // Store the scales exp(lse - lse_logsum) in shared memory. #pragma unroll diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 846d2be7bf2e1..07bca3f7fff99 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -589,6 +589,7 @@ Status FlashAttention( PackedMultiHeadAttentionData& data) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; + const int token_count = parameters.token_count; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; @@ -638,6 +639,7 @@ Status FlashAttention( qk_head_size, sequence_length, sequence_length, + token_count, scale, 0.0, false, // is causal diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc new file mode 100644 index 0000000000000..4189965ab9137 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/cpu/utils/dump_tensor.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/paged_attention_impl.h" +#include "contrib_ops/cuda/bert/paged_attention.h" +#include "contrib_ops/cuda/bert/paged_attention_helper.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + PagedAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("S", DataTypeImpl::GetTensorType()), \ + PagedAttention); + +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +template +PagedAttention::PagedAttention(const OpKernelInfo& info) + : CudaKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + scale_ = info.GetAttrOrDefault("scale", 0.0f); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); + + kernel_options_ = this->GetAttentionKernelOptions(); + disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); +} + +template +Status PagedAttention::ComputeInternal(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* key_cache = context->Input(3); + const Tensor* value_cache = context->Input(4); + const Tensor* cumulative_seqlens_q = context->Input(5); + const Tensor* past_seqlens = context->Input(6); + const Tensor* block_table = context->Input(7); + const Tensor* cos_cache = context->Input(8); + const Tensor* sin_cache = context->Input(9); + + auto& device_prop = GetDeviceProp(); + PagedAttentionParameters parameters; + typedef typename ToCudaType::MappedType CudaT; + PagedAttentionData data; + + // Check shapes of inputs to op and set parameters + ORT_RETURN_IF_ERROR(paged_attention_helper::CheckInputs(query, + key, + value, + key_cache, + value_cache, + cumulative_seqlens_q, + past_seqlens, + block_table, + cos_cache, + sin_cache, + ¶meters, + num_heads_, + kv_num_heads_, + scale_, + softcap_, + device_prop.maxThreadsPerBlock)); + parameters.local_window_size = local_window_size_; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + + DUMP_STRING_INIT(); + DUMP_STRING("Batch size = ", parameters.batch_size); + DUMP_STRING("Token count = ", parameters.token_count); + DUMP_STRING("Q hidden size = ", parameters.hidden_size); + DUMP_STRING("KV hidden size = ", parameters.kv_hidden_size); + DUMP_STRING("Q num heads = ", parameters.num_heads); + DUMP_STRING("KV num heads = ", parameters.kv_num_heads); + DUMP_STRING("Head size = ", parameters.head_size); + DUMP_STRING("Num blocks = ", parameters.num_blocks); + DUMP_STRING("Block size = ", parameters.block_size); + DUMP_STRING("Max num blocks per sequence = ", parameters.max_num_blocks_per_seq); + DUMP_STRING("Rotary dimension = ", parameters.rotary_dim); + DUMP_STRING("Is packed QKV = ", parameters.is_packed_qkv); + + // Check rotary + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache must be passed to PagedAttention when do_rotary = 1"); + } + + // Set output tensor shapes + TensorShapeVector output_shape(2); + output_shape[0] = static_cast(parameters.token_count); + output_shape[1] = static_cast(parameters.hidden_size); + Tensor* output = context->Output(0, output_shape); + + TensorShapeVector key_cache_out_shape(4); + key_cache_out_shape[0] = static_cast(parameters.num_blocks); + key_cache_out_shape[1] = static_cast(parameters.block_size); + key_cache_out_shape[2] = static_cast(parameters.kv_num_heads); + key_cache_out_shape[3] = static_cast(parameters.head_size); + Tensor* key_cache_out = context->Output(1, key_cache_out_shape); + + TensorShapeVector value_cache_out_shape(4); + value_cache_out_shape[0] = static_cast(parameters.num_blocks); + value_cache_out_shape[1] = static_cast(parameters.block_size); + value_cache_out_shape[2] = static_cast(parameters.kv_num_heads); + value_cache_out_shape[3] = static_cast(parameters.head_size); + Tensor* value_cache_out = context->Output(2, value_cache_out_shape); + + if (key_cache_out != nullptr && key_cache->Data() != key_cache_out->MutableData()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "key_cache and key_cache_out must be the same buffer"); + } else if (value_cache_out != nullptr && value_cache->Data() != value_cache_out->MutableData()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "value_cache and value_cache_out must be the same buffer"); + } + + // Check flash kernel availability and allocate buffers +#if USE_FLASH_ATTENTION + bool use_flash_attention = !disable_flash_attention_ && + onnxruntime::flash::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.kv_num_heads); + size_t softmax_lse_bytes = 0; + if (use_flash_attention) { + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.token_count, + parameters.num_heads); + } + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); +#else + constexpr bool use_flash_attention = false; + auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr +#endif + + if (!use_flash_attention) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Currently PagedAttention is only supported through the FlashAttention kernel."); + } + + size_t cumulative_seqlens_kv_bytes = sizeof(int) * (parameters.batch_size + 1); + auto cumulative_seqlens_kv_buffer = GetScratchBuffer(cumulative_seqlens_kv_bytes, context->GetComputeStream()); + + size_t workspace_buffer_bytes = 0; + if (do_rotary_) { + workspace_buffer_bytes = sizeof(T) * parameters.token_count * (parameters.hidden_size + parameters.kv_hidden_size); + } else if (parameters.is_packed_qkv) { + workspace_buffer_bytes = sizeof(T) * parameters.token_count * parameters.hidden_size; + } + auto workspace_buffer = GetScratchBuffer(workspace_buffer_bytes, context->GetComputeStream()); + + // Print debug info + if (kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + + debug_info.Print("PagedAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + + // Set up data struct for kernel launch + data.query = reinterpret_cast(query->Data()); + data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data()); + data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); + data.key_cache = reinterpret_cast(const_cast(key_cache->Data())); + data.value_cache = reinterpret_cast(const_cast(value_cache->Data())); + data.cumulative_seqlens_q = reinterpret_cast(cumulative_seqlens_q->Data()); + data.past_seqlens = reinterpret_cast(past_seqlens->Data()); + data.cumulative_seqlens_kv = reinterpret_cast(cumulative_seqlens_kv_buffer.get()); + data.block_table = reinterpret_cast(block_table->Data()); + data.output = reinterpret_cast(output->MutableData()); + data.use_flash_attention = use_flash_attention; + if (softmax_lse_buffer != nullptr) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + } + if (workspace_buffer != nullptr) { + data.workspace_buffer = reinterpret_cast(workspace_buffer.get()); + } + if (parameters.do_rotary) { + data.cos_cache = reinterpret_cast(cos_cache->Data()); + data.sin_cache = reinterpret_cast(sin_cache->Data()); + } + + cublasHandle_t cublas = GetCublasHandle(context); + + return QkvToContext( + device_prop, cublas, context->GetComputeStream(), parameters, data); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.h b/onnxruntime/contrib_ops/cuda/bert/paged_attention.h new file mode 100644 index 0000000000000..a3df144745f61 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/cuda/cuda_kernel.h" +#include "contrib_ops/cuda/bert/paged_attention_impl.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class PagedAttention final : public CudaKernel { + public: + PagedAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int local_window_size_; + bool do_rotary_; + bool rotary_interleaved_; + float scale_; + float softcap_; + bool disable_flash_attention_; + const AttentionKernelOptions* kernel_options_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h new file mode 100644 index 0000000000000..6fb8969aa9d0a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cpu/bert/attention_parameters.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" + +namespace onnxruntime { +namespace contrib { +namespace paged_attention_helper { + +template +Status Check_Q_K_V(const T* query, const T* key, const T* value, const int num_heads, const int kv_num_heads, + int& token_count, int& q_hidden_size, int& kv_hidden_size, int& head_size) { + const auto& query_dims = query->Shape().GetDims(); + if (query_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 2 dimensions, got ", + query_dims.size()); + } + token_count = static_cast(query_dims[0]); + q_hidden_size = static_cast(query_dims[1]); + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 2 dimensions, got ", + key_dims.size()); + } else if (token_count != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (token count)"); + } + kv_hidden_size = static_cast(key_dims[1]); + if (kv_hidden_size % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_hidden_size must be a multiple of kv_num_heads. Got kv_hidden_size % kv_num_heads == ", + kv_hidden_size % kv_num_heads); + } else if (kv_hidden_size / kv_num_heads != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_hidden_size / kv_num_heads must be equal to head_size. Got kv_hidden_size / kv_num_heads == ", + kv_hidden_size / kv_num_heads); + } + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 2 dimensions, got ", + value_dims.size()); + } else if (token_count != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (token count)"); + } else if (value_dims[1] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + return Status::OK(); +} + +template +Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const int kv_num_heads, int& token_count, + int& q_hidden_size, int& kv_hidden_size, int& head_size) { + const auto& packed_dims = packed_qkv->Shape().GetDims(); + if (packed_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 2 dimensions, got ", + packed_dims.size()); + } + token_count = static_cast(packed_dims[0]); + head_size = static_cast(static_cast(packed_dims[1])) / (num_heads + 2 * kv_num_heads); + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + return Status::OK(); +} + +template +Status CheckKVCache(const T* key_cache, const T* value_cache, const int kv_num_heads, const int head_size, + int& num_blocks, int& block_size) { + const auto& key_cache_dims = key_cache->Shape().GetDims(); + const auto& value_cache_dims = value_cache->Shape().GetDims(); + if (key_cache_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key_cache' is expected to have 4 dimensions, got ", + key_cache_dims.size()); + } + if (value_cache_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' is expected to have 4 dimensions, got ", + value_cache_dims.size()); + } + + num_blocks = static_cast(key_cache_dims[0]); + block_size = static_cast(key_cache_dims[1]); + // TODO(aciddelgado): block size multiple of 8 + if (block_size % 256 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_size must be a multiple of 256. Got block_size % 256 == ", + block_size % 256); + } + if (value_cache_dims[0] != num_blocks) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' dimension 0 should be num_blocks, got ", + value_cache_dims[0]); + } else if (value_cache_dims[1] != block_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' dimension 1 should be block_size, got ", + value_cache_dims[0]); + } + + if (key_cache_dims[2] != value_cache_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key_cache' and 'value_cache' dimension 2 (kv num heads) should be the same, got ", + key_cache_dims[2], " and ", value_cache_dims[2]); + } + if (key_cache_dims[2] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key_cache' shall have kv_num_heads, got ", + key_cache_dims[2]); + } + if (value_cache_dims[2] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' shall have kv_num_heads, got ", + value_cache_dims[2]); + } + + if (key_cache_dims[3] != value_cache_dims[3]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key_cache' and 'value_cache' dimension 3 (head size) should be the same, got ", + key_cache_dims[3], " and ", value_cache_dims[3]); + } + if (key_cache_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key_cache' dimension 3 should be same as head_size, got ", + key_cache_dims[3]); + } + if (value_cache_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + value_cache_dims[3]); + } + return Status::OK(); +} + +template +Status CheckSequenceLengthTensors(const T* cumulative_sequence_length, const T* seqlens, int& batch_size) { + const auto& cumulative_seqlen_dim = cumulative_sequence_length->Shape().GetDims(); + if (cumulative_seqlen_dim.size() != 1 || cumulative_seqlen_dim[0] < 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cumulative_sequence_length must be shape (batch_size + 1)."); + } + batch_size = static_cast(cumulative_seqlen_dim[0]) - 1; + + const auto& seqlens_dim = seqlens->Shape().GetDims(); + if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens must be shape (batch_size)."); + } + return Status::OK(); +} + +template +Status CheckBlockTable(const T* block_table, const int batch_size, int& max_num_blocks_per_seq) { + const auto& block_table_dims = block_table->Shape().GetDims(); + if (block_table_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_table must be 2D."); + } else if (block_table_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_table dimension 0 should be batch_size, got ", + block_table_dims[0]); + } + max_num_blocks_per_seq = static_cast(block_table_dims[1]); + return Status::OK(); +} + +template +Status CheckInputs(const T* query, + const T* key, + const T* value, + const T* key_cache, + const T* value_cache, + const T* cumulative_sequence_length, + const T* seqlens, + const T* block_table, + const T* cos_cache, + const T* sin_cache, + void* parameters, + int num_heads, + int kv_num_heads, + float scale, + float softcap, + int max_threads_per_block) { + if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); + } + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } + + // Check query, key, and value + int token_count = 0; + int q_hidden_size = 0; + int kv_hidden_size = 0; + int head_size = 0; + const bool is_packed_qkv = key == nullptr; + if (!is_packed_qkv) { + ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, token_count, q_hidden_size, + kv_hidden_size, head_size)); + } else { + ORT_RETURN_IF_ERROR(Check_QKV(query, value, num_heads, kv_num_heads, token_count, q_hidden_size, kv_hidden_size, + head_size)); + } + + // Check KV-Cache + int num_blocks = 0; + int block_size = 0; + ORT_RETURN_IF_ERROR(CheckKVCache(key_cache, value_cache, kv_num_heads, head_size, num_blocks, block_size)); + + // Check sequence length tensors + int batch_size = 0; + ORT_RETURN_IF_ERROR(CheckSequenceLengthTensors(cumulative_sequence_length, seqlens, batch_size)); + + // Check block table and slot mappings + int max_num_blocks_per_seq = 0; + ORT_RETURN_IF_ERROR(CheckBlockTable(block_table, batch_size, max_num_blocks_per_seq)); + + // Check rotary cache + int rotary_dim = 0; + if (cos_cache != nullptr && sin_cache != nullptr) { + // 0 to bypass checking rotary cache size + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckRotaryCaches(cos_cache, sin_cache, head_size, + 0, rotary_dim)); + } else if (cos_cache != nullptr || sin_cache != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); + } + + if (parameters != nullptr) { + PagedAttentionParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->token_count = token_count; + output_parameters->hidden_size = q_hidden_size; + output_parameters->kv_hidden_size = kv_hidden_size; + output_parameters->num_heads = num_heads; + output_parameters->kv_num_heads = kv_num_heads; + output_parameters->head_size = head_size; + output_parameters->block_size = block_size; + output_parameters->max_num_blocks_per_seq = max_num_blocks_per_seq; + output_parameters->num_blocks = num_blocks; + output_parameters->rotary_dim = rotary_dim; + output_parameters->is_packed_qkv = is_packed_qkv; + output_parameters->scale = scale; + output_parameters->softcap = softcap; + } + + return Status::OK(); +} + +} // namespace paged_attention_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu new file mode 100644 index 0000000000000..7ecdf51bdde11 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu @@ -0,0 +1,378 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/paged_attention_impl.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +////////// Auxiliary Kernels + +template +__global__ void UnpackQKVCumulative(const T* packed_qkv, T* unpacked_qkv, const int token_count, const int num_heads, + const int kv_num_heads, const int head_size) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= token_count * (num_heads + 2 * kv_num_heads) * head_size) { + return; + } + const int q_hidden_size = num_heads * head_size; + const int kv_hidden_size = kv_num_heads * head_size; + const int in_seq_stride = q_hidden_size + 2 * kv_hidden_size; + + int packed_i; + if (tid < token_count * q_hidden_size) { + const int token_id = tid / q_hidden_size; + const int offset = tid % q_hidden_size; + packed_i = token_id * in_seq_stride + offset; + } else if (tid < token_count * (q_hidden_size + kv_hidden_size)) { + const int id = tid - token_count * q_hidden_size; + const int token_id = id / kv_hidden_size; + const int offset = id % kv_hidden_size; + packed_i = token_id * in_seq_stride + q_hidden_size + offset; + } else if (tid < token_count * (q_hidden_size + 2 * kv_hidden_size)) { + const int id = tid - token_count * (q_hidden_size + kv_hidden_size); + const int token_id = id / kv_hidden_size; + const int offset = id % kv_hidden_size; + packed_i = token_id * in_seq_stride + q_hidden_size + kv_hidden_size + offset; + } + unpacked_qkv[tid] = packed_qkv[packed_i]; +} + +// Since QKV is unpacked into a single workspace buffer, this is similar to a transpose +template +Status LaunchUnpackQKVCumulative(const T* packed_qkv, T* unpacked_qkv, const int token_count, const int num_heads, + const int kv_num_heads, const int head_size, cudaStream_t stream, + const int max_threads_per_block) { + const int threads = max_threads_per_block; + const int blocks = (token_count * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads; + UnpackQKVCumulative<<>>(packed_qkv, unpacked_qkv, token_count, num_heads, kv_num_heads, + head_size); + return CUDA_CALL(cudaGetLastError()); +} + +template +__global__ void UnpackV(const T* input, T* output, const int token_count, const int hidden_size, + const int packed_seq_stride) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < token_count * hidden_size) { + int offset = tid % hidden_size; + int token_id = tid / hidden_size; + int packed_i = token_id * packed_seq_stride + offset; + output[tid] = input[packed_i]; + } +} + +template +Status LaunchUnpackCumulative(const T* input, T* output, const int token_count, const int hidden_size, + const int packed_seq_stride, cudaStream_t stream, const int max_threads_per_block) { + const int threads = std::min(max_threads_per_block, token_count * hidden_size); + const int blocks = (token_count * hidden_size + threads - 1) / threads; + UnpackV<<>>(input, output, token_count, hidden_size, packed_seq_stride); + return CUDA_CALL(cudaGetLastError()); +} + +template +__global__ void RotaryEmbeddingTNH(T* output, // TxNxH + const T* input, // TxNxH + const T* cos_cache, // Mx(H/2) + const T* sin_cache, // Mx(H/2) + const int32_t* past_seqlens, // B + const int32_t* cumulative_seqlens_q, // B+1 + const int head_size, + const int rotary_embedding_dim, + const bool interleaved, + const int3 in_strides, // TxNxH + const int3 out_strides) { // TxNxH + // Use .x in innermost loop to access global memory efficiently + + const int b = blockIdx.y; + const int s = blockIdx.x; + const int n = blockIdx.z; + const int h = threadIdx.x; + + const int sequence_length = cumulative_seqlens_q[b + 1] - cumulative_seqlens_q[b]; + if (h >= head_size || s >= sequence_length) { + return; + } + + const int t = cumulative_seqlens_q[b] + s; // t is the index of the token in the unpadded input/output + const T* input_data = input + t * in_strides.x + n * in_strides.y; + T* output_data = output + t * out_strides.x + n * out_strides.y; + + if (h >= rotary_embedding_dim) { + output_data[h] = input_data[h]; + return; + } + + // Cache is (M, H/2) + const int half_rotary_embedding_dim = rotary_embedding_dim / 2; + const int position_id = past_seqlens[b] + s; + const int cache_offset = position_id * half_rotary_embedding_dim; + const T* cos_data = cos_cache + cache_offset; + const T* sin_data = sin_cache + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + if (interleaved) { + cache_idx = (h / 2) % half_rotary_embedding_dim; + sign = (h % 2 == 0) ? -1 : 1; + j = (h % 2 == 0) ? h + 1 : h - 1; // i - sign + } else { + cache_idx = h % half_rotary_embedding_dim; + sign = (h < half_rotary_embedding_dim) ? -1 : 1; + j = (h + half_rotary_embedding_dim) % rotary_embedding_dim; + } + output_data[h] = input_data[h] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; +} + +template +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int32_t* past_seqlens, + const int32_t* cumulative_seqlens_q, const T* cos_cache, const T* sin_cache, + const int batch_size, const int max_seqlen_q, const int num_heads, + const int head_size, const int rotary_embedding_dim, const bool interleaved, + const int in_seq_stride, const int max_threads_per_block) { + ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); + int3 in_strides = {in_seq_stride <= 0 ? num_heads * head_size : in_seq_stride, head_size, 1}; + int3 out_strides = {num_heads * head_size, head_size, 1}; + int tpb = (head_size + 31) / 32 * 32; + + const dim3 grid(max_seqlen_q, batch_size, num_heads); + const dim3 block(tpb); + RotaryEmbeddingTNH<<>>( + output, input, cos_cache, sin_cache, past_seqlens, cumulative_seqlens_q, head_size, rotary_embedding_dim, + interleaved, in_strides, out_strides); + return CUDA_CALL(cudaGetLastError()); +} + +template +__global__ void GetCumulativeSeqlensKV(int32_t* cumulative_seqlens_kv, const int32_t* cumulative_seqlens_q, + const int32_t* past_seqlens, const int batch_size) { + int id = blockDim.x * blockIdx.x + threadIdx.x; + + if (id == 0) { + cumulative_seqlens_kv[0] = 0; + } + + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + // Sum past_seqlens to new sequence length (which we get by subtracting cumulative_seqlens_q). + // Then do an inclusive sum across present sequence lengths to get the cumulative sequence length + if (id < batch_size) { + cumulative_seqlens_kv[id + 1] = past_seqlens[id] + cumulative_seqlens_q[id + 1] - cumulative_seqlens_q[id]; + int length = cumulative_seqlens_kv[id + 1]; + BlockScan(temp_storage).InclusiveSum(length, length); + cumulative_seqlens_kv[id + 1] = length; + } +} + +Status LaunchGetCumulativeSeqlensKV(int32_t* cumulative_seqlens_kv, const int32_t* cumulative_seqlens_q, + const int32_t* past_seqlens, const int batch_size, cudaStream_t stream) { + const int threads = 256; + const int blocks = (batch_size + threads - 1) / threads; + GetCumulativeSeqlensKV<256><<>>(cumulative_seqlens_kv, cumulative_seqlens_q, past_seqlens, + batch_size); + return CUDA_CALL(cudaGetLastError()); +} + +template +__global__ void ReshapeAndCache(const T* __restrict__ key, const T* __restrict__ value, T* __restrict__ key_cache, + T* __restrict__ value_cache, const int* __restrict__ block_table, + const int* __restrict__ past_seqlens, const int* __restrict__ cumulative_seqlens_q, + const int batch_size, const int max_num_blocks_per_seq, const int token_count, + const int kv_hidden_size, const int block_size, const int key_stride, + const int value_stride) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= token_count * kv_hidden_size) { + return; + } + const int token_id = tid / kv_hidden_size; + const int hidden_offset = tid % kv_hidden_size; + int batch_id = 0; + for (int i = 0; i < batch_size; ++i) { + if (token_id < cumulative_seqlens_q[i + 1]) { + batch_id = i; + break; + } + } + const int token_offset = token_id - cumulative_seqlens_q[batch_id]; + const int past_length = past_seqlens[batch_id]; + const int block_id = block_table[batch_id * max_num_blocks_per_seq + (past_length + token_offset) / block_size]; + const int block_offset = (past_length + token_offset) % block_size; + + const int key_id = token_id * key_stride + hidden_offset; + const int value_id = token_id * value_stride + hidden_offset; + const int dst_id = block_id * block_size * kv_hidden_size + block_offset * kv_hidden_size + hidden_offset; + key_cache[dst_id] = key[key_id]; + value_cache[dst_id] = value[value_id]; +} + +template +Status LaunchReshapeAndCache(const T* key, const T* value, T* key_cache, T* value_cache, const int* block_table, + const int* past_seqlens, const int* cumulative_seqlens_q, const int batch_size, + const int max_num_blocks_per_seq, const int token_count, const int kv_hidden_size, + const int block_size, const int key_stride, const int value_stride, cudaStream_t stream, + const int max_threads_per_block) { + const int total_size = token_count * kv_hidden_size; + const int threads(std::min(total_size, max_threads_per_block)); + const int blocks((total_size + threads - 1) / threads); + ReshapeAndCache<<>>(key, value, key_cache, value_cache, block_table, past_seqlens, + cumulative_seqlens_q, batch_size, max_num_blocks_per_seq, + token_count, kv_hidden_size, block_size, key_stride, value_stride); + return CUDA_CALL(cudaGetLastError()); +} + +////////// Launch Kernels + +#if USE_FLASH_ATTENTION +template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::PagedAttentionParameters& parameters, + PagedAttentionData& data, + float scale) { + // Get parameters + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int token_count = parameters.token_count; + const int q_hidden_size = parameters.hidden_size; + const int kv_hidden_size = parameters.kv_hidden_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + const float softcap = parameters.softcap; + bool is_bf16 = std::is_same::value; + const int local_window_size = parameters.local_window_size; + const int max_num_blocks_per_seq = parameters.max_num_blocks_per_seq; + const int block_size = parameters.block_size; + // The following are passed to flash api but not used by the kernel, so they can be determined heuristically + const int max_query_len = token_count - batch_size + 1; + const int max_seq_len = parameters.max_num_blocks_per_seq * parameters.block_size; + + T* query = const_cast(data.query); + T* key; + T* value; + if (!parameters.is_packed_qkv) { + key = const_cast(data.key); + value = const_cast(data.value); + } else { + key = reinterpret_cast(query) + static_cast(num_heads * head_size); + value = reinterpret_cast(key) + static_cast(kv_num_heads * head_size); + } + + // Calculate cumulative present sequence length in cumulative_seqlens_kv + int* cumulative_seqlens_q = const_cast(data.cumulative_seqlens_q); + int* past_seqlens = const_cast(data.past_seqlens); + int* cumulative_seqlens_kv = data.cumulative_seqlens_kv; + ORT_RETURN_IF_ERROR(LaunchGetCumulativeSeqlensKV(cumulative_seqlens_kv, cumulative_seqlens_q, past_seqlens, + batch_size, stream)); + + if (parameters.do_rotary) { + // Will unpack Q and K in case of packed_qkv + auto q_buffer = data.workspace_buffer; + auto k_buffer = data.workspace_buffer + token_count * num_heads * head_size; + const int packed_seq_stride = parameters.is_packed_qkv ? (num_heads + 2 * kv_num_heads) * head_size : -1; + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( + stream, q_buffer, query, past_seqlens, cumulative_seqlens_q, data.cos_cache, data.sin_cache, batch_size, + max_query_len, num_heads, head_size, parameters.rotary_dim, parameters.rotary_interleaved, packed_seq_stride, + max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( + stream, k_buffer, key, past_seqlens, cumulative_seqlens_q, data.cos_cache, data.sin_cache, batch_size, + max_query_len, kv_num_heads, head_size, parameters.rotary_dim, parameters.rotary_interleaved, packed_seq_stride, + max_threads_per_block)); + query = q_buffer; + key = k_buffer; + } else if (parameters.is_packed_qkv) { + // Only unpack Q. K and V are unpacked by ReshapeAndCache. + auto q_buffer = data.workspace_buffer; + const int packed_seq_stride = q_hidden_size + 2 * kv_hidden_size; + ORT_RETURN_IF_ERROR(LaunchUnpackCumulative( + query, q_buffer, token_count, q_hidden_size, packed_seq_stride, stream, max_threads_per_block)); + query = q_buffer; + } + + // Insert key and value into block-based KV cache + int* block_table = const_cast(data.block_table); + const int key_stride = parameters.is_packed_qkv && !parameters.do_rotary ? q_hidden_size + 2 * kv_hidden_size : kv_hidden_size; + const int value_stride = parameters.is_packed_qkv ? q_hidden_size + 2 * kv_hidden_size : kv_hidden_size; + ORT_RETURN_IF_ERROR(LaunchReshapeAndCache(key, value, data.key_cache, data.value_cache, block_table, past_seqlens, + cumulative_seqlens_q, batch_size, max_num_blocks_per_seq, token_count, + kv_hidden_size, block_size, key_stride, value_stride, stream, + max_threads_per_block)); + + // Launch kernel + void* q = reinterpret_cast(query); + void* key_cache = reinterpret_cast(data.key_cache); + void* value_cache = reinterpret_cast(data.value_cache); + void* output = reinterpret_cast(data.output); + void* softmax_lse = reinterpret_cast(data.softmax_lse); + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_varlen_fwd( + device_prop, stream, q, key_cache, value_cache, output, cumulative_seqlens_q, cumulative_seqlens_kv, + /*seqused_k*/ nullptr, block_table, softmax_lse, batch_size, num_heads, kv_num_heads, head_size, max_query_len, + max_seq_len, token_count, scale, softcap, /*is_causal*/ true, is_bf16, local_window_size, max_num_blocks_per_seq, + block_size)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("flash attention output", data.output, token_count, num_heads, head_size); + + return Status::OK(); +} +#endif + +////////// API Functions + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& /*cublas*/, + Stream* ort_stream, + contrib::PagedAttentionParameters& parameters, + PagedAttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; + +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Paged Attention not implemented."); +} + +template struct PagedAttentionData; +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::PagedAttentionParameters& parameters, + PagedAttentionData& data); + +template struct PagedAttentionData; +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::PagedAttentionParameters& parameters, + PagedAttentionData& data); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h new file mode 100644 index 0000000000000..7e27556a5c63f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cpu/bert/attention_parameters.h" +#include "contrib_ops/cuda/bert/attention_data.h" +#include "core/framework/allocator.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* stream, + contrib::PagedAttentionParameters& parameters, + PagedAttentionData& data); + +template +Status LaunchUnpackQKVCumulative(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int token_count, cudaStream_t stream, + const int max_threads_per_block); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 17f3433aed38a..d016d50d6c445 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -99,6 +99,8 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_MLFloat16, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, GroupQueryAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PagedAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, PagedAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderAttention); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, int32_t, DynamicSlice); @@ -311,6 +313,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 238dd8d4573de..f2757c2c96471 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1206,6 +1206,202 @@ ONNX_MS_OPERATOR_SET_SCHEMA( GroupQueryAttentionTypeAndShapeInference(ctx, 3); })); +constexpr const char* PagedAttention_ver1_doc = R"DOC( +Paged Attention. + +This op leverages a block-based KV cache to enable continuous batching for LLMs. Currently, it is designed to work with +the CUDA Execution Provider only. + +In other attention ops, batch entries typically aren't of the same length, so they are padded. +Below is a batch with 3 sequences where * denotes a padding token. + Sequence_0: 0, 1*, 2*, 3* + Sequence_1: 4, 5, 6*, 7* + Sequence_2: 8, 9, 10, 11 + +PagedAttention is designed to take in packed input, i.e., only the real tokens without padding. +For example, the input shown above will be packed into 3 tensors like below: + - query ([q0, q4, q5, q8, q9, q10, q11]) + - key ([k0, k4, k5, k8, k9, k10, k11]) + - value ([v0, v4, v5, v8, v9, v10, v11]) + - cumulative_sequence_length: 0, 1, 1+2, 1+2+4 +This packing omits padding tokens. + +The query, key and value tensors contain result of hidden embedding of real tokens after input projections. +cumulative_sequence_length records cumulated length of each sequence length. + +)DOC"; + +// Shape inference for PagedAttention. Here are the shapes of inputs and output: +// When Q, K and V are not packed: +// Input 'query': (token_count, hidden_size) +// Input 'key': (token_count, kv_hidden_size) +// Input 'value': (token_count, kv_hidden_size) +// When Q, K and V are packed: +// Input 'query': (token_count, (num_heads + 2 * kv_num_heads) * head_size) +// Input 'key': None +// Input 'value': None +// Input 'key_cache': (num_blocks, block_size, kv_num_heads, head_size) +// Input 'value_cache': (num_blocks, block_size, kv_num_heads, head_size) +// Input 'cumulative_sequence_length': (batch_size + 1) +// Input 'seqlens': (batch_size) +// Input 'block_table': (batch_size, max_blocks_per_sequence) +// Input 'cos_cache': (max_seq_len, head_size / 2) +// Input 'sin_cache': (max_seq_len, head_size / 2) +// Output 'output': (token_count, hidden_size) +// Output 'key_cache_out': (num_blocks, block_size, kv_num_heads, head_size) +// Output 'value_cache_out': (num_blocks, block_size, kv_num_heads, head_size) +void PagedAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + + // Shape inference for output tensor + if (hasInputShape(ctx, 0)) { + auto& query_shape = getInputShape(ctx, 0); + auto& query_dims = query_shape.dim(); + + if (query_dims.size() != 2) { + fail_shape_inference("Input 0 (query) shall be 2 dimensions"); + } + + if (ctx.hasInput(2)) { + ONNX_NAMESPACE::TensorShapeProto output_shape; + propagateShapeFromInputToOutput(ctx, 0, 0); + } else { // packed QKV + ONNX_NAMESPACE::TensorShapeProto output_shape; + *output_shape.add_dim() = query_dims[0]; + int64_t num_heads = getAttribute(ctx, "num_heads", 0); + int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0); + int64_t hidden_size = query_dims[1].dim_value(); + if (hidden_size <= 0 || num_heads <= 0 || kv_num_heads < 0) { + fail_shape_inference("Invalid hidden size or number of heads. Hidden size, num_heads and kv_num_heads must be positive integers."); + } else if (hidden_size % (num_heads + 2 * kv_num_heads) != 0) { + fail_shape_inference("Hidden size must be divisible by (num_heads + 2 * kv_num_heads)."); + } + int64_t head_size = hidden_size / (num_heads + 2 * kv_num_heads); + output_shape.add_dim()->set_dim_value(head_size * num_heads); + updateOutputShape(ctx, 0, output_shape); + } + } + + // Shape inference for KV Cache output tensors + if (ctx.getNumOutputs() > 1) { // has kv cache output + if (ctx.getNumOutputs() != 3) { + fail_shape_inference("Key cache and value cache output tensors must be both present or both absent."); + } + // types + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); + // shapes + auto& key_cache_shape = getInputShape(ctx, 3); + auto& key_cache_dims = key_cache_shape.dim(); + if (key_cache_dims.size() != 4) { + fail_shape_inference("The block-based KV cache inputs shall be 4 dimensions"); + } + // KV cache in and out share the same buffer, thus they have the same shape + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 3, 1); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 4, 2); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 3, 1); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 4, 2); + } +} + +ONNX_MS_OPERATOR_SET_SCHEMA( + PagedAttention, 1, + OpSchema() + .SetDoc(PagedAttention_ver1_doc) + .Attr("num_heads", "Number of attention heads for q", AttributeProto::INT) + .Attr("kv_num_heads", "Number of attention heads for k and v", AttributeProto::INT) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Attr("softcap", + "Softcap value for attention weights. Default value is 0.", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Attr("local_window_size", + "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", + AttributeProto::INT, + static_cast(-1)) + .Attr("do_rotary", + "Whether to use rotary position embedding. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("rotary_interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) + .Input(0, + "query", + "Query with shape (num_tokens, hidden_size), or packed QKV with shape (num_tokens, d) " + "where d is (num_heads * head_size + 2 * kv_num_heads * head_size).", + "T") + .Input(1, + "key", + "Key with shape (num_tokens, kv_hidden_size) ", + "T", + OpSchema::Optional) + .Input(2, + "value", + "Value with shape (num_tokens, kv_hidden_size)", + "T", + OpSchema::Optional) + .Input(3, + "key_cache", + "Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated in " + "place within the op.", + "T") + .Input(4, + "value_cache", + "Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated " + "in place within the op. This should be the same shape as key_cache.", + "T") + .Input(5, + "cumulative_sequence_length", + "A tensor with shape (batch_size + 1). It specifies the cumulative sequence lengths between the packed " + "entries in Q/K/V.", + "S") + .Input(6, + "past_seqlens", + "A tensor with shape (batch_size). It specifies the past lengths of cached sequence in the KV cache.", + "S") + .Input(7, + "block_table", + "2D tensor with shape (batch_size, max_blocks_per_sequence) that maps each sequence in the batch to its" + "corresponding blocks in the KV cache.", + "S") + .Input(8, + "cos_cache", + "2D tensor with shape (max total seqlen, head_size / 2).", + "T", + OpSchema::Optional) + .Input(9, + "sin_cache", + "2D tensor with shape (max total seqlen, head_size / 2).", + "T", + OpSchema::Optional) + .Output(0, + "output", + "3D output tensor with shape (num_tokens, hidden_size)", + "T") + .Output(1, + "key_cache_out", + "Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always " + "the same tensor as key_cache.", + "T", + OpSchema::Optional) + .Output(2, + "value_cache_out", + "Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always " + "the same tensor as value_cache.", + "T", + OpSchema::Optional) + .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") + .TypeConstraint("S", {"tensor(int32)"}, "Constrain Positional inputs to int tensor.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + PagedAttentionTypeAndShapeInference(ctx); + })); + constexpr const char* SparseAttention_ver1_doc = R"DOC( Block Sparse Attention used in Phi-3-small (https://arxiv.org/pdf/2404.14219). diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index a9a89f756b071..6c20aae94d132 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -87,6 +87,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MoE); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, PagedAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); @@ -197,6 +198,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py new file mode 100644 index 0000000000000..cc9a02a8074c0 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py @@ -0,0 +1,735 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import math +import platform +import random +import unittest + +import numpy +import torch +from einops import rearrange, repeat +from onnx import TensorProto, helper +from packaging import version +from parameterized import parameterized +from test_gqa_cpu import smooth_softmax_ref + +from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers + +torch.manual_seed(0) + +pipeline_mode = True # Reduces number of tests so pipeline doesn't time out + + +class Config: + batch_size = 0 + sequence_length = 0 + total_sequence_length = 0 + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + paged_kv_block_size = 0 + local = False + rotary = False + rotary_interleaved = False + packed = False + softcap = 0.0 + ep = "CUDAExecutionProvider" + + def __init__( + self, + batch_size, + sequence_length, + total_sequence_length, + num_heads, + kv_num_heads, + head_size, + paged_kv_block_size, + local, + rotary, + rotary_interleaved, + packed, + softcap, + ): + self.batch_size = batch_size + self.sequence_length = sequence_length + self.total_sequence_length = total_sequence_length + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + self.head_size = head_size + self.paged_kv_block_size = paged_kv_block_size + self.local = local + self.rotary = rotary + self.rotary_interleaved = rotary_interleaved + self.packed = packed + self.softcap = softcap + + def __repr__(self): + short_ep = self.ep[: -len("ExecutionProvider")].lower() + return ( + f"Config(batch_size={self.batch_size}, sequence_length={self.sequence_length}, " + f"total_sequence_length={self.total_sequence_length}, num_heads={self.num_heads}, " + f"kv_num_heads={self.kv_num_heads}, head_size={self.head_size}, " + f"paged_kv_block_size={self.paged_kv_block_size} rotary={self.rotary}, " + f"rotary_interleaved={self.rotary_interleaved}, packed={self.packed}, softcap={self.softcap}, " + f"ep={short_ep})" + ) + + +def create_paged_attention_graph( + config, + num_tokens, + num_blocks, + max_blocks_per_sequence, + local_window_size=-1, +): + nodes = [ + helper.make_node( + "PagedAttention", + [ + "query", + "key" if not config.packed else "", + "value" if not config.packed else "", + "key_cache", + "value_cache", + "cumulative_sequence_length", + "past_seqlens", + "block_table", + "cos_cache" if config.rotary else "", + "sin_cache" if config.rotary else "", + ], + ["output", "key_cache_out", "value_cache_out"], + "PagedAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, + do_rotary=config.rotary, + rotary_interleaved=config.rotary_interleaved, + softcap=config.softcap, + domain="com.microsoft", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + num_tokens, + (config.num_heads * config.head_size) + if not config.packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), + ], + ), + helper.make_tensor_value_info( + "key_cache", + TensorProto.FLOAT16, + [ + num_blocks, + config.paged_kv_block_size, + config.kv_num_heads, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "value_cache", + TensorProto.FLOAT16, + [ + num_blocks, + config.paged_kv_block_size, + config.kv_num_heads, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "cumulative_sequence_length", + TensorProto.INT32, + [config.batch_size + 1], + ), + helper.make_tensor_value_info( + "past_seqlens", + TensorProto.INT32, + [config.batch_size], + ), + helper.make_tensor_value_info( + "block_table", + TensorProto.INT32, + [config.batch_size, max_blocks_per_sequence], + ), + ] + if not config.packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + num_tokens, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + num_tokens, + config.kv_num_heads * config.head_size, + ], + ), + ] + if config.rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.total_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.total_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [num_tokens, config.num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "key_cache_out", + TensorProto.FLOAT16, + [ + num_blocks, + config.paged_kv_block_size, + config.kv_num_heads, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "value_cache_out", + TensorProto.FLOAT16, + [ + num_blocks, + config.paged_kv_block_size, + config.kv_num_heads, + config.head_size, + ], + ), + ] + + graph = helper.make_graph( + nodes, + "PagedAttention_Graph", + graph_input, + graph_output, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def rotary_options_for_current_os(): + # Reference implementation of rotary uses triton, which is not available in Windows. + # So we only test rotary in Linux right now. + return [(False, False)] if platform.system() != "Linux" else [(True, False), (True, True), (False, False)] + + +def paged_attention_func( + config, + query, + key, + value, + key_cache, + value_cache, + cumulative_sequence_length, + past_seqlens, + block_table, + cos=None, + sin=None, + window_size=-1, +): + num_tokens = cumulative_sequence_length[-1].item() + num_blocks = key_cache.shape[0] + max_blocks_per_sequence = block_table.shape[1] + onnx_model_str = create_paged_attention_graph( + config, + num_tokens, + num_blocks, + max_blocks_per_sequence, + local_window_size=window_size, + ) + ort_inputs = { + "query": query.detach().cpu().numpy(), + "key_cache": OrtValue.ortvalue_from_numpy(key_cache.detach().cpu().numpy(), "cuda", 0), + "value_cache": OrtValue.ortvalue_from_numpy(value_cache.detach().cpu().numpy(), "cuda", 0), + "cumulative_sequence_length": cumulative_sequence_length.detach().cpu().numpy(), + "past_seqlens": past_seqlens.detach().cpu().numpy(), + "block_table": block_table.detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) + io_binding = ort_session.io_binding() + if key is not None and value is not None: + ort_inputs["key"] = key.detach().cpu().numpy() + ort_inputs["value"] = value.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None and sin is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_input( + "key_cache", "cuda", 0, numpy.float16, ort_inputs["key_cache"].shape(), ort_inputs["key_cache"].data_ptr() + ) + io_binding.bind_input( + "value_cache", "cuda", 0, numpy.float16, ort_inputs["value_cache"].shape(), ort_inputs["value_cache"].data_ptr() + ) + io_binding.bind_cpu_input("cumulative_sequence_length", ort_inputs["cumulative_sequence_length"]) + io_binding.bind_cpu_input("past_seqlens", ort_inputs["past_seqlens"]) + io_binding.bind_cpu_input("block_table", ort_inputs["block_table"]) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("key_cache_out", ort_inputs["key_cache"]) + io_binding.bind_ortvalue_output("value_cache_out", ort_inputs["value_cache"]) + ort_session.run_with_iobinding(io_binding) + output, key_cache_out, value_cache_out = io_binding.copy_outputs_to_cpu() + output = torch.tensor(numpy.array(output)) + return output, key_cache_out, value_cache_out + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + use_smooth_softmax=False, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores = scores / softcap + scores = scores.tanh() + scores = scores * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + + if use_smooth_softmax: + attention = smooth_softmax_ref(scores) + else: + attention = torch.softmax(scores, dim=-1) + + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def rotary_embedding(*args, **kwargs): + # Use local import since triton is not available in Windows. + from rotary_flash import apply_rotary_emb + + return apply_rotary_emb(*args, **kwargs) + + +def unpad_qkv(config: Config, q, k, v, cum_seqlens): + token_count = cum_seqlens[-1] + q_unpad = torch.zeros( + token_count, + config.num_heads * config.head_size, + dtype=torch.float16, + device="cuda", + ) + k_unpad = torch.zeros( + token_count, + config.kv_num_heads * config.head_size, + dtype=torch.float16, + device="cuda", + ) + v_unpad = torch.zeros( + token_count, + config.kv_num_heads * config.head_size, + dtype=torch.float16, + device="cuda", + ) + for i in range(config.batch_size): + new_seqlen = cum_seqlens[i + 1] - cum_seqlens[i] + q_unpad[cum_seqlens[i] : cum_seqlens[i + 1]] = rearrange(q[i, :new_seqlen], "s n h -> s (n h)") + k_unpad[cum_seqlens[i] : cum_seqlens[i + 1]] = rearrange(k[i, :new_seqlen], "s n h -> s (n h)") + v_unpad[cum_seqlens[i] : cum_seqlens[i + 1]] = rearrange(v[i, :new_seqlen], "s n h -> s (n h)") + return q_unpad, k_unpad, v_unpad + + +def generate_block_kvcache(config: Config, device, dtype): + num_blocks = math.ceil(config.total_sequence_length / config.paged_kv_block_size) * config.batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, config.paged_kv_block_size, config.kv_num_heads, config.head_size, device=device, dtype=dtype + ) + v_cache_paged = torch.randn( + num_blocks, config.paged_kv_block_size, config.kv_num_heads, config.head_size, device=device, dtype=dtype + ) + block_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=config.batch_size, + ) + k_cache = rearrange( + # pytorch 1.12 doesn't have indexing with int32 + k_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=config.batch_size, + )[:, : config.total_sequence_length] + v_cache = rearrange( + v_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=config.batch_size, + )[:, : config.total_sequence_length] + return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged + + +def parity_check_paged_attention( + config: Config, + rtol=1e-3, + atol=1e-3, +): + # Generate padded inputs + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k_new = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v_new = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + # Generate random sequence lengths + past_seqlens = torch.randint( + 0, + config.total_sequence_length - config.sequence_length + 1, # one above highest integer to be drawn + (config.batch_size,), + dtype=torch.int32, + device="cuda", + ) + new_seqlens = torch.randint( + 1, + config.sequence_length + 1, + (config.batch_size,), + dtype=torch.int32, + device="cuda", + ) + cum_seqlens = torch.cat( + (torch.tensor([0], dtype=torch.int32, device="cuda"), torch.cumsum(new_seqlens, dim=0)) + ).type(torch.int32) + total_seqlens = past_seqlens + new_seqlens + + q_unpad, k_unpad, v_unpad = unpad_qkv(config, q, k_new, v_new, cum_seqlens) + + # Generate kv cache and associated block-based data structures + k_cache, v_cache, block_table, k_cache_paged, v_cache_paged = generate_block_kvcache(config, "cuda", torch.float16) + + # Set window size for local / causal + window_size = (-1, -1) + left_window_size = -1 + if config.local: + left_window_size = random.randint(0, config.total_sequence_length - 1) # random.randint is inclusive + window_size = (left_window_size, 0) + else: + left_window_size = -1 + window_size = (-1, 0) + + # Apply rotary embedding for reference implementation + if config.rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.total_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=past_seqlens, interleaved=config.rotary_interleaved) + k_ro = rotary_embedding(k_new, cos, sin, seqlen_offsets=past_seqlens, interleaved=config.rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, k_new + + # Update reference kv cache + k_cache_ref = k_cache.clone() + v_cache_ref = v_cache.clone() + total_range = rearrange(torch.arange(config.total_sequence_length, device="cuda"), "s -> 1 s") + past_seqlens_expanded = rearrange(past_seqlens, "b -> b 1") + update_mask = torch.logical_and( + past_seqlens_expanded <= total_range, total_range < past_seqlens_expanded + config.sequence_length + ) + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(v_new, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + + # Create padding masks for reference implementation + total_seqlens_expanded = rearrange(total_seqlens, "b -> b 1") + key_padding_mask = total_range < total_seqlens_expanded + query_range = rearrange(torch.arange(config.sequence_length, device="cuda"), "s -> 1 s") + new_seqlens_expanded = rearrange(new_seqlens, "b -> b 1") + query_padding_mask = query_range < new_seqlens_expanded + + # Run reference implementation of attention + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + softcap=config.softcap, + ) + out_ref = out_ref.detach().cpu().numpy() + + if config.packed: + q_unpad = torch.concatenate([q_unpad, k_unpad, v_unpad], dim=1) + k_unpad = None + v_unpad = None + out, updated_k_cache_paged, updated_v_cache_paged = paged_attention_func( + config, + q_unpad, + k_unpad, + v_unpad, + k_cache_paged, + v_cache_paged, + cum_seqlens, + past_seqlens, + block_table, + cos, + sin, + left_window_size, + ) + num_tokens = q_unpad.shape[0] + out = torch.reshape(out, (num_tokens, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + err_msg = f" with {config}" + # Make sure past-present buffer updating correctly + present_k = rearrange( + updated_k_cache_paged[block_table.to(dtype=torch.long).flatten().cpu()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=config.batch_size, + )[:, : config.total_sequence_length] + present_v = rearrange( + updated_v_cache_paged[block_table.to(dtype=torch.long).flatten().cpu()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=config.batch_size, + )[:, : config.total_sequence_length] + for i in range(config.batch_size): + numpy.testing.assert_allclose( + present_k[i, : total_seqlens[i]], + k_cache_ref[i, : total_seqlens[i]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) + numpy.testing.assert_allclose( + present_v[i, : total_seqlens[i]], + v_cache_ref[i, : total_seqlens[i]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) + new_seqlen = cum_seqlens[i + 1] - cum_seqlens[i] + out_i = out[cum_seqlens[i] : cum_seqlens[i + 1]] + out_ref_i = out_ref[i, :new_seqlen] + numpy.testing.assert_allclose(out_i, out_ref_i, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) + + +def has_flash_attention(): + if not torch.cuda.is_available(): + return False + if "CUDAExecutionProvider" not in get_available_providers(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 8 and ( + platform.system() == "Linux" + or (platform.system() == "Windows" and version.parse(torch.version.cuda) >= version.parse("12.0")) + ) + + +def paged_attention_test_cases(): + batches = [4] if pipeline_mode else [1, 3, 5] + seqs = ( + [(1025, 2047)] + if pipeline_mode + else [ + (3, 1024), + (1, 339), + (408, 800), + (333, 799), + (64, 2048), + (837, 4000), + (17, 49), + (257, 257), + (459, 459), + ] + ) + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + block_sizes = [256] if pipeline_mode else [256, 512] + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for block_size in block_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + for softcap in [0.0, 50.0]: + if rotary and h % 16 > 0: + continue + + config = Config( + b, + s, + s2, + n, + n2, + h, + block_size, + local, + rotary, + rotary_interleaved, + packed, + softcap, + ) + yield ( + str(config), + config, + ) + + +@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.") +class TestPagedAttention(unittest.TestCase): + @parameterized.expand(paged_attention_test_cases()) + def test_paged_attention(self, _, config): + parity_check_paged_attention(config, rtol=5e-3, atol=5e-3) + + +if __name__ == "__main__": + unittest.main(verbosity=2)