diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f6cc816b45ed2..64cbd8d04b295 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2553,7 +2553,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Softcap value for attention weights. Default value is 0.
-#### Inputs (7 - 11) +#### Inputs (7 - 12)
query : T
@@ -2578,6 +2578,8 @@ This version of the operator has been available since version 1 of the 'com.micr
2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element
attention_bias (optional) : T
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
+
head_sink (optional) : T
+
1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1ffcabee8cc10..e50702afe9975 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -538,7 +538,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -942,7 +942,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1420,7 +1420,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index ac32a4445f3ca..aef47edd5fcd2 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -17,13 +17,13 @@ namespace onnxruntime { namespace contrib { template -inline void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, true, tp); +inline void ComputeSmoothSoftmaxInplace(T* score, int D, float sink, ThreadPool* tp) { + MlasComputeSoftmax(score, score, 1, D, false, true, sink, tp); } template inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, false, tp); + MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp); } template diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index c79508cbae273..cec495ef7391e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -51,6 +51,7 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH + const T* head_sink, // Head sink for smooth softmax, nullptr if not used const Tensor* attention_bias, // Attention bias to add to QxK' const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) @@ -97,7 +98,7 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; if (gqa_mlas_supported) { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); @@ -110,7 +111,7 @@ class GQAAttentionBase { hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); } else { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); @@ -136,6 +137,7 @@ class GQAAttentionBase { void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH + const T* head_sink, // for smooth softmax. Its size is N. const int32_t* seqlens_k, // total - 1 sequence lengths tensor const T* attention_bias, // optional attention bias const size_t batch_size, // batch size of self-attention @@ -310,8 +312,9 @@ class GQAAttentionBase { } } - if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); + if (use_smooth_softmax_ || head_sink != nullptr) { + float sink = (head_sink != nullptr) ? static_cast(head_sink[head_index]) : 0.0f; + ComputeSmoothSoftmaxInplace(output_softmax + start_offset, static_cast(window_size), sink, nullptr); } else { ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index a912bd6e6b43c..988151f778806 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -206,9 +206,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data() : nullptr; + // Compute the attention score and apply the score to V return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - attention_bias, past_key, past_value, output, present_k, present_v, + head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, seqlens_k, parameters, allocator, context); } } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index f2757c2c96471..c2371487d9187 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1184,6 +1184,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) + .Input(11, + "head_sink", + "1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 3575e30721af7..d0f59eb534b82 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1020,6 +1020,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index 96a2398796777..669c73d2b9c06 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -74,6 +74,7 @@ struct MLAS_SOFTMAX_WORK_BLOCK { ptrdiff_t ThreadCountN; bool LogSoftmax; bool SmoothSoftmax; + float Sink; const T* Input; T* Output; size_t N; @@ -850,6 +851,7 @@ Return Value: const size_t D = WorkBlock->D; const bool LogSoftmax = WorkBlock->LogSoftmax; const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; + const float Sink = WorkBlock->Sink; const float* Input = WorkBlock->Input + n * D; float* Output = WorkBlock->Output + n * D; @@ -880,11 +882,12 @@ Return Value: #else float Maximum = MlasReduceMaximumF32Kernel(Input, D); #endif - float NegativeMaximum = -Maximum; - if (SmoothSoftmax && NegativeMaximum > 0.0f) { - NegativeMaximum = 0.0f; + if (SmoothSoftmax && Sink > Maximum) { + Maximum = Sink; } + float NegativeMaximum = -Maximum; + // // Compute the exponential function for each element of the row (save to Temp if provided) and // compute the sum of these exponential functions. @@ -897,7 +900,7 @@ Return Value: #endif if (SmoothSoftmax) { - Accumulation += expf(NegativeMaximum); + Accumulation += expf(Sink + NegativeMaximum); } if (LogSoftmax) { @@ -1014,6 +1017,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ) /*++ @@ -1039,6 +1043,8 @@ Routine Description: SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation. + Sink - Supplies the smooth factor to use in the softmax operation. + ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. @@ -1060,6 +1066,7 @@ Return Value: WorkBlock.Output = Output; WorkBlock.N = N; WorkBlock.D = D; + WorkBlock.Sink = Sink; // // Compute the number of target threads given the complexity of the softmax @@ -1097,6 +1104,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); @@ -1110,6 +1118,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 2817dda9d0085..e123414b03b21 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -99,7 +99,7 @@ common::Status SoftmaxCPU(size_t N, float* Ydata, bool logarithmic, onnxruntime::concurrency::ThreadPool* thread_pool) { - MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, thread_pool); + MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, 0.0f, thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index 3359b2a69fe83..f7cc2523adbf6 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -445,7 +445,7 @@ void batched_update_scores_inplace(gsl::span scores, int64_t num_batches_in, } if (use_mlas) { - MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, threadpool); + MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, 0.0f, threadpool); } else { while (s < s_end) { gsl::span scores_for_batch(s, s + batch_size); diff --git a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp index 65822eb294d7d..ea36383f70621 100644 --- a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp +++ b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp @@ -58,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) { std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory // warming up run - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); for (auto _ : state) { - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); } free(ptr.underlying_buffer); diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp index 041b6c61cd5bf..4d7a45143b311 100644 --- a/onnxruntime/test/mlas/unittest/test_softmax.cpp +++ b/onnxruntime/test/mlas/unittest/test_softmax.cpp @@ -152,7 +152,7 @@ class MlasSoftmaxTest : public MlasTestBase { } void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 1e-6f; @@ -206,7 +206,7 @@ class MlasSoftmaxTest : public MlasTestBase { InputReference[nd] = Input[nd].ToFloat(); } - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); ReferenceSoftmax(InputReference, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 5e-3f; diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 461c243b82212..ce0649e55f7c5 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -54,6 +54,7 @@ class Config: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False + has_head_sink: bool = False @dataclass @@ -67,6 +68,7 @@ class PromptConfig: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False + has_head_sink: bool = False # LLaMA Microsoft model @@ -166,6 +168,7 @@ def create_group_query_attention_graph_prompt( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -289,6 +292,15 @@ def create_group_query_attention_graph_prompt( ), ] + if config.has_head_sink: + graph_input += [ + helper.make_tensor_value_info( + "head_sink", + ort_type, + [config.num_heads], + ), + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -380,6 +392,7 @@ def create_group_query_attention_graph_past( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -441,6 +454,7 @@ def create_group_query_attention_graph_past( [1], ), ] + if not packed: graph_input += [ helper.make_tensor_value_info( @@ -462,6 +476,7 @@ def create_group_query_attention_graph_past( ], ), ] + if rotary: graph_input += [ helper.make_tensor_value_info( @@ -498,6 +513,15 @@ def create_group_query_attention_graph_past( ), ] + if config.has_head_sink: + graph_input += [ + helper.make_tensor_value_info( + "head_sink", + ort_type, + [config.num_heads], + ), + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -552,17 +576,17 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): """ Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) + q: (batch_size, seqlen_q, num_heads, d) + k: (batch_size, seqlen_k, num_heads_k, d) + v: (batch_size, seqlen_k, num_heads_k, d) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) + batch_size, seqlen_q, num_heads, d = q.shape + _, seqlen_k, num_heads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, num_heads_k, d) + assert v.shape == (batch_size, seqlen_k, num_heads_k, d) if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) @@ -593,7 +617,7 @@ def output_pad_fn(output_unpad): if qkvpacked: assert (query_padding_mask == key_padding_mask).all() - assert nheads == nheads_k + assert num_heads == num_heads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: @@ -714,6 +738,7 @@ def gqa_prompt_func( seqlens_k=None, position_ids=None, attention_bias=None, + head_sink=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True, @@ -749,6 +774,11 @@ def gqa_prompt_func( if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -758,9 +788,6 @@ def gqa_prompt_func( "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -797,25 +824,19 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() + if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() @@ -836,11 +857,16 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v + + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v def gqa_past_func( @@ -855,6 +881,7 @@ def gqa_past_func( seqlens_k=None, position_ids=None, attention_bias=None, + head_sink=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1, @@ -890,6 +917,11 @@ def gqa_past_func( if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -901,9 +933,7 @@ def gqa_past_func( .cpu() .numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -940,11 +970,6 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -958,9 +983,7 @@ def gqa_past_func( .cpu() .numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -988,11 +1011,16 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v + + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): @@ -1025,11 +1053,28 @@ def construct_local_mask( ) -def smooth_softmax_ref(x): - x_max = x.amax(axis=-1, keepdim=True) - x_max = torch.maximum(x_max, torch.zeros_like(x_max)) - w = torch.exp(x - x_max) - return w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-x_max)) +def smooth_softmax_ref(x, head_sink): + """ + Arguments: + x: (batch_size, num_heads, seqlen_q, seqlen_k) + head_sink: (num_heads) or None + Output: + y: (batch_size, num_heads, seqlen_q, seqlen_k) + """ + assert len(x.shape) == 4 + b, n, s, t = x.shape + + if head_sink is not None: + assert len(head_sink.shape) == 1 + assert head_sink.shape[0] == x.shape[1] + sink = head_sink.reshape(1, n, 1, 1).expand(b, -1, s, -1) + else: + sink = torch.zeros(b, n, s, 1, dtype=x.dtype) + + y = torch.cat([x, sink], dim=-1) + y = torch.softmax(y, dim=-1) + y = y[..., :-1] + return y def attention_ref( @@ -1046,16 +1091,17 @@ def attention_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, + head_sink=None, ): """ Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_k, head_dim) - v: (batch_size, seqlen_k, nheads_k, head_dim) + q: (batch_size, seqlen_q, num_heads, head_dim) + k: (batch_size, seqlen_k, num_heads_k, head_dim) + v: (batch_size, seqlen_k, num_heads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + dropout_mask: (batch_size, num_heads, seqlen_q, seqlen_k) causal: whether to apply causal masking window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast @@ -1064,9 +1110,10 @@ def attention_ref( without changing the math. This is to estimate the numerical error from operation reordering. use_smooth_softmax: whether use smooth softmax or not + head_sink: (num_heads) or None Output: - output: (batch_size, seqlen_q, nheads, head_dim) - attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + output: (batch_size, seqlen_q, num_heads, head_dim) + attention: (batch_size, num_heads, seqlen_q, seqlen_k), softmax after dropout """ if causal: window_size = (window_size[0], 0) @@ -1098,8 +1145,8 @@ def attention_ref( ) scores.masked_fill_(local_mask, float("-inf")) - if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + if use_smooth_softmax or (head_sink is not None): + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1) @@ -1133,6 +1180,7 @@ def attention_qkvpacked_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, + head_sink=None, ): return attention_ref( qkv[:, :, 0], @@ -1146,6 +1194,7 @@ def attention_qkvpacked_ref( causal=causal, reorder_ops=reorder_ops, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) @@ -1186,6 +1235,10 @@ def get_custom_position_ids(batch_size, sequence_length, seqlens_k=None, past=Fa return position_ids +def get_custom_head_sink(num_heads, torch_type=torch.float16): + return torch.rand(num_heads, dtype=torch_type) + + def parity_check_gqa_prompt( config, torch_type, @@ -1248,6 +1301,8 @@ def parity_check_gqa_prompt( requires_grad=False, ) + head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None + window_size = (-1, -1) left_window_size = -1 if local: @@ -1327,6 +1382,7 @@ def parity_check_gqa_prompt( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1349,6 +1405,7 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, + head_sink, left_window_size, past_format, True, @@ -1371,6 +1428,7 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, + head_sink, left_window_size, past_format, True, @@ -1531,6 +1589,8 @@ def parity_check_gqa_prompt_no_buff( else None ) + head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None + brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded @@ -1548,6 +1608,7 @@ def parity_check_gqa_prompt_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1570,6 +1631,7 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, + head_sink, left_window_size, past_format, False, @@ -1592,6 +1654,7 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, + head_sink, left_window_size, past_format, False, @@ -1759,6 +1822,8 @@ def parity_check_gqa_past( cos, sin = None, None q_ro, k_ro = q, new_k + head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None + arange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -1781,6 +1846,7 @@ def parity_check_gqa_past( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1822,6 +1888,7 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, + head_sink, past_format, True, left_window_size, @@ -1844,6 +1911,7 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, + head_sink, past_format, True, left_window_size, @@ -1882,6 +1950,8 @@ def parity_check_gqa_past( softcap, " smooth_softmax:", use_smooth_softmax, + " head_sink:", + config.has_head_sink, " B:", config.batch_size, " S:", @@ -2017,6 +2087,8 @@ def parity_check_gqa_past_no_buff( cos, sin = None, None q_ro, k_ro = q, new_k + head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -2039,6 +2111,7 @@ def parity_check_gqa_past_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -2080,6 +2153,7 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, + head_sink, past_format, False, window_size=left_window_size, @@ -2102,6 +2176,7 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, + head_sink, past_format, False, window_size=left_window_size, @@ -2134,6 +2209,8 @@ def parity_check_gqa_past_no_buff( softcap, " smooth_softmax:", use_smooth_softmax, + " head_sink:", + config.has_head_sink, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -2202,33 +2279,47 @@ def run_test_config( for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: for has_pos, has_attn in pos_ids_attn_bias: - if config_class == PromptConfig: - config = config_class( - b, s, s2, s + s2 + 8, n, n2, h, has_pos, has_attn - ) - else: # Config - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = config_class(b, s, s2, sp, n, n2, h, has_pos, has_attn) - - params = { - "config": config, - "torch_type": precision["torch_type"], - "numpy_type": precision["numpy_type"], - "ort_type": precision["ort_type"], - "rtol": precision["rtol"], - "atol": precision["atol"], - "local": local, - "past_format": Formats.BNSH, - "rotary": rotary, - "rotary_interleaved": rotary_interleaved, - "packed": packed, - "softcap": softcap, - "use_smooth_softmax": use_smooth_softmax, - } - params.update(additional_params) - - all_close = test_func(**params) - self.assertTrue(all_close) + for head_sink in [False, True]: + if use_smooth_softmax and head_sink: + continue + if config_class == PromptConfig: + config = config_class( + b, + s, + s2, + s + s2 + 8, + n, + n2, + h, + has_pos, + has_attn, + head_sink, + ) + else: # Config + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = config_class( + b, s, s2, sp, n, n2, h, has_pos, has_attn, head_sink + ) + + params = { + "config": config, + "torch_type": precision["torch_type"], + "numpy_type": precision["numpy_type"], + "ort_type": precision["ort_type"], + "rtol": precision["rtol"], + "atol": precision["atol"], + "local": local, + "past_format": Formats.BNSH, + "rotary": rotary, + "rotary_interleaved": rotary_interleaved, + "packed": packed, + "softcap": softcap, + "use_smooth_softmax": use_smooth_softmax, + } + params.update(additional_params) + + all_close = test_func(**params) + self.assertTrue(all_close) def test_gqa_no_past(self): print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") diff --git a/onnxruntime/test/python/transformers/test_gqa_cuda.py b/onnxruntime/test/python/transformers/test_gqa_cuda.py index 2f5b638a57d0c..79976a92e54bf 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cuda.py +++ b/onnxruntime/test/python/transformers/test_gqa_cuda.py @@ -782,7 +782,8 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + head_sink = None + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py index 410860a324a9d..ca5c9c2ce133f 100644 --- a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py +++ b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py @@ -401,7 +401,8 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + head_sink = None + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1)