Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
f39854d
add flash attention v2
tianleiwu Jul 18, 2023
aa1efa8
remove bfloat16 kernels
tianleiwu Jul 18, 2023
1fe60ae
namespace
tianleiwu Jul 18, 2023
6cb6eba
remove commented code
tianleiwu Jul 18, 2023
893b77e
remove backward code
tianleiwu Jul 18, 2023
1195d3f
path
tianleiwu Jul 19, 2023
4e59d95
flash v2 api and api call
aciddelgado Aug 3, 2023
b0f7f47
api build fixes
aciddelgado Aug 4, 2023
0e79ec5
builds succesfully
aciddelgado Aug 7, 2023
2f61ac2
some fixes
aciddelgado Aug 9, 2023
372141d
use varlen to run Attention_Mask1D_FP16_B2 test
aciddelgado Aug 11, 2023
95a7e35
Packed MHA cleanup
aciddelgado Aug 16, 2023
52553a7
flash attention runs
aciddelgado Aug 16, 2023
018ecb6
update
aciddelgado Aug 17, 2023
632451e
pre clean
aciddelgado Aug 18, 2023
98a4e5d
clean PMHA
aciddelgado Aug 18, 2023
12f8db8
packed mha flash
aciddelgado Aug 18, 2023
91c1ab8
remove extraneous changes
aciddelgado Aug 18, 2023
2fc7d86
reviewed changes
aciddelgado Aug 18, 2023
7cf0359
reviewed changes
aciddelgado Aug 18, 2023
f6554a1
reviewed changes
aciddelgado Aug 18, 2023
7a12487
reviewed changes
aciddelgado Aug 18, 2023
33a43aa
reviewed changes
aciddelgado Aug 18, 2023
24524cc
namespace and USE_FLASH_ATTENTION flag
aciddelgado Aug 21, 2023
d9002f7
more compile flags flash
aciddelgado Aug 21, 2023
3be4af2
lint
aciddelgado Aug 21, 2023
3798446
gcc warnings in template
aciddelgado Aug 21, 2023
666d3bf
address tianlei comments
aciddelgado Aug 21, 2023
545851a
clean up
Aug 22, 2023
5798e0b
cpplint
Aug 22, 2023
cad2901
refactoring
tianleiwu Aug 22, 2023
fdb189b
workspace as buffer. copyright and cgmanifest.
aciddelgado Aug 22, 2023
14d7db2
undo cgmanifest
aciddelgado Aug 22, 2023
aba9ceb
enable flash attention in MultiHeadAttention op
tianleiwu Aug 23, 2023
a06da2e
fix attention test error in A100
tianleiwu Aug 23, 2023
ab015e6
namespace from flash to onnxruntime::flash
tianleiwu Aug 23, 2023
866414c
undo cgmanifest
tianleiwu Aug 23, 2023
6d682ab
add unit test
tianleiwu Aug 23, 2023
300019e
enable flash attention in Attention op
tianleiwu Aug 23, 2023
7e50d9a
Merge branch 'main' into flash_v2_packed_mha
tianleiwu Aug 24, 2023
7481ec3
set proper nvcc threads to avoid OOM
tianleiwu Aug 24, 2023
2823476
--nvcc_threads=1 in build_cuda_c_api_package.sh
tianleiwu Aug 24, 2023
c880147
test script segfaults
aciddelgado Aug 24, 2023
d2a040b
pass cuda device prop to flash attention
tianleiwu Aug 24, 2023
8b92a9b
add requirements for test_flash_attn.py
tianleiwu Aug 24, 2023
a0393e2
remove nvcc_threads logic
tianleiwu Aug 24, 2023
0830925
flash attn test, pmha works, mha crashes
aciddelgado Aug 24, 2023
1c96739
check head size for efficient attention
tianleiwu Aug 24, 2023
6d8e43d
lint except lambda assignment
aciddelgado Aug 25, 2023
4ceca3b
lint fix
aciddelgado Aug 25, 2023
3bfa3b5
line length < 120
tianleiwu Aug 25, 2023
bdb17d5
flash v2 update
aciddelgado Aug 28, 2023
bfee28e
formatting
aciddelgado Aug 29, 2023
a54a7b9
flash benchmark script
aciddelgado Aug 29, 2023
b064a01
merge with main
aciddelgado Aug 29, 2023
e7b7f2e
Update c-api-noopenmp-packaging-pipelines.yml
aciddelgado Aug 29, 2023
f6927f7
io binding
aciddelgado Aug 30, 2023
345f4e6
update benchmark
tianleiwu Aug 30, 2023
da8bc50
Add bert-base
tianleiwu Aug 30, 2023
40a1f61
Merge remote-tracking branch 'origin/main' into flash_v2_packed_mha
aciddelgado Aug 30, 2023
f760123
merge main into branch for nuget fix
aciddelgado Aug 30, 2023
0dc8613
Merge branch 'flash_v2_packed_mha' into flash_v2_no_cuda52
aciddelgado Aug 30, 2023
92652c3
update benchark to support more input formats
tianleiwu Aug 30, 2023
ee2296f
Merge branch 'flash_v2_packed_mha' of https://github.com/microsoft/on…
tianleiwu Aug 30, 2023
e998af7
seq len threshold to trigger flash for packed qkv
tianleiwu Aug 30, 2023
599d019
add back 2 lines
tianleiwu Aug 30, 2023
492c59f
flash attention flag in packed attention op test and a few more bench…
aciddelgado Aug 30, 2023
6400a02
flash attention flag in packed attention op test and a few more bench…
aciddelgado Aug 30, 2023
1929de5
Merge remote-tracking branch 'refs/remotes/origin/flash_v2_packed_mha…
aciddelgado Aug 30, 2023
c880f08
specify TNLGv4 model for Turing Team in Benchmark
aciddelgado Aug 30, 2023
30c2f79
remove env variable change from packed attention test
aciddelgado Aug 31, 2023
01443ef
python lint
aciddelgado Aug 31, 2023
6a06d9e
Merge remote-tracking branch 'origin/main' into flash_v2_packed_mha
aciddelgado Aug 31, 2023
e1eb49a
Merge branch 'flash_v2_packed_mha' into flash_v2_no_cuda52
aciddelgado Aug 31, 2023
7605bb4
start work on group query attention
aciddelgado Aug 31, 2023
0697d19
work on check input and group query attention cc
aciddelgado Sep 1, 2023
b9784dc
more work on gqa
aciddelgado Sep 6, 2023
5e7286e
gqa working with causal or without causal
aciddelgado Sep 11, 2023
cb0a96f
push before rebase
aciddelgado Sep 12, 2023
afb493e
merge with main
aciddelgado Sep 12, 2023
6053c86
gqa with past builds
aciddelgado Sep 13, 2023
11608be
gqa working with past kv
aciddelgado Sep 14, 2023
9d31ad1
Merge remote-tracking branch 'origin/main' into aciddelgado/group_que…
aciddelgado Sep 14, 2023
9d2f922
some code cleaning
aciddelgado Sep 14, 2023
bdb3867
some fixes and clean up
aciddelgado Sep 14, 2023
362c6ae
no dumper
aciddelgado Sep 15, 2023
04801df
premerge main
aciddelgado Sep 18, 2023
3a11592
lint
aciddelgado Sep 18, 2023
2941dbc
mergemain
aciddelgado Sep 18, 2023
d78f476
Merge remote-tracking branch 'origin/main' into aciddelgado/group_que…
aciddelgado Sep 18, 2023
2d0b960
fix illegal access memory issue
aciddelgado Sep 19, 2023
5b076f7
clean up
aciddelgado Sep 19, 2023
3bf777c
bytes
aciddelgado Sep 20, 2023
cdc65dc
merge main
aciddelgado Sep 20, 2023
0e33dc1
gqa final touches
aciddelgado Sep 21, 2023
de64ff4
build fixes gqa
aciddelgado Sep 22, 2023
7a2ad7c
lint
aciddelgado Sep 22, 2023
7a47696
benchmark gqa vs dmmha
aciddelgado Sep 22, 2023
437d23c
fix comments
aciddelgado Sep 25, 2023
365d0b5
start work bnsh
aciddelgado Sep 25, 2023
470a8a7
bsnh present
aciddelgado Sep 25, 2023
05d1c56
Support for BNSH format
aciddelgado Sep 25, 2023
27dfac5
bnsh attribute and benchmark
aciddelgado Sep 25, 2023
6d681ee
past-present bnsh, non-cache past-present.
aciddelgado Sep 28, 2023
0e76730
merge bnsh and no buff
aciddelgado Sep 28, 2023
a7482ed
lint and benchmark script
aciddelgado Sep 29, 2023
46f0ce4
fix build issue
aciddelgado Oct 2, 2023
8679214
fix build pipeline
aciddelgado Oct 3, 2023
a0ec0eb
pr cleanup
aciddelgado Oct 3, 2023
b4082d1
int64 past sequence
aciddelgado Oct 4, 2023
befdb2d
small review changes p1
aciddelgado Oct 6, 2023
3fb6b9c
clang-format and update documentation
aciddelgado Oct 6, 2023
fcaba35
ignore whitespace when diff documentation
aciddelgado Oct 6, 2023
3a06b64
ignore blank lines
aciddelgado Oct 6, 2023
bbc47f0
formatting whitespace
aciddelgado Oct 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Do not modify directly.*
* <a href="#com.microsoft.GreedySearch">com.microsoft.GreedySearch</a>
* <a href="#com.microsoft.GridSample">com.microsoft.GridSample</a>
* <a href="#com.microsoft.GroupNorm">com.microsoft.GroupNorm</a>
* <a href="#com.microsoft.GroupQueryAttention">com.microsoft.GroupQueryAttention</a>
* <a href="#com.microsoft.Inverse">com.microsoft.Inverse</a>
* <a href="#com.microsoft.Irfft">com.microsoft.Irfft</a>
* <a href="#com.microsoft.LongformerAttention">com.microsoft.LongformerAttention</a>
Expand Down Expand Up @@ -1169,9 +1170,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
<dt><tt>present_key</tt> (optional) : T</dt>
<dd>past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dd>present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dt><tt>present_value</tt> (optional) : T</dt>
<dd>past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dd>present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
</dl>

#### Type Constraints
Expand Down Expand Up @@ -2218,6 +2219,69 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.GroupQueryAttention"></a><a name="com.microsoft.groupqueryattention">**com.microsoft.GroupQueryAttention**</a>

Group Query Self/Cross Attention.

Supports different number of heads for q and kv.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>is_past_bsnh</tt> : int</dt>
<dd>Whether past kv uses BSNH, otherwise BNSH. Default value is 1 (BSNH).</dd>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
<dd>Whether every token can only attend to previous tokens. Default value is 1.</dd>
</dl>

#### Inputs (3 - 6)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>key</tt> : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, kv_hidden_size) </dd>
<dt><tt>value</tt> : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>past_value</tt> (optional) : T</dt>
<dd>past state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dd>When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.</dd>
</dl>

#### Outputs (1 - 3)

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>present_key</tt> (optional) : T</dt>
<dd>present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>present_value</tt> (optional) : T</dt>
<dd>present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32), tensor(int64)</dt>
<dd>Constrain past sequence length to int tensor.</dd>
</dl>


### <a name="com.microsoft.Inverse"></a><a name="com.microsoft.inverse">**com.microsoft.Inverse**</a>

#### Version
Expand Down
8 changes: 5 additions & 3 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ Do not modify directly.*
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ConcatFromSequence|*in* input_sequence:**S**<br> *out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|ConstantOfShape|*in* input:**T1**<br> *out* output:**T2**|9+|**T1** = tensor(int64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ConstantOfShape|*in* input:**T1**<br> *out* output:**T2**|20+|**T1** = tensor(int64)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[9, 19]|**T1** = tensor(int64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Conv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|11+|**T** = tensor(float)|
|||[1, 10]|**T** = tensor(float)|
|ConvInteger|*in* x:**T1**<br> *in* w:**T2**<br> *in* x_zero_point:**T1**<br> *in* w_zero_point:**T2**<br> *out* y:**T3**|10+|**T1** = tensor(uint8)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int32)|
Expand All @@ -78,7 +79,7 @@ Do not modify directly.*
|Crop|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|CumSum|*in* x:**T**<br> *in* axis:**T2**<br> *out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int32), tensor(int64)|
|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int32), tensor(int64)|
|DFT|*in* input:**T1**<br> *in* dft_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
|DFT|*in* input:**T1**<br> *in* dft_length:**T2**<br> *in* axis:**tensor(int64)**<br> *out* output:**T1**<br><br>or<br><br>*in* input:**T1**<br> *in* dft_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[1, 10]|**T** = tensor(double), tensor(float)|
Expand Down Expand Up @@ -839,6 +840,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32), tensor(int64)<br/> **T** = tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down Expand Up @@ -935,7 +937,7 @@ Do not modify directly.*
|Crop|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|CumSum|*in* x:**T**<br> *in* axis:**T2**<br> *out* y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||11+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|DFT|*in* input:**T1**<br> *in* dft_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int64)|
|DFT|*in* input:**T1**<br> *in* dft_length:**T2**<br> *in* axis:**tensor(int64)**<br> *out* output:**T1**<br><br>or<br><br>*in* input:**T1**<br> *in* dft_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int64)|
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,26 @@ struct PackedAttentionParameters {
bool broadcast_res_pos_bias;
};

// Parameters deduced from node attributes and inputs/outputs.
struct GroupQueryAttentionParameters {
int batch_size;
int sequence_length;
int past_sequence_length; // actual sequence length of past_key and past_value
int kv_sequence_length; // sequence length of key and value (or new_k and new_v when past is present)
int present_sequence_length; // past_sequence_length + kv_sequence_length
int max_sequence_length; // allocated length of past_key and past_value
int hidden_size;
int num_heads;
int head_size;
int kv_hidden_size;
int kv_num_heads;
bool is_unidirectional; // causal
float scale;
int num_splits; // number of splits for splitkv
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
};

namespace attention {
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";
Expand Down
9 changes: 4 additions & 5 deletions onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ __global__ void AddBiasTransposeQKV(int M, const T* input, const T* biases, T* o
T* k_smem = q_smem + rotary_embedding_dim;

const int half_rotary_dim = rotary_embedding_dim / 2;
const int half_idx = (head_idx) / half_rotary_dim;
const int intra_half_idx = (head_idx) % half_rotary_dim;
const int smem_pitch = half_rotary_dim;
const int half_idx = (head_idx) / half_rotary_dim;
const int intra_half_idx = (head_idx) % half_rotary_dim;
const int smem_pitch = half_rotary_dim;

if (do_rotary) {
*reinterpret_cast<Vec_t*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
Expand Down Expand Up @@ -441,7 +441,6 @@ __global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, co
}
}


template <typename T>
__global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* output, int v_head_size) {
// Format 3 for cutlass memory efficient attention
Expand Down Expand Up @@ -651,7 +650,7 @@ void InvokeAddBiasTranspose(
if (format != 1 && format != 2 && format != 3) {
ORT_THROW("format must be 1, 2 or 3 for rotary attention");
}
if (qk_head_size != 64 && qk_head_size !=128) {
if (qk_head_size != 64 && qk_head_size != 128) {
ORT_THROW("qk_head_size must be 64 or 128 for rotary attention");
}
if (v_head_size != -1 && qk_head_size != v_head_size) {
Expand Down
44 changes: 22 additions & 22 deletions onnxruntime/contrib_ops/cuda/bert/bert_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -367,32 +367,32 @@ __global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
const int* attention_masks,
const int batch_size,
const int sequence_length) {
typedef cub::BlockReduce<int, kMAX_THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

const int batch_id = blockIdx.x;
const int* batch_mask = attention_masks + (batch_id * sequence_length);
const bool leftmost_non_zero = (batch_mask[0] != 0);
int biggest_position = 0;

for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) {
if (leftmost_non_zero == (batch_mask[i] != 0)) {
biggest_position = i;
} else {
break;
}
typedef cub::BlockReduce<int, kMAX_THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

const int batch_id = blockIdx.x;
const int* batch_mask = attention_masks + (batch_id * sequence_length);
const bool leftmost_non_zero = (batch_mask[0] != 0);
int biggest_position = 0;

for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) {
if (leftmost_non_zero == (batch_mask[i] != 0)) {
biggest_position = i;
} else {
break;
}
}

int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x);
int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x);

if (threadIdx.x == 0) {
int batch_offset = batch_id * sequence_length;
trt_mha_padding_offset[2 * batch_id] = batch_offset;
trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1;
if (batch_id == gridDim.x - 1) {
trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length;
}
if (threadIdx.x == 0) {
int batch_offset = batch_id * sequence_length;
trt_mha_padding_offset[2 * batch_id] = batch_offset;
trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1;
if (batch_id == gridDim.x - 1) {
trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length;
}
}
}

// only support simple left padding with mask 0s on leading left,
Expand Down
13 changes: 6 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ __global__ void MaskIndexKernel(int sequence_length, const int* mask, int* mask_
}

inline Status ComputeMaskIndex(cudaStream_t stream,
const int sequence_length,
const int batch_size,
const int* mask,
int* mask_index) {
const int sequence_length,
const int batch_size,
const int* mask,
int* mask_index) {
// Mask idx is of length batch_size and assumes the valid region is contiguous starting
// from the beginning of the sequence

Expand Down Expand Up @@ -133,7 +133,7 @@ __global__ void EmbedLayerNormKernel(
}
if (nullptr == position_ids) {
position_id = blockIdx.x;
} else if (broadcast_position_ids){
} else if (broadcast_position_ids) {
position_id = position_ids[sequence_position % gridDim.x];
} else {
position_id = position_ids[sequence_position];
Expand Down Expand Up @@ -212,13 +212,12 @@ Status LaunchEmbedLayerNormKernel(
void* embedding_sum,
const int* position_ids,
const bool broadcast_position_ids) {

if (mask_index != nullptr) {
if (nullptr == input_mask) {
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size, stream));
} else {
ORT_RETURN_IF_ERROR(
ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast<int*>(mask_index)));
ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast<int*>(mask_index)));
}
}

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int

template <>
Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const float* input, const float* bias, float* output, bool /*use_half2*/) {
const float* input, const float* bias, float* output, bool /*use_half2*/) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
FastGeluKernel<float, blockSize><<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length,
Expand All @@ -77,7 +77,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int

template <>
Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const half* input, const half* bias, half* output, bool use_half2) {
const half* input, const half* bias, half* output, bool use_half2) {
constexpr int blockSize = 256;
if (use_half2 && 0 == (bias_length & 1) && prop.major >= 7) {
const int n = input_length / 2;
Expand All @@ -101,7 +101,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int

template <>
Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
constexpr int blockSize = 256;

// remove nv_bfloat162 implementation for now to fix build issue
Expand Down
Loading