Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ set(contrib_ops_excluded_files
"bert/attention.h"
"bert/attention_impl.cu"
"bert/attention_softmax.h"
"bert/cross_attention.cc"
"bert/cross_attention.h"
"bert/embed_layer_norm.cc"
"bert/embed_layer_norm.h"
"bert/embed_layer_norm_impl.cu"
Expand Down
74 changes: 59 additions & 15 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Do not modify directly.*
* <a href="#com.microsoft.ComplexMulConj">com.microsoft.ComplexMulConj</a>
* <a href="#com.microsoft.ConvTransposeWithDynamicPads">com.microsoft.ConvTransposeWithDynamicPads</a>
* <a href="#com.microsoft.CropAndResize">com.microsoft.CropAndResize</a>
* <a href="#com.microsoft.CrossAttention">com.microsoft.CrossAttention</a>
* <a href="#com.microsoft.DecoderAttention">com.microsoft.DecoderAttention</a>
* <a href="#com.microsoft.DequantizeBFP">com.microsoft.DequantizeBFP</a>
* <a href="#com.microsoft.DequantizeLinear">com.microsoft.DequantizeLinear</a>
Expand Down Expand Up @@ -106,10 +107,6 @@ Do not modify directly.*
When unidirectional is 1, each token only attends to previous tokens.

Both past and present state are optional. They shall be used together, and not allowed to use only one of them.

When weights is not provided, key and value are required. In this situation, MatMul for input projection is excluded,
and input is the query after projection. The bias is included for performance consideration.

The qkv_hidden_sizes is required only when K and V have different hidden sizes.

When there is past state, hidden dimension for Q, K and V shall be the same.
Expand All @@ -135,27 +132,23 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
</dl>

#### Inputs (3 - 9)
#### Inputs (3 - 7)

<dl>
<dt><tt>input</tt> (optional) : T</dt>
<dd>Input tensor with shape (batch_size, sequence_length, input_hidden_size) when weights is available, or query tensor with shape (batch_size, sequence_length, hidden_size) when weights is not available.</dd>
<dt><tt>weights</tt> (optional) : T</dt>
<dt><tt>input</tt> : T</dt>
<dd>Input tensor with shape (batch_size, sequence_length, input_hidden_size)</dd>
<dt><tt>weights</tt> : T</dt>
<dd>Merged Q/K/V weights with shape (input_hidden_size, hidden_size + hidden_size + v_hidden_size)</dd>
<dt><tt>bias</tt> : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) for input projection</dd>
<dt><tt>mask_index</tt> (optional) : M</dt>
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size).</dd>
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size)</dd>
<dt><tt>past</tt> (optional) : T</dt>
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
<dt><tt>extra_add</tt> (optional) : T</dt>
<dd>additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>key</tt> (optional) : T</dt>
<dd>Input for key with shape (batch_size, kv_sequence_length, hidden_size). Required when weights is not available.</dd>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Input for key with shape (batch_size, kv_sequence_length, v_hidden_size). Required when weights is not available.</dd>
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dd>When past_present_share_buffer, specify past_sequence_length for effective past sequence lenght (could be 0).Needed when past_present_share_buffer is not zero.</dd>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).</dd>
</dl>

#### Outputs (1 - 2)
Expand All @@ -164,7 +157,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
<dt><tt>present</tt> (optional) : T</dt>
<dd>past state for key and value with shape (2, batch_size, num_heads, total_sequence_length, head_size)If past_present_share_buffer, it should be exactly same as past tensor, of shape (2, batch_size, num_heads, max_sequence_length, head_size),while effective_seq_length = (past_sequence_length + kv_sequence_length)</dd>
<dd>past state for key and value with shape (2, batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
</dl>

#### Type Constraints
Expand Down Expand Up @@ -961,6 +954,57 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.CrossAttention"></a><a name="com.microsoft.crossattention">**com.microsoft.CrossAttention**</a>

Multi-Head Cross Attention. Bias from input projection is included.

The key padding mask is optional. When its shape is (batch_size, kv_sequence_length), value 0
means padding or 1 otherwise. When key has right-side padding, its shape could be (batch_size): it is actual length of
each key sequence excluding paddings.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads</dd>
</dl>

#### Inputs (4 - 5)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size) when weights is not available.</dd>
<dt><tt>key</tt> : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, hidden_size)</dd>
<dt><tt>value</tt> : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, v_hidden_size)</dd>
<dt><tt>bias</tt> : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain mask to integer types</dd>
</dl>


### <a name="com.microsoft.DecoderAttention"></a><a name="com.microsoft.decoderattention">**com.microsoft.DecoderAttention**</a>

This DecoderAttention supports self attention and cross attention, key and value cache, and key_padding_mask. The attention mask is not support at the moment.
Expand Down
7 changes: 4 additions & 3 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
Expand Down Expand Up @@ -762,7 +762,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
Expand All @@ -772,6 +772,7 @@ Do not modify directly.*
|ComplexMul|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ComplexMulConj|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|CrossAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|DecoderAttention|*in* query:**T**<br> *in* key:**T**<br> *in* q_weight:**T**<br> *in* kv_weight:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**B**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* static_kv:**B**<br> *in* use_past:**B**<br> *in* has_layer_state:**B**<br> *in* has_key_padding_mask:**B**<br> *out* output:**T**<br> *out* new_key_cache:**T**<br> *out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(float16)|
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down Expand Up @@ -1128,7 +1129,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
Expand Down
11 changes: 3 additions & 8 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
#include "core/common/safeint.h"
#include "core/platform/threadpool.h"

using onnxruntime::concurrency::ThreadPool;
using onnxruntime::narrow;
using onnxruntime::concurrency::ThreadPool;
namespace onnxruntime {
namespace contrib {

Expand Down Expand Up @@ -59,7 +59,7 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
Attention<float>);

template <typename T>
Attention<T>::Attention(const OpKernelInfo& info) : OpKernel(info), AttentionCPUBase(info, false, true) {
Attention<T>::Attention(const OpKernelInfo& info) : OpKernel(info), AttentionCPUBase(info, false) {
}

template <typename T>
Expand Down Expand Up @@ -200,20 +200,15 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
const Tensor* past = context->Input<Tensor>(4);
const Tensor* extra_add_qk = context->Input<Tensor>(5);

const Tensor* key = context->Input<Tensor>(6);
const Tensor* value = context->Input<Tensor>(7);

const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_);

AttentionParameters parameters;
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
(nullptr != weights || is_prepack_) ? &weights_shape : nullptr,
weights_shape,
bias->Shape(),
mask_index,
past,
extra_add_qk,
key,
value,
&parameters));

const int batch_size = parameters.batch_size;
Expand Down
Loading