Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2553,7 +2553,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Softcap value for attention weights. Default value is 0.</dd>
</dl>

#### Inputs (7 - 11)
#### Inputs (7 - 12)

<dl>
<dt><tt>query</tt> : T</dt>
Expand All @@ -2578,6 +2578,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element</dd>
<dt><tt>attention_bias</tt> (optional) : T</dt>
<dd>additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)</dd>
<dt><tt>head_sink</tt> (optional) : T</dt>
<dd>1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.</dd>
</dl>

#### Outputs
Expand Down
6 changes: 3 additions & 3 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ Do not modify directly.*
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
Expand Down Expand Up @@ -942,7 +942,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down Expand Up @@ -1420,7 +1420,7 @@ Do not modify directly.*
|FusedMatMulActivation|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ namespace onnxruntime {
namespace contrib {

template <typename T>
inline void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
MlasComputeSoftmax(score, score, N, D, false, true, tp);
inline void ComputeSmoothSoftmaxInplace(T* score, int D, float sink, ThreadPool* tp) {
MlasComputeSoftmax(score, score, 1, D, false, true, sink, tp);
}

template <typename T>
inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
MlasComputeSoftmax(score, score, N, D, false, false, tp);
MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp);
}

template <typename T>
Expand Down
11 changes: 7 additions & 4 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class GQAAttentionBase {
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
const T* K, // K data with shape BxN_kvxSxH
const T* V, // V data with shape BxN_kvxSxH
const T* head_sink, // Head sink for smooth softmax, nullptr if not used
const Tensor* attention_bias, // Attention bias to add to QxK'
const Tensor* past_key, // past K input tensor (if not using past state)
const Tensor* past_value, // past V input tensor (if not using past state)
Expand Down Expand Up @@ -97,7 +98,7 @@ class GQAAttentionBase {
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;

if (gqa_mlas_supported) {
ComputeAttentionProbs(static_cast<T*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), attention_bias_data,
ComputeAttentionProbs(static_cast<T*>(attention_probs), Q, k, head_sink, seqlens_k->Data<int32_t>(), attention_bias_data,
batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache,
head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt,
tp, allocator);
Expand All @@ -110,7 +111,7 @@ class GQAAttentionBase {
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
} else {
ComputeAttentionProbs(static_cast<float*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), attention_bias_data,
ComputeAttentionProbs(static_cast<float*>(attention_probs), Q, k, head_sink, seqlens_k->Data<int32_t>(), attention_bias_data,
batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache,
head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt,
tp, allocator);
Expand All @@ -136,6 +137,7 @@ class GQAAttentionBase {
void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
const T* head_sink, // for smooth softmax. Its size is N.
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
const T* attention_bias, // optional attention bias
const size_t batch_size, // batch size of self-attention
Expand Down Expand Up @@ -310,8 +312,9 @@ class GQAAttentionBase {
}
}

if (use_smooth_softmax_) {
ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast<int>(window_size), nullptr);
if (use_smooth_softmax_ || head_sink != nullptr) {
float sink = (head_sink != nullptr) ? static_cast<float>(head_sink[head_index]) : 0.0f;
ComputeSmoothSoftmaxInplace(output_softmax + start_offset, static_cast<int>(window_size), sink, nullptr);
} else {
ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast<int>(window_size), nullptr);
}
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,11 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {

ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data<T>() : nullptr;

// Compute the attention score and apply the score to V
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
attention_bias, past_key, past_value, output, present_k, present_v,
head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v,
seqlens_k, parameters, allocator, context);
}
} // namespace contrib
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)",
"T",
OpSchema::Optional)
.Input(11,
"head_sink",
"1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.",
"T",
OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,7 @@ MlasComputeSoftmax(
size_t D,
bool LogSoftmax,
bool SmoothSoftmax,
float Sink,
MLAS_THREADPOOL* ThreadPool
);

Expand Down
17 changes: 13 additions & 4 deletions onnxruntime/core/mlas/lib/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct MLAS_SOFTMAX_WORK_BLOCK {
ptrdiff_t ThreadCountN;
bool LogSoftmax;
bool SmoothSoftmax;
float Sink;
const T* Input;
T* Output;
size_t N;
Expand Down Expand Up @@ -850,6 +851,7 @@ Return Value:
const size_t D = WorkBlock->D;
const bool LogSoftmax = WorkBlock->LogSoftmax;
const bool SmoothSoftmax = WorkBlock->SmoothSoftmax;
const float Sink = WorkBlock->Sink;

const float* Input = WorkBlock->Input + n * D;
float* Output = WorkBlock->Output + n * D;
Expand Down Expand Up @@ -880,11 +882,12 @@ Return Value:
#else
float Maximum = MlasReduceMaximumF32Kernel(Input, D);
#endif
float NegativeMaximum = -Maximum;
if (SmoothSoftmax && NegativeMaximum > 0.0f) {
NegativeMaximum = 0.0f;
if (SmoothSoftmax && Sink > Maximum) {
Maximum = Sink;
}

float NegativeMaximum = -Maximum;

//
// Compute the exponential function for each element of the row (save to Temp if provided) and
// compute the sum of these exponential functions.
Expand All @@ -897,7 +900,7 @@ Return Value:
#endif

if (SmoothSoftmax) {
Accumulation += expf(NegativeMaximum);
Accumulation += expf(Sink + NegativeMaximum);
}

if (LogSoftmax) {
Expand Down Expand Up @@ -1014,6 +1017,7 @@ MlasComputeSoftmax(
size_t D,
bool LogSoftmax,
bool SmoothSoftmax,
float Sink,
MLAS_THREADPOOL* ThreadPool
)
/*++
Expand All @@ -1039,6 +1043,8 @@ Routine Description:

SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation.

Sink - Supplies the smooth factor to use in the softmax operation.

ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.

Expand All @@ -1060,6 +1066,7 @@ Return Value:
WorkBlock.Output = Output;
WorkBlock.N = N;
WorkBlock.D = D;
WorkBlock.Sink = Sink;

//
// Compute the number of target threads given the complexity of the softmax
Expand Down Expand Up @@ -1097,6 +1104,7 @@ MlasComputeSoftmax<float>(
size_t D,
bool LogSoftmax,
bool SmoothSoftmax,
float Sink,
MLAS_THREADPOOL* ThreadPool
);

Expand All @@ -1110,6 +1118,7 @@ MlasComputeSoftmax<MLAS_FP16>(
size_t D,
bool LogSoftmax,
bool SmoothSoftmax,
float Sink,
MLAS_THREADPOOL* ThreadPool
);

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/math/softmax_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ common::Status SoftmaxCPU<float>(size_t N,
float* Ydata,
bool logarithmic,
onnxruntime::concurrency::ThreadPool* thread_pool) {
MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, thread_pool);
MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, 0.0f, thread_pool);
return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/ml/ml_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ void batched_update_scores_inplace(gsl::span<T> scores, int64_t num_batches_in,
}

if (use_mlas) {
MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow<size_t>(batch_size), false, false, threadpool);
MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow<size_t>(batch_size), false, false, 0.0f, threadpool);
} else {
while (s < s_end) {
gsl::span<float> scores_for_batch(s, s + batch_size);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/mlas/bench/bench_computesoftmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) {
std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory

// warming up run
MlasComputeSoftmax(input, output, N, D, false, false, tp.get());
MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get());

for (auto _ : state) {
MlasComputeSoftmax(input, output, N, D, false, false, tp.get());
MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get());
}

free(ptr.underlying_buffer);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/mlas/unittest/test_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class MlasSoftmaxTest : public MlasTestBase {
}

void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) {
MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_);
MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_);
ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax);

constexpr float AbsoluteTolerance = 1e-6f;
Expand Down Expand Up @@ -206,7 +206,7 @@ class MlasSoftmaxTest : public MlasTestBase {
InputReference[nd] = Input[nd].ToFloat();
}

MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_);
MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_);
ReferenceSoftmax(InputReference, OutputReference, N, D, LogSoftmax, SmoothSoftmax);

constexpr float AbsoluteTolerance = 5e-3f;
Expand Down
Loading
Loading