Skip to content
Merged
97 changes: 97 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Do not modify directly.*
* <a href="#com.microsoft.PackedAttention">com.microsoft.PackedAttention</a>
* <a href="#com.microsoft.PackedMultiHeadAttention">com.microsoft.PackedMultiHeadAttention</a>
* <a href="#com.microsoft.Pad">com.microsoft.Pad</a>
* <a href="#com.microsoft.PagedAttention">com.microsoft.PagedAttention</a>
* <a href="#com.microsoft.QAttention">com.microsoft.QAttention</a>
* <a href="#com.microsoft.QGemm">com.microsoft.QGemm</a>
* <a href="#com.microsoft.QLinearAdd">com.microsoft.QLinearAdd</a>
Expand Down Expand Up @@ -3683,6 +3684,100 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.PagedAttention"></a><a name="com.microsoft.pagedattention">**com.microsoft.PagedAttention**</a>

Paged Attention.

This op leverages a block-based KV cache to enable continuous batching for LLMs. Currently, it is designed to work with
the CUDA Execution Provider only.

In other attention ops, batch entries typically aren't of the same length, so they are padded.
Below is a batch with 3 sequences where * denotes a padding token.
Sequence_0: 0, 1*, 2*, 3*
Sequence_1: 4, 5, 6*, 7*
Sequence_2: 8, 9, 10, 11

PagedAttention is designed to take in packed input, i.e., only the real tokens without padding.
For example, the input shown above will be packed into 3 tensors like below:
- query ([q0, q4, q5, q8, q9, q10, q11])
- key ([k0, k4, k5, k8, k9, k10, k11])
- value ([v0, v4, v5, v8, v9, v10, v11])
- cumulative_sequence_length: 0, 1, 1+2, 1+2+4
This packing omits padding tokens.

The query, key and value tensors contain result of hidden embedding of real tokens after input projections.
cumulative_sequence_length records cumulated length of each sequence length.


#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>do_rotary</tt> : int</dt>
<dd>Whether to use rotary position embedding. Default value is 0.</dd>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>local_window_size</tt> : int</dt>
<dd>left_window_size for local attention (like Mistral). Default value is -1 meaning unused.</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>rotary_interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>softcap</tt> : float</dt>
<dd>Softcap value for attention weights. Default value is 0.</dd>
</dl>

#### Inputs (8 - 10)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (num_tokens, hidden_size), or packed QKV with shape (num_tokens, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).</dd>
<dt><tt>key</tt> (optional) : T</dt>
<dd>Key with shape (num_tokens, kv_hidden_size) </dd>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (num_tokens, kv_hidden_size)</dd>
<dt><tt>key_cache</tt> : T</dt>
<dd>Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated in place within the op.</dd>
<dt><tt>value_cache</tt> : T</dt>
<dd>Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated in place within the op. This should be the same shape as key_cache.</dd>
<dt><tt>cumulative_sequence_length</tt> : S</dt>
<dd>A tensor with shape (batch_size + 1). It specifies the cumulative sequence lengths between the packed entries in Q/K/V.</dd>
<dt><tt>past_seqlens</tt> : S</dt>
<dd>A tensor with shape (batch_size). It specifies the past lengths of cached sequence in the KV cache.</dd>
<dt><tt>block_table</tt> : S</dt>
<dd>2D tensor with shape (batch_size, max_blocks_per_sequence) that maps each sequence in the batch to itscorresponding blocks in the KV cache.</dd>
<dt><tt>cos_cache</tt> (optional) : T</dt>
<dd>2D tensor with shape (max total seqlen, head_size / 2).</dd>
<dt><tt>sin_cache</tt> (optional) : T</dt>
<dd>2D tensor with shape (max total seqlen, head_size / 2).</dd>
</dl>

#### Outputs (1 - 3)

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (num_tokens, hidden_size)</dd>
<dt><tt>key_cache_out</tt> (optional) : T</dt>
<dd>Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always the same tensor as key_cache.</dd>
<dt><tt>value_cache_out</tt> (optional) : T</dt>
<dd>Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always the same tensor as value_cache.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>S</tt> : tensor(int32)</dt>
<dd>Constrain Positional inputs to int tensor.</dd>
</dl>


### <a name="com.microsoft.QAttention"></a><a name="com.microsoft.qattention">**com.microsoft.QAttention**</a>

Quantization of Multi-Head Self Attention.
Expand Down Expand Up @@ -6345,3 +6440,5 @@ No versioning maintained for experimental ops.
<dt><tt>T</tt> : tensor(float)</dt>
<dd>Constrain input and output types to float32 tensors.</dd>
</dl>


1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,7 @@ Do not modify directly.*
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|PackedAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PackedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PagedAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* cumulative_sequence_length:**S**<br> *in* past_seqlens:**S**<br> *in* block_table:**S**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* key_cache_out:**T**<br> *out* value_cache_out:**T**|1+|**S** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float16)<br/> **T1** = tensor(uint8)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* attention_bias:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@
output_parameters->is_unidirectional = is_unidirectional_;
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
output_parameters->do_rotary = do_rotary_;
output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;
output_parameters->rotary_dim = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;

Check warning on line 230 in onnxruntime/contrib_ops/cpu/bert/attention_base.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_base.cc:230: Using C-style cast. Use static_cast<int>(...) instead [readability/casting] [4]
output_parameters->mask_filter_value = mask_filter_value_;
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
Expand Down
31 changes: 19 additions & 12 deletions onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ struct AttentionParameters {
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
int num_splits;
int rotary_embedding;
int num_splits; // number of splits for splitkv
int rotary_dim = 0; // rotary embedding dimension
int beam_width;
bool is_unidirectional;
bool past_present_share_buffer;
bool is_packed_qkv = false; // whether qkv is packed
bool do_rotary;
bool broadcast_attn_bias_dim_0;
bool broadcast_attn_bias_dim_1;
Expand All @@ -46,13 +47,11 @@ struct DecoderMaskedMultiHeadAttentionParameters : AttentionParameters {
int beam_width = 1;

// Only NeoX style rotary embedding is supported
int rotary_embedding_dim = 0;
int t_step = 0;

// Weather to use multihead attention(excludes matmul and bias)
bool is_mha = false;
bool is_cross_attention = false;
bool is_packed_qkv = false;

// Useful to better use global memory bandwidth on certain CUDA architectures.
// Turned off by default for now until we fully understand performance implications
Expand Down Expand Up @@ -83,15 +82,12 @@ struct DecoderMaskedMultiHeadAttentionParameters : AttentionParameters {

// Parameters deduced from node attributes and inputs/outputs.
struct GroupQueryAttentionParameters : AttentionParameters {
int kv_num_heads; // number of heads of key or value
int kv_hidden_size; // hidden size of key or value
int seqlen_past_kv_cache; // sequence length of past kv tensor
int seqlen_present_kv_cache; // sequence length of present kv tensor
int kv_hidden_size;
int kv_num_heads;
int num_splits; // number of splits for splitkv
int rotary_dim; // rotary embedding dimension
int local_window_size; // The window size excludes current token. It only includes tokens on the left side.
int local_window_size; // The window size excludes current token. It only includes tokens on the left side.
bool kv_share_buffer;
bool is_packed_qkv;
bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1
bool is_first_prompt; // indicates whether this is first decoding step
bool rotary_interleaved;
Expand All @@ -102,18 +98,29 @@ struct GroupQueryAttentionParameters : AttentionParameters {
int* zero_ptr;
};

// Parameters deduced from node attributes and inputs/outputs.
struct PagedAttentionParameters : AttentionParameters {
int kv_num_heads; // number of heads of key or value
int kv_hidden_size; // hidden size of key or value
int token_count; // number of tokens in packed query
int block_size; // block size for kv cache
int max_num_blocks_per_seq; // max number of blocks per sequence for kv cache
int num_blocks; // number of blocks in kv cache
int local_window_size; // The window size excludes current token. It only includes tokens on the left side.
bool rotary_interleaved;
float softcap;
};

// Parameters for sparse attention.
struct SparseAttentionParameters : AttentionParameters {
int kv_hidden_size; // hidden size of key or value
int kv_num_heads; // number of heads of key or value
bool do_rotary; // whether to use rotary embedding
bool rotary_interleaved; // whether to use interleaved rotary embedding
int rotary_dim; // rotary embedding dimension
int sparse_block_size; // block size for sparse attention
int num_sparse_layout; // number of sparse layout
int stride_col_indices; // shape of block_col_indices is [num_sparse_layout, stride_col_indices]
int stride_row_indices; // shape of block_row_indices is [num_sparse_layout, stride_row_indices]
bool is_packed_qkv; // whether qkv is packed
int max_rotary_sequence_length; // max sequence length for rotary cos/sin cache
int max_cache_sequence_length; // max sequence length for kv cache buffer
};
Expand Down
Loading
Loading