diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 15864a0198161..87387d4f281ed 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -27,6 +27,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/activate.cpp
${MLAS_SRC_DIR}/logistic.cpp
${MLAS_SRC_DIR}/tanh.cpp
+ ${MLAS_SRC_DIR}/eltwise.h
+ ${MLAS_SRC_DIR}/eltwise.cpp
${MLAS_SRC_DIR}/erf.cpp
${MLAS_SRC_DIR}/compute.cpp
${MLAS_SRC_DIR}/quantize.cpp
@@ -101,6 +103,9 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/softmax_kernel_neon.h
${MLAS_SRC_DIR}/softmax_kernel_neon.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
+ ${MLAS_SRC_DIR}/eltwise_kernel_neon.h
+ ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
+ ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
)
set(mlas_platform_preprocess_srcs
@@ -387,6 +392,8 @@ else()
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon.h
${MLAS_SRC_DIR}/softmax_kernel_neon.cpp
+ ${MLAS_SRC_DIR}/eltwise_kernel_neon.h
+ ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
@@ -409,6 +416,7 @@ else()
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
+ ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
@@ -423,6 +431,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 274531faaf717..f85ed1e5f146c 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2551,7 +2551,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Softcap value for attention weights. Default value is 0.
-#### Inputs (7 - 9)
+#### Inputs (7 - 11)
- query : T
@@ -2572,6 +2572,10 @@ This version of the operator has been available since version 1 of the 'com.micr
- 2D tensor with shape (max_sequence_length, head_size / 2).
- sin_cache (optional) : T
- 2D tensor with shape (max_sequence_length, head_size / 2).
+- position_ids (optional) : tensor(int64)
+- 2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element
+- attention_bias (optional) : T
+- additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
#### Outputs
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 84b9c7c9fc174..1dd145463367b 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -520,7 +520,7 @@ Do not modify directly.*
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)|
@@ -922,7 +922,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -1399,7 +1399,7 @@ Do not modify directly.*
|FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
index 188fc6e43b5b5..ac32a4445f3ca 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
@@ -31,6 +31,11 @@ void ComputeAttentionSoftcapInplace(T* scores, int sequence_length, T softcap) {
MlasComputeSoftcap(scores, scores, sequence_length, softcap);
}
+template
+void ApplyAttentionBias(T* softmax_logits, const T* attention_mask, int N) {
+ MlasEltwiseAdd(softmax_logits, attention_mask, softmax_logits, N);
+}
+
template
void PrepareMask(const int32_t* mask_index,
gsl::span mask_index_dims,
diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
index 70d66e534ee8a..ff6cb8edc0231 100644
--- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
@@ -50,6 +50,7 @@ class GQAAttentionBase {
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
const T* K, // K data with shape BxN_kvxSxH
const T* V, // V data with shape BxN_kvxSxH
+ const Tensor* attention_bias, // Attention bias to add to QxK'
const Tensor* past_key, // past K input tensor (if not using past state)
const Tensor* past_value, // past V input tensor (if not using past state)
Tensor* output, // output tensor
@@ -87,14 +88,18 @@ class GQAAttentionBase {
const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr;
T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr;
+ const T* attention_bias_data = attention_bias != nullptr ? attention_bias->Data() : nullptr;
+ auto attention_bias_shape = attention_bias != nullptr ? attention_bias->Shape().GetDims() : gsl::span{};
+
bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
if (gqa_mlas_supported) {
- ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size,
- sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
- present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
+ ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data,
+ batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache,
+ head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt,
+ tp, allocator);
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
@@ -104,9 +109,10 @@ class GQAAttentionBase {
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
} else {
- ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size,
- sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
- present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
+ ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data,
+ batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache,
+ head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt,
+ tp, allocator);
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
@@ -126,22 +132,24 @@ class GQAAttentionBase {
// attention_probs(B, N, S, T) = Softmax(attention_probs)
// If T is float32, U is float32. If T is float16, U could be float16 or float32.
template
- void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT
- const T* Q, // Q data. Its size is BxNxSxH
- const T* K, // k data. Its size is BxNxLxH
- const int32_t* seqlens_k, // total - 1 sequence lengths tensor
- const size_t batch_size, // batch size of self-attention
- const size_t sequence_length, // sequence length of self-attention (S)
- const size_t past_buffer_sequence_length, // sequence length of past state
- const size_t present_buffer_sequence_length, // sequence length of present state
- const size_t head_size, // head size of self-attention
- const T* past_key, // past key only
- T* present_key, // present key only
- const bool past_present_share_buffer, // whether present key and value share the same buffer
- const bool packed_qkv, // whether Q, K, V are packed
- const bool is_prompt, // whether it is prompt
- ThreadPool* tp, // thread pool
- AllocatorPtr allocator) const { // allocator for temporary buffer
+ void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT
+ const T* Q, // Q data. Its size is BxNxSxH
+ const T* K, // k data. Its size is BxNxLxH
+ const int32_t* seqlens_k, // total - 1 sequence lengths tensor
+ const T* attention_bias, // optional attention bias
+ const size_t batch_size, // batch size of self-attention
+ const size_t sequence_length, // sequence length of self-attention (S)
+ const gsl::span attention_bias_shape, // shape of the attention bias
+ const size_t past_buffer_sequence_length, // sequence length of past state
+ const size_t present_buffer_sequence_length, // sequence length of present state
+ const size_t head_size, // head size of self-attention
+ const T* past_key, // past key only
+ T* present_key, // present key only
+ const bool past_present_share_buffer, // whether present key and value share the same buffer
+ const bool packed_qkv, // whether Q, K, V are packed
+ const bool is_prompt, // whether it is prompt
+ ThreadPool* tp, // thread pool
+ AllocatorPtr allocator) const { // allocator for temporary buffer
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt(0);
@@ -189,6 +197,24 @@ class GQAAttentionBase {
const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length;
U* output = attention_probs + output_offset;
+ // Compute attention bias offset based on the batch and head indexes
+ // Attention bias is of shape (B or 1, H or 1, S, T) so handle broadcasting
+ const T* attention_bias_thread = nullptr;
+ ptrdiff_t attention_total_seqlen = 0;
+ if (attention_bias != nullptr) {
+ ptrdiff_t attention_bias_offset = 0;
+ attention_total_seqlen = static_cast(attention_bias_shape[3]);
+ const ptrdiff_t attention_matrix_size = sequence_length * attention_total_seqlen;
+ if (attention_bias_shape[0] != 1) {
+ attention_bias_offset += SafeInt(batch_index) * attention_bias_shape[1] * attention_matrix_size;
+ }
+ if (attention_bias_shape[1] != 1) {
+ attention_bias_offset += SafeInt(head_index) * attention_matrix_size;
+ }
+
+ attention_bias_thread = attention_bias + attention_bias_offset;
+ }
+
const T* k;
if (packed_qkv) {
k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
@@ -242,7 +268,15 @@ class GQAAttentionBase {
U* output_softmax = output;
for (size_t seq = 0; seq < sequence_length; seq++) {
size_t seq_causal_length = past_seqlen + seq + 1;
- if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) {
+
+ const bool should_apply_local_window = local_window_size_ > 0 &&
+ seq_causal_length > static_cast(local_window_size_) + 1;
+
+ const size_t start_offset = should_apply_local_window ? seq_causal_length - local_window_size_ - 1 : 0;
+ const size_t window_size = should_apply_local_window ? local_window_size_ + 1 : seq_causal_length;
+
+ // Mask everything before local window, if local window should be applied
+ if (should_apply_local_window) {
for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
if constexpr (std::is_same::value) {
output_softmax[total_seq_id] = 0.f;
@@ -250,27 +284,34 @@ class GQAAttentionBase {
output_softmax[total_seq_id] = MLFloat16::FromBits(static_cast(0));
}
}
- if (softcap_ > 0.f) {
- ComputeAttentionSoftcapInplace(output_softmax + seq_causal_length - local_window_size_ - 1,
- local_window_size_ + 1, static_cast(softcap_));
- }
- if (use_smooth_softmax_) {
- ComputeSmoothSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1,
- local_window_size_ + 1, nullptr);
+ }
+
+ if (softcap_ > 0.f) {
+ ComputeAttentionSoftcapInplace(output_softmax + start_offset, static_cast(window_size),
+ static_cast(softcap_));
+ }
+
+ // Add attention bias to QxK' if provided
+ // TODO (#23982): Implement bias addition during softmax computation in GQA CPU operator
+ if (attention_bias_thread != nullptr) {
+ if constexpr (std::is_same_v) {
+ ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread + start_offset,
+ static_cast(window_size));
} else {
- ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1,
- local_window_size_ + 1, nullptr);
+ static_assert(std::is_same_v && std::is_same_v);
+ size_t bytes = window_size * sizeof(float);
+ auto attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes));
+ BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator));
+
+ MlasConvertHalfToFloatBuffer(attention_bias_thread + start_offset, attention_bias_thread_fp32, window_size);
+ ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread_fp32, static_cast(window_size));
}
+ }
+
+ if (use_smooth_softmax_) {
+ ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr);
} else {
- if (softcap_ > 0.f) {
- ComputeAttentionSoftcapInplace(output_softmax, static_cast(seq_causal_length),
- static_cast(softcap_));
- }
- if (use_smooth_softmax_) {
- ComputeSmoothSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr);
- } else {
- ComputeAttentionSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr);
- }
+ ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr);
}
// set causal [seq_causal_length, total_seqlen) to 0.f
@@ -283,6 +324,10 @@ class GQAAttentionBase {
}
output_softmax += present_buffer_sequence_length;
+
+ if (attention_bias_thread != nullptr) {
+ attention_bias_thread += attention_total_seqlen;
+ }
}
}
});
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
index 8f662cd388c6d..9c7530f0126bb 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
@@ -52,6 +52,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
const Tensor* total_seqlen_tensor = context->Input(6);
const Tensor* cos_cache = context->Input(7);
const Tensor* sin_cache = context->Input(8);
+ const Tensor* position_ids = context->Input(9);
+ const Tensor* attention_bias = context->Input(10);
GroupQueryAttentionParameters parameters = {};
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
@@ -69,6 +71,10 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
scale_,
softcap_));
+ ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
+ attention_bias,
+ parameters));
+
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int present_kv_seqlen = parameters.seqlen_present_kv_cache;
@@ -129,9 +135,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
auto* tp = context->GetOperatorThreadPool();
// Generate position ids
const int pos_ids_size = parameters.is_first_prompt ? 1 : batch_size * sequence_length;
- std::vector pos_ids(pos_ids_size);
- if (parameters.is_first_prompt) {
- pos_ids[0] = static_cast(0);
+ std::vector default_pos_ids(pos_ids_size);
+ const int64_t* pos_ids_data = default_pos_ids.data();
+
+ if (position_ids != nullptr) {
+ pos_ids_data = position_ids->Data();
+ } else if (parameters.is_first_prompt) {
+ default_pos_ids[0] = static_cast(0);
} else {
// Note: As of now, continuous decoding supports only batch size 1 and token generation supports only sequence length 1.
for (int b = 0; b < batch_size; b++) {
@@ -139,13 +149,14 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
const int past_seqlen = total_seqlen - sequence_length;
for (int s = 0; s < sequence_length; s++) {
if (past_seqlen + s < total_seqlen) {
- pos_ids[b * sequence_length + s] = static_cast(past_seqlen) + s;
+ default_pos_ids[b * sequence_length + s] = static_cast(past_seqlen) + s;
} else {
- pos_ids[b * sequence_length + s] = static_cast(1);
+ default_pos_ids[b * sequence_length + s] = static_cast(1);
}
}
}
}
+
// Initialize separate buffers for rotary embeddings
const T* q_input;
const T* k_input;
@@ -165,7 +176,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
}
// Run rotary embedding for Q and K
ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input,
- pos_ids.data(), cos_cache->Data(),
+ pos_ids_data, cos_cache->Data(),
sin_cache->Data(), q_rotary, rotary_interleaved_));
rotary_params.num_heads = kv_num_heads_;
@@ -174,7 +185,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
rotary_params.batch_stride = kv_num_heads_ * rotary_params.head_stride;
}
ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input,
- pos_ids.data(), cos_cache->Data(),
+ pos_ids_data, cos_cache->Data(),
sin_cache->Data(), k_rotary, rotary_interleaved_));
// Pack V into rotary QKV buffer
if (packed_qkv) {
@@ -192,9 +203,10 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
}
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+
// Compute the attention score and apply the score to V
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(),
- past_key, past_value, output, present_k, present_v,
+ attention_bias, past_key, past_value, output, present_k, present_v,
seqlens_k, parameters, allocator, context);
}
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
index 4cc5a4228dc8c..7bffd768c8f7c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
@@ -288,6 +288,50 @@ Status CheckInputs(const T* query,
return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap);
}
+
+template
+Status CheckCustomAttentionInputs(const T* position_ids,
+ const T* attention_bias,
+ const GroupQueryAttentionParameters& parameters) {
+ if (position_ids != nullptr) {
+ const auto& pos_ids_shape = position_ids->Shape();
+ if (pos_ids_shape[0] != parameters.batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "position_ids dimension 0 must be equal to the batch size, got ", pos_ids_shape[0]);
+ }
+
+ if (pos_ids_shape[1] < parameters.sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "position_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]);
+ }
+ }
+
+ if (attention_bias != nullptr) {
+ const auto& attn_bias_shape = attention_bias->Shape();
+ if ((attn_bias_shape[0] != parameters.batch_size) && (attn_bias_shape[0] != 1)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "attention_bias dimension 0 must be equal to the batch size or 1, got ", attn_bias_shape[0]);
+ }
+
+ if ((attn_bias_shape[1] != parameters.num_heads) && (attn_bias_shape[1] != 1)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "attention_bias dimension 1 must be equal to the num heads or 1, got ", attn_bias_shape[1]);
+ }
+
+ if (attn_bias_shape[2] != parameters.sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "attention_bias dimension 2 must be equal to the sequence length, got ", attn_bias_shape[2]);
+ }
+
+ if (attn_bias_shape[3] != parameters.total_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "attention_bias dimension 3 must be equal to total_sequence_length, got ", attn_bias_shape[3]);
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace group_query_attention_helper
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index ecc8cb091b1b6..718dd9a4397b5 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1128,6 +1128,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"2D tensor with shape (max_sequence_length, head_size / 2).",
"T",
OpSchema::Optional)
+ .Input(9,
+ "position_ids",
+ "2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel "
+ "uses only the first element",
+ "tensor(int64)",
+ OpSchema::Optional)
+ .Input(10,
+ "attention_bias",
+ "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)",
+ "T",
+ OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index 1401e27ca77e5..8033eab8262a0 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -1030,6 +1030,16 @@ MlasComputeSoftcap(
T cap
);
+template
+void
+MLASCALL
+MlasEltwiseAdd(
+ const T* left,
+ const T* right,
+ T* output,
+ size_t N
+ );
+
template
void
MLASCALL
diff --git a/onnxruntime/core/mlas/lib/eltwise.cpp b/onnxruntime/core/mlas/lib/eltwise.cpp
new file mode 100644
index 0000000000000..f63d71b40bfbb
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/eltwise.cpp
@@ -0,0 +1,71 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ eltwise.cpp
+
+Abstract:
+
+ This module implements routines to compute element-wise operations on two vectors.
+
+ Currently supported element-wise operations:
+ - Add
+
+--*/
+
+#include "mlasi.h"
+#include "eltwise.h"
+
+template <>
+void
+MLASCALL
+MlasEltwiseAdd(
+ const float* left,
+ const float* right,
+ float* output,
+ size_t N
+) {
+ while (N > 0) {
+ if (N >= 4) {
+ MLAS_FLOAT32X4 LeftVec = MlasLoadFloat32x4(left);
+ MLAS_FLOAT32X4 RightVec = MlasLoadFloat32x4(right);
+
+ MLAS_FLOAT32X4 ResultVec = MlasAddFloat32x4(LeftVec, RightVec);
+
+ MlasStoreFloat32x4(output, ResultVec);
+
+ left += 4;
+ right += 4;
+ output += 4;
+ N -= 4;
+ } else {
+ *output = *left + *right;
+
+ left += 1;
+ right += 1;
+ output += 1;
+ N -= 1;
+ }
+ }
+}
+
+
+template <>
+void
+MLASCALL
+MlasEltwiseAdd(
+ const MLAS_FP16* left,
+ const MLAS_FP16* right,
+ MLAS_FP16* output,
+ size_t N
+) {
+ const auto* dispatch = GetMlasPlatform().EltwiseDispatch;
+ if (dispatch == nullptr || dispatch->Add_Fp16 == nullptr) {
+ MLAS_THROW_EX(std::runtime_error, "Add_Fp16 is not supported.");
+ }
+ dispatch->Add_Fp16(left, right, output, N);
+}
diff --git a/onnxruntime/core/mlas/lib/eltwise.h b/onnxruntime/core/mlas/lib/eltwise.h
new file mode 100644
index 0000000000000..a8345c499f6b7
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/eltwise.h
@@ -0,0 +1,37 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ eltwise.h
+
+Abstract:
+
+ This module includes kernel function prototypes and helper functions for
+ element-wise operations.
+
+--*/
+#pragma once
+
+#include "mlasi.h"
+
+struct MLAS_ELTWISE_DISPATCH {
+ /**
+ * @brief Compute the element-wise addition of the two given vectors
+ * @param left Address of the left operand
+ * @param right Address of the right operand
+ * @param output Address of the output array. Could be the same as the input array.
+ * @param N Number of elements in the input arrays
+ */
+ typedef void(Add_Fp16_Fn)(
+ const MLAS_FP16* left,
+ const MLAS_FP16* right,
+ MLAS_FP16* output,
+ size_t N
+ );
+
+ Add_Fp16_Fn* Add_Fp16 = nullptr;
+};
diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp
new file mode 100644
index 0000000000000..415c1281c808e
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp
@@ -0,0 +1,32 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ eltwise_kernel_neon.cpp
+
+Abstract:
+
+ This module implements the element-wise kernels for ARM NEON.
+
+--*/
+
+#include "eltwise.h"
+#include "eltwise_kernel_neon.h"
+
+//
+// Kernel dispatch structure definition.
+//
+const MLAS_ELTWISE_DISPATCH MlasEltwiseDispatchNeon = []() {
+ MLAS_ELTWISE_DISPATCH d;
+
+#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
+ if (MlasFp16AccelerationSupported()) {
+ d.Add_Fp16 = eltwise_neon::Add_Kernel_Fp16;
+ }
+#endif
+ return d;
+}();
diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h
new file mode 100644
index 0000000000000..d99a3e97c21f2
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h
@@ -0,0 +1,28 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ eltwise_kernel_neon.h
+
+Abstract:
+
+ This module includes function declarations and common helper functions for
+ element-wise operations on ARM cpu.
+
+--*/
+
+#pragma once
+
+#include
+
+#include "mlasi.h"
+
+namespace eltwise_neon {
+
+void Add_Kernel_Fp16(const MLAS_FP16* left, const MLAS_FP16* right, MLAS_FP16* output, size_t N);
+
+} // namespace eltwise_neon
diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp
new file mode 100644
index 0000000000000..decbdb576d5cd
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp
@@ -0,0 +1,118 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ eltwise_kernel_neon_fp16.cpp
+
+Abstract:
+
+ This module implements the fp16 element-wise kernels for ARM NEON.
+
+--*/
+#include
+#include
+
+#include "fp16_common.h"
+#include "eltwise.h"
+#include "eltwise_kernel_neon.h"
+
+namespace eltwise_neon {
+
+void Add_Kernel_Fp16(const MLAS_FP16* left, const MLAS_FP16* right, MLAS_FP16* output, size_t N) {
+ const auto* left_fp16 = reinterpret_cast(left);
+ const auto* right_fp16 = reinterpret_cast(right);
+ auto* output_fp16 = reinterpret_cast<_mlas_fp16_*>(output);
+
+ while (N >= 32) {
+ auto l0 = MlasLoadFloat16x8(left_fp16);
+ auto l1 = MlasLoadFloat16x8(left_fp16 + 8);
+ auto l2 = MlasLoadFloat16x8(left_fp16 + 16);
+ auto l3 = MlasLoadFloat16x8(left_fp16 + 24);
+
+ auto r0 = MlasLoadFloat16x8(right_fp16);
+ auto r1 = MlasLoadFloat16x8(right_fp16 + 8);
+ auto r2 = MlasLoadFloat16x8(right_fp16 + 16);
+ auto r3 = MlasLoadFloat16x8(right_fp16 + 24);
+
+ auto o0 = MlasAddFloat16(l0, r0);
+ auto o1 = MlasAddFloat16(l1, r1);
+ auto o2 = MlasAddFloat16(l2, r2);
+ auto o3 = MlasAddFloat16(l3, r3);
+
+ MlasStoreFloat16x8(output_fp16, o0);
+ MlasStoreFloat16x8(output_fp16 + 8, o1);
+ MlasStoreFloat16x8(output_fp16 + 16, o2);
+ MlasStoreFloat16x8(output_fp16 + 24, o3);
+
+ left_fp16 += 32;
+ right_fp16 += 32;
+ output_fp16 += 32;
+ N -= 32;
+ }
+
+ if (N & 16) {
+ auto l0 = MlasLoadFloat16x8(left_fp16);
+ auto l1 = MlasLoadFloat16x8(left_fp16 + 8);
+
+ auto r0 = MlasLoadFloat16x8(right_fp16);
+ auto r1 = MlasLoadFloat16x8(right_fp16 + 8);
+
+ auto o0 = MlasAddFloat16(l0, r0);
+ auto o1 = MlasAddFloat16(l1, r1);
+
+ MlasStoreFloat16x8(output_fp16, o0);
+ MlasStoreFloat16x8(output_fp16 + 8, o1);
+
+ left_fp16 += 16;
+ right_fp16 += 16;
+ output_fp16 += 16;
+ N -= 16;
+ }
+
+ if (N & 8) {
+ auto l0 = MlasLoadFloat16x8(left_fp16);
+ auto r0 = MlasLoadFloat16x8(right_fp16);
+ auto o0 = MlasAddFloat16(l0, r0);
+ MlasStoreFloat16x8(output_fp16, o0);
+
+ left_fp16 += 8;
+ right_fp16 += 8;
+ output_fp16 += 8;
+ N -= 8;
+ }
+
+ if (N & 4) {
+ auto l0 = MlasLoadFloat16x4(left_fp16);
+ auto r0 = MlasLoadFloat16x4(right_fp16);
+ auto o0 = MlasAddFloat16(l0, r0);
+ MlasStoreFloat16x4(output_fp16, o0);
+
+ left_fp16 += 4;
+ right_fp16 += 4;
+ output_fp16 += 4;
+ N -= 4;
+ }
+
+ if (N == 3) {
+ auto l0 = MlasLoadPartialFloat16x4(left_fp16, 3);
+ auto r0 = MlasLoadPartialFloat16x4(right_fp16, 3);
+ auto o0 = MlasAddFloat16(l0, r0);
+ MlasStorePartialFloat16x4(output_fp16, o0, 3);
+ } else if (N == 2) {
+ auto l0 = MlasLoadPartialFloat16x4(left_fp16, 2);
+ auto r0 = MlasLoadPartialFloat16x4(right_fp16, 2);
+ auto o0 = MlasAddFloat16(l0, r0);
+ MlasStorePartialFloat16x4(output_fp16, o0, 2);
+ } else if (N == 1) {
+ auto l0 = MlasLoadPartialFloat16x4(left_fp16, 1);
+ auto r0 = MlasLoadPartialFloat16x4(right_fp16, 1);
+ auto o0 = MlasAddFloat16(l0, r0);
+ MlasStorePartialFloat16x4(output_fp16, o0, 1);
+ }
+}
+
+} // namespace eltwise_neon
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
index 0681b49252495..8e704b5b94801 100644
--- a/onnxruntime/core/mlas/lib/mlasi.h
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -1070,6 +1070,10 @@ extern const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon;
struct MLAS_SOFTMAX_DISPATCH;
extern const MLAS_SOFTMAX_DISPATCH MlasSoftmaxDispatchNeon;
+// eltwise dispatch structure
+struct MLAS_ELTWISE_DISPATCH;
+extern const MLAS_ELTWISE_DISPATCH MlasEltwiseDispatchNeon;
+
//
// Quantized depthwise convolution kernels.
//
@@ -1233,6 +1237,7 @@ struct MLAS_PLATFORM {
const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr};
const MLAS_HGEMM_DISPATCH* HGemmDispatch{nullptr};
const MLAS_SOFTMAX_DISPATCH* SoftmaxDispatch{nullptr};
+ const MLAS_ELTWISE_DISPATCH* EltwiseDispatch{nullptr};
};
inline
diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp
index 582c1ab944b98..312a624fd160c 100644
--- a/onnxruntime/core/mlas/lib/platform.cpp
+++ b/onnxruntime/core/mlas/lib/platform.cpp
@@ -547,6 +547,7 @@ Return Value:
this->RopeDispatch = &MlasRopeDispatchNeon;
this->HGemmDispatch = &MlasHGemmDispatchNeon;
this->SoftmaxDispatch = &MlasSoftmaxDispatchNeon;
+ this->EltwiseDispatch = &MlasEltwiseDispatchNeon;
//
// Check if the processor supports ASIMD dot product instructions.
diff --git a/onnxruntime/test/mlas/unittest/test_eltwise.cpp b/onnxruntime/test/mlas/unittest/test_eltwise.cpp
new file mode 100644
index 0000000000000..c4d4b9c0eb317
--- /dev/null
+++ b/onnxruntime/test/mlas/unittest/test_eltwise.cpp
@@ -0,0 +1,106 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "test_util.h"
+#include "core/mlas/lib/mlasi.h"
+#include "core/mlas/lib/eltwise.h"
+
+class MlasEltwiseAddTest : public MlasTestBase {
+ private:
+ MatrixGuardBuffer BufferInputLeft;
+ MatrixGuardBuffer BufferInputRight;
+ MatrixGuardBuffer BufferOutput;
+ MatrixGuardBuffer BufferOutputReference;
+ MatrixGuardBuffer BufferInputLeftFp16;
+ MatrixGuardBuffer BufferInputRightFp16;
+ MatrixGuardBuffer BufferOutputFp16;
+
+ void Test(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) {
+ float* InputLeft = BufferInputLeft.GetBuffer(N);
+ float* InputRight = BufferInputRight.GetBuffer(N);
+ float* Output = BufferOutput.GetBuffer(N);
+ float* OutputReference = BufferOutputReference.GetBuffer(N);
+
+ std::default_random_engine generator(static_cast(N));
+ std::uniform_real_distribution distribution(MinimumValue, MaximumValue);
+
+ for (size_t n = 0; n < N; n++) {
+ InputLeft[n] = distribution(generator);
+ InputRight[n] = ScalarValue.value_or(distribution(generator));
+ }
+
+ for (size_t n = 0; n < N; n++) {
+ OutputReference[n] = InputLeft[n] + InputRight[n];
+ }
+
+ MlasEltwiseAdd(InputLeft, InputRight, Output, N);
+
+ constexpr float AbsoluteTolerance = 1e-6f;
+ constexpr float RelativeTolerance = 1e-6f;
+
+ for (size_t n = 0; n < N; n++) {
+ float diff = std::fabs(Output[n] - OutputReference[n]);
+ ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[n]) * RelativeTolerance)
+ << " @" << n << " of " << N << ", got: " << Output[n] << ", expecting: " << OutputReference[n];
+ }
+ }
+
+#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
+
+ void TestFp16(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) {
+ MLAS_FP16* InputLeft = BufferInputLeftFp16.GetBuffer(N);
+ MLAS_FP16* InputRight = BufferInputRightFp16.GetBuffer(N);
+ MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N);
+
+ std::default_random_engine generator(static_cast(N));
+ std::uniform_real_distribution distribution(MinimumValue, MaximumValue);
+
+ for (size_t n = 0; n < N; n++) {
+ InputLeft[n] = MLAS_FP16(distribution(generator));
+ InputRight[n] = MLAS_FP16(ScalarValue.value_or(distribution(generator)));
+ }
+
+ MlasEltwiseAdd(InputLeft, InputRight, Output, N);
+
+ constexpr float AbsoluteTolerance = 5e-4f;
+ constexpr float RelativeTolerance = 1e-3f;
+
+ for (size_t n = 0; n < N; n++) {
+ float inLeft = InputLeft[n].ToFloat();
+ float inRight = InputRight[n].ToFloat();
+ float ref = inLeft + inRight;
+ float out = Output[n].ToFloat();
+ float diff = std::fabs(out - ref);
+ ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance)
+ << " @ " << inLeft << ", " << inRight << ", got: " << out << ", expecting: " << ref
+ << ", r-diff: " << diff / std::fabs(ref);
+ }
+ }
+
+#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
+
+ public:
+ static const char* GetTestSuiteName() {
+ static const std::string suite_name("Eltwise_Add");
+ return suite_name.c_str();
+ }
+
+ void ExecuteShort(void) override {
+ for (size_t n = 1; n < 128; n++) {
+ Test(n, -10.f, 10.f);
+ Test(n, -10.f, 10.f, -5000.f);
+#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
+ TestFp16(n, -17.f, 11.f);
+ TestFp16(n, -17.f, 11.f, -5000.f);
+#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
+ }
+ }
+};
+
+static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) {
+ size_t count = 0;
+ if (is_short_execute) {
+ count += MlasDirectShortExecuteTests::RegisterShortExecute();
+ }
+ return count;
+});
diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py
index 77b4b326bf645..1239affcc04de 100644
--- a/onnxruntime/test/python/transformers/test_gqa_cpu.py
+++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py
@@ -12,6 +12,7 @@
import math
import random
import unittest
+from dataclasses import dataclass
import numpy
import torch
@@ -41,42 +42,30 @@ class Formats:
BNSH = 1
+@dataclass
class Config:
- batch_size = 0
- sequence_length = 0
- kv_sequence_length = 0
- past_sequence_length = 0
- num_heads = 0
- kv_num_heads = 0
- head_size = 0
-
- def __init__(self, b, s, s2, sp, n, n2, h):
- self.batch_size = b
- self.sequence_length = s
- self.kv_sequence_length = s2
- self.past_sequence_length = sp
- self.num_heads = n
- self.kv_num_heads = n2
- self.head_size = h
-
-
+ batch_size: int = 0
+ sequence_length: int = 0
+ kv_sequence_length: int = 0
+ past_sequence_length: int = 0
+ num_heads: int = 0
+ kv_num_heads: int = 0
+ head_size: int = 0
+ has_position_ids: bool = False
+ has_attention_bias: bool = False
+
+
+@dataclass
class PromptConfig:
- batch_size = 0
- q_sequence_length = 0
- kv_sequence_length = 0
- buffer_sequence_length = 0
- num_heads = 0
- kv_num_heads = 0
- head_size = 0
-
- def __init__(self, b, sq, skv, sb, n, n2, h):
- self.batch_size = b
- self.q_sequence_length = sq
- self.kv_sequence_length = skv
- self.buffer_sequence_length = sb
- self.num_heads = n
- self.kv_num_heads = n2
- self.head_size = h
+ batch_size: int = 0
+ q_sequence_length: int = 0
+ kv_sequence_length: int = 0
+ buffer_sequence_length: int = 0
+ num_heads: int = 0
+ kv_num_heads: int = 0
+ head_size: int = 0
+ has_position_ids: bool = False
+ has_attention_bias: bool = False
# LLaMA Microsoft model
@@ -173,6 +162,8 @@ def create_group_query_attention_graph_prompt(
"total_sequence_length",
"cos_cache" if rotary else "",
"sin_cache" if rotary else "",
+ "position_ids" if config.has_position_ids else "",
+ "attention_bias" if config.has_attention_bias else "",
],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
@@ -278,6 +269,24 @@ def create_group_query_attention_graph_prompt(
),
]
+ if config.has_position_ids:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "position_ids",
+ TensorProto.INT64,
+ [config.batch_size, config.kv_sequence_length],
+ ),
+ ]
+
+ if config.has_attention_bias:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "attention_bias",
+ ORT_TYPE,
+ [config.batch_size, 1, config.kv_sequence_length, config.kv_sequence_length],
+ ),
+ ]
+
graph_output = [
helper.make_tensor_value_info(
"output",
@@ -334,6 +343,7 @@ def create_group_query_attention_graph_prompt(
)
model = helper.make_model(graph)
+
return model.SerializeToString()
@@ -365,6 +375,8 @@ def create_group_query_attention_graph_past(
"total_sequence_length",
"cos_cache" if rotary else "",
"sin_cache" if rotary else "",
+ "position_ids" if config.has_position_ids else "",
+ "attention_bias" if config.has_attention_bias else "",
],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
@@ -467,6 +479,22 @@ def create_group_query_attention_graph_past(
),
]
+ if config.has_position_ids:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "position_ids", TensorProto.INT64, [config.batch_size, config.sequence_length]
+ ),
+ ]
+
+ if config.has_attention_bias:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "attention_bias",
+ ORT_TYPE,
+ [config.batch_size, 1, config.sequence_length, present_kv_seqlen],
+ ),
+ ]
+
graph_output = [
helper.make_tensor_value_info(
"output",
@@ -681,6 +709,8 @@ def gqa_prompt_func(
cos=None,
sin=None,
seqlens_k=None,
+ position_ids=None,
+ attention_bias=None,
window_size=-1,
past_kv_format=Formats.BSNH,
share_buffer=True,
@@ -699,9 +729,17 @@ def gqa_prompt_func(
softcap=softcap,
use_smooth_softmax=use_smooth_softmax,
)
+
q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1))
past_k = k.clone() if share_buffer else None
past_v = v.clone() if share_buffer else None
+
+ if config.has_position_ids:
+ assert position_ids is not None
+
+ if config.has_attention_bias:
+ assert attention_bias is not None
+
if new_k is not None:
new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1))
new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1))
@@ -713,6 +751,7 @@ def gqa_prompt_func(
"seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32),
"total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(),
}
+
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
io_binding = ort_session.io_binding()
@@ -726,6 +765,15 @@ def gqa_prompt_func(
ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
+
+ if config.has_position_ids:
+ ort_inputs["position_ids"] = position_ids.detach().cpu().numpy()
+ io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"])
+
+ if config.has_attention_bias:
+ ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy()
+ io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"])
+
io_binding.bind_cpu_input("query", ort_inputs["query"])
io_binding.bind_input(
"past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr()
@@ -767,6 +815,15 @@ def gqa_prompt_func(
ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
+
+ if config.has_position_ids:
+ ort_inputs["position_ids"] = position_ids.detach().cpu().numpy()
+ io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"])
+
+ if config.has_attention_bias:
+ ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy()
+ io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"])
+
io_binding.bind_cpu_input("query", ort_inputs["query"])
io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"])
io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"])
@@ -790,6 +847,8 @@ def gqa_past_func(
cos=None,
sin=None,
seqlens_k=None,
+ position_ids=None,
+ attention_bias=None,
past_kv_format=Formats.BSNH,
share_buffer=True,
window_size=-1,
@@ -812,6 +871,13 @@ def gqa_past_func(
q = torch.reshape(q, (config.batch_size, config.sequence_length, -1))
past_k = k.clone()
past_v = v.clone()
+
+ if config.has_position_ids:
+ assert position_ids is not None
+
+ if config.has_attention_bias:
+ assert attention_bias is not None
+
if new_k is not None:
new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1))
new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1))
@@ -839,6 +905,15 @@ def gqa_past_func(
ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
+
+ if config.has_position_ids:
+ ort_inputs["position_ids"] = position_ids.detach().cpu().numpy()
+ io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"])
+
+ if config.has_attention_bias:
+ ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy()
+ io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"])
+
io_binding.bind_cpu_input("query", ort_inputs["query"])
io_binding.bind_input(
"past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr()
@@ -887,6 +962,15 @@ def gqa_past_func(
ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
+
+ if config.has_position_ids:
+ ort_inputs["position_ids"] = position_ids.detach().cpu().numpy()
+ io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"])
+
+ if config.has_attention_bias:
+ ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy()
+ io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"])
+
io_binding.bind_cpu_input("query", ort_inputs["query"])
io_binding.bind_cpu_input("past_key", ort_inputs["past_key"])
io_binding.bind_cpu_input("past_value", ort_inputs["past_value"])
@@ -1056,6 +1140,41 @@ def attention_qkvpacked_ref(
)
+def get_custom_attention_bias(batch_size, sequence_length, total_seq_len, seqlens_k=None, past=False):
+ if past:
+ assert seqlens_k is not None
+ attention_bias = torch.zeros((batch_size, 1, sequence_length, total_seq_len), dtype=TORCH_TYPE)
+ for b in range(batch_size):
+ total_seq_len = seqlens_k[b] + 1
+ past_seq_len = total_seq_len - sequence_length
+
+ # Configure bias
+ for i in range(sequence_length):
+ for j in range(past_seq_len + i + 1, total_seq_len):
+ attention_bias[b][0][i][j] = -5000
+ else:
+ attention_bias = torch.rand(batch_size, 1, sequence_length, total_seq_len, dtype=TORCH_TYPE)
+ attention_bias = torch.triu(attention_bias, diagonal=1)
+
+ return attention_bias
+
+
+def get_custom_position_ids(batch_size, sequence_length, seqlens_k=None, past=False):
+ if past:
+ assert seqlens_k is not None
+ position_ids_data = []
+ for b in range(batch_size):
+ total_seq_len = seqlens_k[b] + 1
+ past_seq_len = total_seq_len - sequence_length
+ position_ids_data.append(list(range(past_seq_len, past_seq_len + sequence_length)))
+
+ position_ids = torch.tensor(data=position_ids_data, dtype=torch.int64)
+ else:
+ position_ids = torch.zeros((batch_size, sequence_length), dtype=torch.int64)
+
+ return position_ids
+
+
def parity_check_gqa_prompt(
config,
causal=True,
@@ -1087,6 +1206,7 @@ def parity_check_gqa_prompt(
dtype=TORCH_TYPE,
requires_grad=False,
)
+
v = torch.randn(
config.batch_size,
config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads,
@@ -1154,6 +1274,19 @@ def parity_check_gqa_prompt(
cos, sin = None, None
q_ro, k_ro = q, new_k
+ position_ids = (
+ get_custom_position_ids(config.batch_size, config.kv_sequence_length, seqlens_k=None, past=False)
+ if config.has_position_ids
+ else None
+ )
+ attention_bias = (
+ get_custom_attention_bias(
+ config.batch_size, config.kv_sequence_length, config.q_sequence_length, seqlens_k=None, past=False
+ )
+ if config.has_attention_bias
+ else None
+ )
+
rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s")
arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
@@ -1184,6 +1317,7 @@ def parity_check_gqa_prompt(
v_cache_ref = v_cache_ref.transpose(1, 2)
# Flash function
+ # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1
if packed:
packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
out, present_k, present_v = gqa_prompt_func(
@@ -1195,7 +1329,9 @@ def parity_check_gqa_prompt(
None,
cos,
sin,
- cache_seqlens,
+ cache_seqlens - 1,
+ position_ids,
+ attention_bias,
left_window_size,
past_format,
True,
@@ -1213,7 +1349,9 @@ def parity_check_gqa_prompt(
new_v,
cos,
sin,
- cache_seqlens,
+ cache_seqlens - 1,
+ position_ids,
+ attention_bias,
left_window_size,
past_format,
True,
@@ -1262,6 +1400,10 @@ def parity_check_gqa_prompt(
config.kv_num_heads,
" h:",
config.head_size,
+ " has_position_ids:",
+ config.has_position_ids,
+ " has_attention_bias:",
+ config.has_attention_bias,
" Mean Error:",
numpy.mean(numpy.abs(out - out_ref)),
correct,
@@ -1347,6 +1489,19 @@ def parity_check_gqa_prompt_no_buff(
q_ro, k_ro = q, k_cache_ref
k_cache_ref = k_ro
+ position_ids = (
+ get_custom_position_ids(config.batch_size, config.kv_sequence_length, seqlens_k=None, past=False)
+ if config.has_position_ids
+ else None
+ )
+ attention_bias = (
+ get_custom_attention_bias(
+ config.batch_size, config.kv_sequence_length, config.q_sequence_length, seqlens_k=None, past=False
+ )
+ if config.has_attention_bias
+ else None
+ )
+
brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
new_mask = brange < cache_seqlens_expanded
@@ -1371,6 +1526,7 @@ def parity_check_gqa_prompt_no_buff(
v_cache_ref = v_cache_ref.transpose(1, 2)
# Flash function
+ # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1
if packed:
packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
out, present_k, present_v = gqa_prompt_func(
@@ -1383,6 +1539,8 @@ def parity_check_gqa_prompt_no_buff(
cos,
sin,
cache_seqlens - 1,
+ position_ids,
+ attention_bias,
left_window_size,
past_format,
False,
@@ -1401,6 +1559,8 @@ def parity_check_gqa_prompt_no_buff(
cos,
sin,
cache_seqlens - 1,
+ position_ids,
+ attention_bias,
left_window_size,
past_format,
False,
@@ -1449,6 +1609,10 @@ def parity_check_gqa_prompt_no_buff(
config.kv_num_heads,
" h:",
config.head_size,
+ " has_position_ids:",
+ config.has_position_ids,
+ " has_attention_bias:",
+ config.has_attention_bias,
" Mean Error:",
numpy.mean(numpy.abs(out - out_ref)),
correct,
@@ -1589,6 +1753,19 @@ def parity_check_gqa_past(
cache_seqlens += config.sequence_length - 1
+ position_ids = (
+ get_custom_position_ids(config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True)
+ if config.has_position_ids
+ else None
+ )
+ attention_bias = (
+ get_custom_attention_bias(
+ config.batch_size, config.sequence_length, config.kv_sequence_length, seqlens_k=cache_seqlens, past=True
+ )
+ if config.has_attention_bias
+ else None
+ )
+
# ORT function
if packed:
packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
@@ -1602,6 +1779,8 @@ def parity_check_gqa_past(
cos,
sin,
cache_seqlens,
+ position_ids,
+ attention_bias,
past_format,
True,
left_window_size,
@@ -1620,6 +1799,8 @@ def parity_check_gqa_past(
cos,
sin,
cache_seqlens,
+ position_ids,
+ attention_bias,
past_format,
True,
left_window_size,
@@ -1668,6 +1849,10 @@ def parity_check_gqa_past(
config.kv_num_heads,
" h:",
config.head_size,
+ " has_position_ids:",
+ config.has_position_ids,
+ " has_attention_bias:",
+ config.has_attention_bias,
" Mean Error:",
numpy.mean(numpy.abs(out - out_ref)),
correct,
@@ -1814,6 +1999,23 @@ def parity_check_gqa_past_no_buff(
cache_seqlens += config.sequence_length - 1
+ position_ids = (
+ get_custom_position_ids(config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True)
+ if config.has_position_ids
+ else None
+ )
+ attention_bias = (
+ get_custom_attention_bias(
+ config.batch_size,
+ config.sequence_length,
+ config.kv_sequence_length + config.sequence_length,
+ seqlens_k=cache_seqlens,
+ past=True,
+ )
+ if config.has_attention_bias
+ else None
+ )
+
# Flash function
if packed:
packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
@@ -1827,6 +2029,8 @@ def parity_check_gqa_past_no_buff(
cos,
sin,
cache_seqlens,
+ position_ids,
+ attention_bias,
past_format,
False,
window_size=left_window_size,
@@ -1845,6 +2049,8 @@ def parity_check_gqa_past_no_buff(
cos,
sin,
cache_seqlens,
+ position_ids,
+ attention_bias,
past_format,
False,
window_size=left_window_size,
@@ -1889,6 +2095,10 @@ def parity_check_gqa_past_no_buff(
config.kv_num_heads,
" h:",
config.head_size,
+ " has_position_ids:",
+ config.has_position_ids,
+ " has_attention_bias:",
+ config.has_attention_bias,
" Mean Error:",
numpy.mean(numpy.abs(out - out_ref)),
correct,
@@ -1916,6 +2126,11 @@ def test_gqa_no_past(self):
(8000, 8000),
]
)
+ pos_ids_attn_bias = (
+ [(False, False), (True, True)]
+ if pipeline_mode
+ else [(False, False), (True, True), (False, True), (True, False)]
+ )
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
for b in batches:
@@ -1927,30 +2142,41 @@ def test_gqa_no_past(self):
for packed in [False, True]:
for softcap in [0.0, 50.0]:
for use_smooth_softmax in [False, True]:
- config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
- past_kv_format = Formats.BNSH
- all_close = parity_check_gqa_prompt(
- config,
- local=local,
- past_format=past_kv_format,
- rotary=rotary,
- rotary_interleaved=rotary_interleaved,
- packed=packed,
- softcap=softcap,
- use_smooth_softmax=use_smooth_softmax,
- )
- self.assertTrue(all_close)
- all_close = parity_check_gqa_prompt_no_buff(
- config,
- local=local,
- past_format=past_kv_format,
- rotary=rotary,
- rotary_interleaved=rotary_interleaved,
- packed=packed,
- softcap=softcap,
- use_smooth_softmax=use_smooth_softmax,
- )
- self.assertTrue(all_close)
+ for has_position_ids, has_attention_bias in pos_ids_attn_bias:
+ config = PromptConfig(
+ b,
+ sq,
+ skv,
+ sq + skv + 8,
+ n,
+ n2,
+ h,
+ has_position_ids,
+ has_attention_bias,
+ )
+ past_kv_format = Formats.BNSH
+ all_close = parity_check_gqa_prompt(
+ config,
+ local=local,
+ past_format=past_kv_format,
+ rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
+ packed=packed,
+ softcap=softcap,
+ use_smooth_softmax=use_smooth_softmax,
+ )
+ self.assertTrue(all_close)
+ all_close = parity_check_gqa_prompt_no_buff(
+ config,
+ local=local,
+ past_format=past_kv_format,
+ rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
+ packed=packed,
+ softcap=softcap,
+ use_smooth_softmax=use_smooth_softmax,
+ )
+ self.assertTrue(all_close)
def test_gqa_past(self):
print("-------- TEST GQA PAST (TOKEN GEN) ---------")
@@ -1972,6 +2198,11 @@ def test_gqa_past(self):
# (128, 128),
]
)
+ pos_ids_attn_bias = (
+ [(False, False), (True, True)]
+ if pipeline_mode
+ else [(False, False), (True, True), (False, True), (True, False)]
+ )
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)
@@ -1984,35 +2215,38 @@ def test_gqa_past(self):
for packed in [False, True]:
for softcap in [0.0, 50.0]:
for use_smooth_softmax in [False, True]:
- sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
- config = Config(b, s, s2, sp, n, n2, h)
- past_kv_format = Formats.BNSH
- all_close = parity_check_gqa_past(
- config,
- local=local,
- past_format=past_kv_format,
- rtol=RTOL,
- atol=ATOL,
- rotary=rotary,
- rotary_interleaved=rotary_interleaved,
- packed=packed,
- softcap=softcap,
- use_smooth_softmax=use_smooth_softmax,
- )
- self.assertTrue(all_close)
- all_close = parity_check_gqa_past_no_buff(
- config,
- local=local,
- past_format=past_kv_format,
- rtol=RTOL,
- atol=ATOL,
- rotary=rotary,
- rotary_interleaved=rotary_interleaved,
- packed=packed,
- softcap=softcap,
- use_smooth_softmax=use_smooth_softmax,
- )
- self.assertTrue(all_close)
+ for has_position_ids, has_attention_bias in pos_ids_attn_bias:
+ sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
+ config = Config(
+ b, s, s2, sp, n, n2, h, has_position_ids, has_attention_bias
+ )
+ past_kv_format = Formats.BNSH
+ all_close = parity_check_gqa_past(
+ config,
+ local=local,
+ past_format=past_kv_format,
+ rtol=RTOL,
+ atol=ATOL,
+ rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
+ packed=packed,
+ softcap=softcap,
+ use_smooth_softmax=use_smooth_softmax,
+ )
+ self.assertTrue(all_close)
+ all_close = parity_check_gqa_past_no_buff(
+ config,
+ local=local,
+ past_format=past_kv_format,
+ rtol=RTOL,
+ atol=ATOL,
+ rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
+ packed=packed,
+ softcap=softcap,
+ use_smooth_softmax=use_smooth_softmax,
+ )
+ self.assertTrue(all_close)
def test_gqa_interactive_one_batch(self):
print("-------- TEST GQA INTERACTIVE ---------")
@@ -2034,6 +2268,11 @@ def test_gqa_interactive_one_batch(self):
# (128, 128),
]
)
+ pos_ids_attn_bias = (
+ [(False, False), (True, True)]
+ if pipeline_mode
+ else [(False, False), (True, True), (False, True), (True, False)]
+ )
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)
@@ -2044,30 +2283,31 @@ def test_gqa_interactive_one_batch(self):
for local in [False, True]:
for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]:
for packed in [False, True]:
- config = Config(b, s, s2, -1, n, n2, h)
- past_kv_format = Formats.BNSH
- all_close = parity_check_gqa_past(
- config,
- local=local,
- past_format=past_kv_format,
- rtol=RTOL,
- atol=ATOL,
- rotary=rotary,
- rotary_interleaved=rotary_interleaved,
- packed=packed,
- )
- self.assertTrue(all_close)
- all_close = parity_check_gqa_past_no_buff(
- config,
- local=local,
- past_format=past_kv_format,
- rtol=RTOL,
- atol=ATOL,
- rotary=rotary,
- rotary_interleaved=rotary_interleaved,
- packed=packed,
- )
- self.assertTrue(all_close)
+ for has_position_ids, has_attention_bias in pos_ids_attn_bias:
+ config = Config(b, s, s2, -1, n, n2, h, has_position_ids, has_attention_bias)
+ past_kv_format = Formats.BNSH
+ all_close = parity_check_gqa_past(
+ config,
+ local=local,
+ past_format=past_kv_format,
+ rtol=RTOL,
+ atol=ATOL,
+ rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
+ packed=packed,
+ )
+ self.assertTrue(all_close)
+ all_close = parity_check_gqa_past_no_buff(
+ config,
+ local=local,
+ past_format=past_kv_format,
+ rtol=RTOL,
+ atol=ATOL,
+ rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
+ packed=packed,
+ )
+ self.assertTrue(all_close)
if __name__ == "__main__":