diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index f6cc816b45ed2..64cbd8d04b295 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2553,7 +2553,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Softcap value for attention weights. Default value is 0.
-#### Inputs (7 - 11)
+#### Inputs (7 - 12)
- query : T
@@ -2578,6 +2578,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- 2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element
- attention_bias (optional) : T
- additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
+- head_sink (optional) : T
+- 1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.
#### Outputs
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 1ffcabee8cc10..e50702afe9975 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -538,7 +538,7 @@ Do not modify directly.*
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)|
@@ -942,7 +942,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -1420,7 +1420,7 @@ Do not modify directly.*
|FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
index ac32a4445f3ca..aef47edd5fcd2 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
@@ -17,13 +17,13 @@ namespace onnxruntime {
namespace contrib {
template
-inline void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
- MlasComputeSoftmax(score, score, N, D, false, true, tp);
+inline void ComputeSmoothSoftmaxInplace(T* score, int D, float sink, ThreadPool* tp) {
+ MlasComputeSoftmax(score, score, 1, D, false, true, sink, tp);
}
template
inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
- MlasComputeSoftmax(score, score, N, D, false, false, tp);
+ MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp);
}
template
diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
index c79508cbae273..cec495ef7391e 100644
--- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
@@ -51,6 +51,7 @@ class GQAAttentionBase {
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
const T* K, // K data with shape BxN_kvxSxH
const T* V, // V data with shape BxN_kvxSxH
+ const T* head_sink, // Head sink for smooth softmax, nullptr if not used
const Tensor* attention_bias, // Attention bias to add to QxK'
const Tensor* past_key, // past K input tensor (if not using past state)
const Tensor* past_value, // past V input tensor (if not using past state)
@@ -97,7 +98,7 @@ class GQAAttentionBase {
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
if (gqa_mlas_supported) {
- ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data,
+ ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data,
batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache,
head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt,
tp, allocator);
@@ -110,7 +111,7 @@ class GQAAttentionBase {
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
} else {
- ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data,
+ ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data,
batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache,
head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt,
tp, allocator);
@@ -136,6 +137,7 @@ class GQAAttentionBase {
void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
+ const T* head_sink, // for smooth softmax. Its size is N.
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
const T* attention_bias, // optional attention bias
const size_t batch_size, // batch size of self-attention
@@ -310,8 +312,9 @@ class GQAAttentionBase {
}
}
- if (use_smooth_softmax_) {
- ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr);
+ if (use_smooth_softmax_ || head_sink != nullptr) {
+ float sink = (head_sink != nullptr) ? static_cast(head_sink[head_index]) : 0.0f;
+ ComputeSmoothSoftmaxInplace(output_softmax + start_offset, static_cast(window_size), sink, nullptr);
} else {
ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr);
}
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
index a912bd6e6b43c..988151f778806 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
@@ -206,9 +206,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+ const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data() : nullptr;
+
// Compute the attention score and apply the score to V
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(),
- attention_bias, past_key, past_value, output, present_k, present_v,
+ head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v,
seqlens_k, parameters, allocator, context);
}
} // namespace contrib
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index f2757c2c96471..c2371487d9187 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1184,6 +1184,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)",
"T",
OpSchema::Optional)
+ .Input(11,
+ "head_sink",
+ "1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.",
+ "T",
+ OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index 3575e30721af7..d0f59eb534b82 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -1020,6 +1020,7 @@ MlasComputeSoftmax(
size_t D,
bool LogSoftmax,
bool SmoothSoftmax,
+ float Sink,
MLAS_THREADPOOL* ThreadPool
);
diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp
index 96a2398796777..669c73d2b9c06 100644
--- a/onnxruntime/core/mlas/lib/compute.cpp
+++ b/onnxruntime/core/mlas/lib/compute.cpp
@@ -74,6 +74,7 @@ struct MLAS_SOFTMAX_WORK_BLOCK {
ptrdiff_t ThreadCountN;
bool LogSoftmax;
bool SmoothSoftmax;
+ float Sink;
const T* Input;
T* Output;
size_t N;
@@ -850,6 +851,7 @@ Return Value:
const size_t D = WorkBlock->D;
const bool LogSoftmax = WorkBlock->LogSoftmax;
const bool SmoothSoftmax = WorkBlock->SmoothSoftmax;
+ const float Sink = WorkBlock->Sink;
const float* Input = WorkBlock->Input + n * D;
float* Output = WorkBlock->Output + n * D;
@@ -880,11 +882,12 @@ Return Value:
#else
float Maximum = MlasReduceMaximumF32Kernel(Input, D);
#endif
- float NegativeMaximum = -Maximum;
- if (SmoothSoftmax && NegativeMaximum > 0.0f) {
- NegativeMaximum = 0.0f;
+ if (SmoothSoftmax && Sink > Maximum) {
+ Maximum = Sink;
}
+ float NegativeMaximum = -Maximum;
+
//
// Compute the exponential function for each element of the row (save to Temp if provided) and
// compute the sum of these exponential functions.
@@ -897,7 +900,7 @@ Return Value:
#endif
if (SmoothSoftmax) {
- Accumulation += expf(NegativeMaximum);
+ Accumulation += expf(Sink + NegativeMaximum);
}
if (LogSoftmax) {
@@ -1014,6 +1017,7 @@ MlasComputeSoftmax(
size_t D,
bool LogSoftmax,
bool SmoothSoftmax,
+ float Sink,
MLAS_THREADPOOL* ThreadPool
)
/*++
@@ -1039,6 +1043,8 @@ Routine Description:
SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation.
+ Sink - Supplies the smooth factor to use in the softmax operation.
+
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
@@ -1060,6 +1066,7 @@ Return Value:
WorkBlock.Output = Output;
WorkBlock.N = N;
WorkBlock.D = D;
+ WorkBlock.Sink = Sink;
//
// Compute the number of target threads given the complexity of the softmax
@@ -1097,6 +1104,7 @@ MlasComputeSoftmax(
size_t D,
bool LogSoftmax,
bool SmoothSoftmax,
+ float Sink,
MLAS_THREADPOOL* ThreadPool
);
@@ -1110,6 +1118,7 @@ MlasComputeSoftmax(
size_t D,
bool LogSoftmax,
bool SmoothSoftmax,
+ float Sink,
MLAS_THREADPOOL* ThreadPool
);
diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc
index 2817dda9d0085..e123414b03b21 100644
--- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc
+++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc
@@ -99,7 +99,7 @@ common::Status SoftmaxCPU(size_t N,
float* Ydata,
bool logarithmic,
onnxruntime::concurrency::ThreadPool* thread_pool) {
- MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, thread_pool);
+ MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, 0.0f, thread_pool);
return Status::OK();
}
diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h
index 3359b2a69fe83..f7cc2523adbf6 100644
--- a/onnxruntime/core/providers/cpu/ml/ml_common.h
+++ b/onnxruntime/core/providers/cpu/ml/ml_common.h
@@ -445,7 +445,7 @@ void batched_update_scores_inplace(gsl::span scores, int64_t num_batches_in,
}
if (use_mlas) {
- MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, threadpool);
+ MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, 0.0f, threadpool);
} else {
while (s < s_end) {
gsl::span scores_for_batch(s, s + batch_size);
diff --git a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp
index 65822eb294d7d..ea36383f70621 100644
--- a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp
+++ b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp
@@ -58,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) {
std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory
// warming up run
- MlasComputeSoftmax(input, output, N, D, false, false, tp.get());
+ MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get());
for (auto _ : state) {
- MlasComputeSoftmax(input, output, N, D, false, false, tp.get());
+ MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get());
}
free(ptr.underlying_buffer);
diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp
index 041b6c61cd5bf..4d7a45143b311 100644
--- a/onnxruntime/test/mlas/unittest/test_softmax.cpp
+++ b/onnxruntime/test/mlas/unittest/test_softmax.cpp
@@ -152,7 +152,7 @@ class MlasSoftmaxTest : public MlasTestBase {
}
void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) {
- MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_);
+ MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_);
ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax);
constexpr float AbsoluteTolerance = 1e-6f;
@@ -206,7 +206,7 @@ class MlasSoftmaxTest : public MlasTestBase {
InputReference[nd] = Input[nd].ToFloat();
}
- MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_);
+ MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_);
ReferenceSoftmax(InputReference, OutputReference, N, D, LogSoftmax, SmoothSoftmax);
constexpr float AbsoluteTolerance = 5e-3f;
diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py
index 461c243b82212..ce0649e55f7c5 100644
--- a/onnxruntime/test/python/transformers/test_gqa_cpu.py
+++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py
@@ -54,6 +54,7 @@ class Config:
head_size: int = 0
has_position_ids: bool = False
has_attention_bias: bool = False
+ has_head_sink: bool = False
@dataclass
@@ -67,6 +68,7 @@ class PromptConfig:
head_size: int = 0
has_position_ids: bool = False
has_attention_bias: bool = False
+ has_head_sink: bool = False
# LLaMA Microsoft model
@@ -166,6 +168,7 @@ def create_group_query_attention_graph_prompt(
"sin_cache" if rotary else "",
"position_ids" if config.has_position_ids else "",
"attention_bias" if config.has_attention_bias else "",
+ "head_sink" if config.has_head_sink else "",
],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
@@ -289,6 +292,15 @@ def create_group_query_attention_graph_prompt(
),
]
+ if config.has_head_sink:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "head_sink",
+ ort_type,
+ [config.num_heads],
+ ),
+ ]
+
graph_output = [
helper.make_tensor_value_info(
"output",
@@ -380,6 +392,7 @@ def create_group_query_attention_graph_past(
"sin_cache" if rotary else "",
"position_ids" if config.has_position_ids else "",
"attention_bias" if config.has_attention_bias else "",
+ "head_sink" if config.has_head_sink else "",
],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
@@ -441,6 +454,7 @@ def create_group_query_attention_graph_past(
[1],
),
]
+
if not packed:
graph_input += [
helper.make_tensor_value_info(
@@ -462,6 +476,7 @@ def create_group_query_attention_graph_past(
],
),
]
+
if rotary:
graph_input += [
helper.make_tensor_value_info(
@@ -498,6 +513,15 @@ def create_group_query_attention_graph_past(
),
]
+ if config.has_head_sink:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "head_sink",
+ ort_type,
+ [config.num_heads],
+ ),
+ ]
+
graph_output = [
helper.make_tensor_value_info(
"output",
@@ -552,17 +576,17 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False):
"""
Arguments:
- q: (batch_size, seqlen_q, nheads, d)
- k: (batch_size, seqlen_k, nheads_k, d)
- v: (batch_size, seqlen_k, nheads_k, d)
+ q: (batch_size, seqlen_q, num_heads, d)
+ k: (batch_size, seqlen_k, num_heads_k, d)
+ v: (batch_size, seqlen_k, num_heads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
- batch_size, seqlen_q, nheads, d = q.shape
- _, seqlen_k, nheads_k, _ = k.shape
- assert k.shape == (batch_size, seqlen_k, nheads_k, d)
- assert v.shape == (batch_size, seqlen_k, nheads_k, d)
+ batch_size, seqlen_q, num_heads, d = q.shape
+ _, seqlen_k, num_heads_k, _ = k.shape
+ assert k.shape == (batch_size, seqlen_k, num_heads_k, d)
+ assert v.shape == (batch_size, seqlen_k, num_heads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
@@ -593,7 +617,7 @@ def output_pad_fn(output_unpad):
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
- assert nheads == nheads_k
+ assert num_heads == num_heads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
@@ -714,6 +738,7 @@ def gqa_prompt_func(
seqlens_k=None,
position_ids=None,
attention_bias=None,
+ head_sink=None,
window_size=-1,
past_kv_format=Formats.BSNH,
share_buffer=True,
@@ -749,6 +774,11 @@ def gqa_prompt_func(
if new_k is not None:
new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1))
new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1))
+
+ sess_options = SessionOptions()
+ ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
+ io_binding = ort_session.io_binding()
+
if share_buffer:
ort_inputs = {
"query": q.detach().cpu().numpy(),
@@ -758,9 +788,6 @@ def gqa_prompt_func(
"total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(),
}
- sess_options = SessionOptions()
- ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
- io_binding = ort_session.io_binding()
if new_k is not None:
ort_inputs["key"] = new_k.detach().cpu().numpy()
ort_inputs["value"] = new_v.detach().cpu().numpy()
@@ -797,25 +824,19 @@ def gqa_prompt_func(
io_binding.bind_output("output")
io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"])
io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"])
- ort_session.run_with_iobinding(io_binding)
- ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
- ort_output = numpy.array(ort_output)
- output = torch.tensor(ort_output)
- return output, present_k, present_v
else:
ort_inputs = {
"query": q.detach().cpu().numpy(),
"seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32),
"total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(),
}
- sess_options = SessionOptions()
- ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
- io_binding = ort_session.io_binding()
+
if new_k is not None:
ort_inputs["key"] = new_k.detach().cpu().numpy()
ort_inputs["value"] = new_v.detach().cpu().numpy()
io_binding.bind_cpu_input("key", ort_inputs["key"])
io_binding.bind_cpu_input("value", ort_inputs["value"])
+
if cos is not None:
ort_inputs["cos_cache"] = cos.detach().cpu().numpy()
ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
@@ -836,11 +857,16 @@ def gqa_prompt_func(
io_binding.bind_output("output")
io_binding.bind_output("present_key")
io_binding.bind_output("present_value")
- ort_session.run_with_iobinding(io_binding)
- ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
- ort_output = numpy.array(ort_output)
- output = torch.tensor(ort_output)
- return output, present_k, present_v
+
+ if config.has_head_sink:
+ ort_inputs["head_sink"] = head_sink.detach().cpu().numpy()
+ io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"])
+
+ ort_session.run_with_iobinding(io_binding)
+ ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
+ ort_output = numpy.array(ort_output)
+ output = torch.tensor(ort_output)
+ return output, present_k, present_v
def gqa_past_func(
@@ -855,6 +881,7 @@ def gqa_past_func(
seqlens_k=None,
position_ids=None,
attention_bias=None,
+ head_sink=None,
past_kv_format=Formats.BSNH,
share_buffer=True,
window_size=-1,
@@ -890,6 +917,11 @@ def gqa_past_func(
if new_k is not None:
new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1))
new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1))
+
+ sess_options = SessionOptions()
+ ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
+ io_binding = ort_session.io_binding()
+
if share_buffer:
ort_inputs = {
"query": q.detach().cpu().numpy(),
@@ -901,9 +933,7 @@ def gqa_past_func(
.cpu()
.numpy(),
}
- sess_options = SessionOptions()
- ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
- io_binding = ort_session.io_binding()
+
if new_k is not None and new_v is not None:
ort_inputs["key"] = new_k.detach().cpu().numpy()
ort_inputs["value"] = new_v.detach().cpu().numpy()
@@ -940,11 +970,6 @@ def gqa_past_func(
io_binding.bind_output("output")
io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"])
io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"])
- ort_session.run_with_iobinding(io_binding)
- ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
- ort_output = numpy.array(ort_output)
- output = torch.tensor(ort_output)
- return output, present_k, present_v
else:
ort_inputs = {
"query": q.detach().cpu().numpy(),
@@ -958,9 +983,7 @@ def gqa_past_func(
.cpu()
.numpy(),
}
- sess_options = SessionOptions()
- ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
- io_binding = ort_session.io_binding()
+
if new_k is not None and new_v is not None:
ort_inputs["key"] = new_k.detach().cpu().numpy()
ort_inputs["value"] = new_v.detach().cpu().numpy()
@@ -988,11 +1011,16 @@ def gqa_past_func(
io_binding.bind_output("output")
io_binding.bind_output("present_key")
io_binding.bind_output("present_value")
- ort_session.run_with_iobinding(io_binding)
- ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
- ort_output = numpy.array(ort_output)
- output = torch.tensor(ort_output)
- return output, present_k, present_v
+
+ if config.has_head_sink:
+ ort_inputs["head_sink"] = head_sink.detach().cpu().numpy()
+ io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"])
+
+ ort_session.run_with_iobinding(io_binding)
+ ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
+ ort_output = numpy.array(ort_output)
+ output = torch.tensor(ort_output)
+ return output, present_k, present_v
def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None):
@@ -1025,11 +1053,28 @@ def construct_local_mask(
)
-def smooth_softmax_ref(x):
- x_max = x.amax(axis=-1, keepdim=True)
- x_max = torch.maximum(x_max, torch.zeros_like(x_max))
- w = torch.exp(x - x_max)
- return w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-x_max))
+def smooth_softmax_ref(x, head_sink):
+ """
+ Arguments:
+ x: (batch_size, num_heads, seqlen_q, seqlen_k)
+ head_sink: (num_heads) or None
+ Output:
+ y: (batch_size, num_heads, seqlen_q, seqlen_k)
+ """
+ assert len(x.shape) == 4
+ b, n, s, t = x.shape
+
+ if head_sink is not None:
+ assert len(head_sink.shape) == 1
+ assert head_sink.shape[0] == x.shape[1]
+ sink = head_sink.reshape(1, n, 1, 1).expand(b, -1, s, -1)
+ else:
+ sink = torch.zeros(b, n, s, 1, dtype=x.dtype)
+
+ y = torch.cat([x, sink], dim=-1)
+ y = torch.softmax(y, dim=-1)
+ y = y[..., :-1]
+ return y
def attention_ref(
@@ -1046,16 +1091,17 @@ def attention_ref(
upcast=True,
reorder_ops=False,
use_smooth_softmax=False,
+ head_sink=None,
):
"""
Arguments:
- q: (batch_size, seqlen_q, nheads, head_dim)
- k: (batch_size, seqlen_k, nheads_k, head_dim)
- v: (batch_size, seqlen_k, nheads_k, head_dim)
+ q: (batch_size, seqlen_q, num_heads, head_dim)
+ k: (batch_size, seqlen_k, num_heads_k, head_dim)
+ v: (batch_size, seqlen_k, num_heads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
- dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
+ dropout_mask: (batch_size, num_heads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
@@ -1064,9 +1110,10 @@ def attention_ref(
without changing the math. This is to estimate the numerical error from operation
reordering.
use_smooth_softmax: whether use smooth softmax or not
+ head_sink: (num_heads) or None
Output:
- output: (batch_size, seqlen_q, nheads, head_dim)
- attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
+ output: (batch_size, seqlen_q, num_heads, head_dim)
+ attention: (batch_size, num_heads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
@@ -1098,8 +1145,8 @@ def attention_ref(
)
scores.masked_fill_(local_mask, float("-inf"))
- if use_smooth_softmax:
- attention = smooth_softmax_ref(scores)
+ if use_smooth_softmax or (head_sink is not None):
+ attention = smooth_softmax_ref(scores, head_sink)
else:
attention = torch.softmax(scores, dim=-1)
@@ -1133,6 +1180,7 @@ def attention_qkvpacked_ref(
upcast=True,
reorder_ops=False,
use_smooth_softmax=False,
+ head_sink=None,
):
return attention_ref(
qkv[:, :, 0],
@@ -1146,6 +1194,7 @@ def attention_qkvpacked_ref(
causal=causal,
reorder_ops=reorder_ops,
use_smooth_softmax=use_smooth_softmax,
+ head_sink=head_sink,
)
@@ -1186,6 +1235,10 @@ def get_custom_position_ids(batch_size, sequence_length, seqlens_k=None, past=Fa
return position_ids
+def get_custom_head_sink(num_heads, torch_type=torch.float16):
+ return torch.rand(num_heads, dtype=torch_type)
+
+
def parity_check_gqa_prompt(
config,
torch_type,
@@ -1248,6 +1301,8 @@ def parity_check_gqa_prompt(
requires_grad=False,
)
+ head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None
+
window_size = (-1, -1)
left_window_size = -1
if local:
@@ -1327,6 +1382,7 @@ def parity_check_gqa_prompt(
window_size=window_size,
softcap=softcap,
use_smooth_softmax=use_smooth_softmax,
+ head_sink=head_sink,
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
@@ -1349,6 +1405,7 @@ def parity_check_gqa_prompt(
cache_seqlens - 1,
position_ids,
attention_bias,
+ head_sink,
left_window_size,
past_format,
True,
@@ -1371,6 +1428,7 @@ def parity_check_gqa_prompt(
cache_seqlens - 1,
position_ids,
attention_bias,
+ head_sink,
left_window_size,
past_format,
True,
@@ -1531,6 +1589,8 @@ def parity_check_gqa_prompt_no_buff(
else None
)
+ head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None
+
brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
new_mask = brange < cache_seqlens_expanded
@@ -1548,6 +1608,7 @@ def parity_check_gqa_prompt_no_buff(
window_size=window_size,
softcap=softcap,
use_smooth_softmax=use_smooth_softmax,
+ head_sink=head_sink,
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
@@ -1570,6 +1631,7 @@ def parity_check_gqa_prompt_no_buff(
cache_seqlens - 1,
position_ids,
attention_bias,
+ head_sink,
left_window_size,
past_format,
False,
@@ -1592,6 +1654,7 @@ def parity_check_gqa_prompt_no_buff(
cache_seqlens - 1,
position_ids,
attention_bias,
+ head_sink,
left_window_size,
past_format,
False,
@@ -1759,6 +1822,8 @@ def parity_check_gqa_past(
cos, sin = None, None
q_ro, k_ro = q, new_k
+ head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None
+
arange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
update_mask = torch.logical_and(
@@ -1781,6 +1846,7 @@ def parity_check_gqa_past(
window_size=window_size,
softcap=softcap,
use_smooth_softmax=use_smooth_softmax,
+ head_sink=head_sink,
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
@@ -1822,6 +1888,7 @@ def parity_check_gqa_past(
cache_seqlens,
position_ids,
attention_bias,
+ head_sink,
past_format,
True,
left_window_size,
@@ -1844,6 +1911,7 @@ def parity_check_gqa_past(
cache_seqlens,
position_ids,
attention_bias,
+ head_sink,
past_format,
True,
left_window_size,
@@ -1882,6 +1950,8 @@ def parity_check_gqa_past(
softcap,
" smooth_softmax:",
use_smooth_softmax,
+ " head_sink:",
+ config.has_head_sink,
" B:",
config.batch_size,
" S:",
@@ -2017,6 +2087,8 @@ def parity_check_gqa_past_no_buff(
cos, sin = None, None
q_ro, k_ro = q, new_k
+ head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None
+
arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cpu"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
update_mask = torch.logical_and(
@@ -2039,6 +2111,7 @@ def parity_check_gqa_past_no_buff(
window_size=window_size,
softcap=softcap,
use_smooth_softmax=use_smooth_softmax,
+ head_sink=head_sink,
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
@@ -2080,6 +2153,7 @@ def parity_check_gqa_past_no_buff(
cache_seqlens,
position_ids,
attention_bias,
+ head_sink,
past_format,
False,
window_size=left_window_size,
@@ -2102,6 +2176,7 @@ def parity_check_gqa_past_no_buff(
cache_seqlens,
position_ids,
attention_bias,
+ head_sink,
past_format,
False,
window_size=left_window_size,
@@ -2134,6 +2209,8 @@ def parity_check_gqa_past_no_buff(
softcap,
" smooth_softmax:",
use_smooth_softmax,
+ " head_sink:",
+ config.has_head_sink,
"past kv format:",
"BSNH" if past_format == Formats.BSNH else "BNSH",
" B:",
@@ -2202,33 +2279,47 @@ def run_test_config(
for softcap in [0.0, 50.0]:
for use_smooth_softmax in [False, True]:
for has_pos, has_attn in pos_ids_attn_bias:
- if config_class == PromptConfig:
- config = config_class(
- b, s, s2, s + s2 + 8, n, n2, h, has_pos, has_attn
- )
- else: # Config
- sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
- config = config_class(b, s, s2, sp, n, n2, h, has_pos, has_attn)
-
- params = {
- "config": config,
- "torch_type": precision["torch_type"],
- "numpy_type": precision["numpy_type"],
- "ort_type": precision["ort_type"],
- "rtol": precision["rtol"],
- "atol": precision["atol"],
- "local": local,
- "past_format": Formats.BNSH,
- "rotary": rotary,
- "rotary_interleaved": rotary_interleaved,
- "packed": packed,
- "softcap": softcap,
- "use_smooth_softmax": use_smooth_softmax,
- }
- params.update(additional_params)
-
- all_close = test_func(**params)
- self.assertTrue(all_close)
+ for head_sink in [False, True]:
+ if use_smooth_softmax and head_sink:
+ continue
+ if config_class == PromptConfig:
+ config = config_class(
+ b,
+ s,
+ s2,
+ s + s2 + 8,
+ n,
+ n2,
+ h,
+ has_pos,
+ has_attn,
+ head_sink,
+ )
+ else: # Config
+ sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
+ config = config_class(
+ b, s, s2, sp, n, n2, h, has_pos, has_attn, head_sink
+ )
+
+ params = {
+ "config": config,
+ "torch_type": precision["torch_type"],
+ "numpy_type": precision["numpy_type"],
+ "ort_type": precision["ort_type"],
+ "rtol": precision["rtol"],
+ "atol": precision["atol"],
+ "local": local,
+ "past_format": Formats.BNSH,
+ "rotary": rotary,
+ "rotary_interleaved": rotary_interleaved,
+ "packed": packed,
+ "softcap": softcap,
+ "use_smooth_softmax": use_smooth_softmax,
+ }
+ params.update(additional_params)
+
+ all_close = test_func(**params)
+ self.assertTrue(all_close)
def test_gqa_no_past(self):
print("-------- TEST GQA NO PAST (PROMPT CASE) ---------")
diff --git a/onnxruntime/test/python/transformers/test_gqa_cuda.py b/onnxruntime/test/python/transformers/test_gqa_cuda.py
index 2f5b638a57d0c..79976a92e54bf 100644
--- a/onnxruntime/test/python/transformers/test_gqa_cuda.py
+++ b/onnxruntime/test/python/transformers/test_gqa_cuda.py
@@ -782,7 +782,8 @@ def attention_ref(
scores.masked_fill_(local_mask, float("-inf"))
if use_smooth_softmax:
- attention = smooth_softmax_ref(scores)
+ head_sink = None
+ attention = smooth_softmax_ref(scores, head_sink)
else:
attention = torch.softmax(scores, dim=-1)
diff --git a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py
index 410860a324a9d..ca5c9c2ce133f 100644
--- a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py
+++ b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py
@@ -401,7 +401,8 @@ def attention_ref(
scores.masked_fill_(local_mask, float("-inf"))
if use_smooth_softmax:
- attention = smooth_softmax_ref(scores)
+ head_sink = None
+ attention = smooth_softmax_ref(scores, head_sink)
else:
attention = torch.softmax(scores, dim=-1)