diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 15864a0198161..87387d4f281ed 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -27,6 +27,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/activate.cpp ${MLAS_SRC_DIR}/logistic.cpp ${MLAS_SRC_DIR}/tanh.cpp + ${MLAS_SRC_DIR}/eltwise.h + ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp ${MLAS_SRC_DIR}/compute.cpp ${MLAS_SRC_DIR}/quantize.cpp @@ -101,6 +103,9 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/softmax_kernel_neon.h ${MLAS_SRC_DIR}/softmax_kernel_neon.cpp ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon.h + ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ) set(mlas_platform_preprocess_srcs @@ -387,6 +392,8 @@ else() ${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/softmax_kernel_neon.h ${MLAS_SRC_DIR}/softmax_kernel_neon.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon.h + ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") @@ -409,6 +416,7 @@ else() ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -423,6 +431,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 274531faaf717..f85ed1e5f146c 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2551,7 +2551,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Softcap value for attention weights. Default value is 0.
-#### Inputs (7 - 9) +#### Inputs (7 - 11)
query : T
@@ -2572,6 +2572,10 @@ This version of the operator has been available since version 1 of the 'com.micr
2D tensor with shape (max_sequence_length, head_size / 2).
sin_cache (optional) : T
2D tensor with shape (max_sequence_length, head_size / 2).
+
position_ids (optional) : tensor(int64)
+
2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element
+
attention_bias (optional) : T
+
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 84b9c7c9fc174..1dd145463367b 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -520,7 +520,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -922,7 +922,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1399,7 +1399,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 188fc6e43b5b5..ac32a4445f3ca 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -31,6 +31,11 @@ void ComputeAttentionSoftcapInplace(T* scores, int sequence_length, T softcap) { MlasComputeSoftcap(scores, scores, sequence_length, softcap); } +template +void ApplyAttentionBias(T* softmax_logits, const T* attention_mask, int N) { + MlasEltwiseAdd(softmax_logits, attention_mask, softmax_logits, N); +} + template void PrepareMask(const int32_t* mask_index, gsl::span mask_index_dims, diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 70d66e534ee8a..ff6cb8edc0231 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -50,6 +50,7 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH + const Tensor* attention_bias, // Attention bias to add to QxK' const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) Tensor* output, // output tensor @@ -87,14 +88,18 @@ class GQAAttentionBase { const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; + const T* attention_bias_data = attention_bias != nullptr ? attention_bias->Data() : nullptr; + auto attention_bias_shape = attention_bias != nullptr ? attention_bias->Shape().GetDims() : gsl::span{}; + bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; if (gqa_mlas_supported) { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, - sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, + head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, + tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -104,9 +109,10 @@ class GQAAttentionBase { hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); } else { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, - sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, + head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, + tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -126,22 +132,24 @@ class GQAAttentionBase { // attention_probs(B, N, S, T) = Softmax(attention_probs) // If T is float32, U is float32. If T is float16, U could be float16 or float32. template - void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - const int32_t* seqlens_k, // total - 1 sequence lengths tensor - const size_t batch_size, // batch size of self-attention - const size_t sequence_length, // sequence length of self-attention (S) - const size_t past_buffer_sequence_length, // sequence length of past state - const size_t present_buffer_sequence_length, // sequence length of present state - const size_t head_size, // head size of self-attention - const T* past_key, // past key only - T* present_key, // present key only - const bool past_present_share_buffer, // whether present key and value share the same buffer - const bool packed_qkv, // whether Q, K, V are packed - const bool is_prompt, // whether it is prompt - ThreadPool* tp, // thread pool - AllocatorPtr allocator) const { // allocator for temporary buffer + void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const T* attention_bias, // optional attention bias + const size_t batch_size, // batch size of self-attention + const size_t sequence_length, // sequence length of self-attention (S) + const gsl::span attention_bias_shape, // shape of the attention bias + const size_t past_buffer_sequence_length, // sequence length of past state + const size_t present_buffer_sequence_length, // sequence length of present state + const size_t head_size, // head size of self-attention + const T* past_key, // past key only + T* present_key, // present key only + const bool past_present_share_buffer, // whether present key and value share the same buffer + const bool packed_qkv, // whether Q, K, V are packed + const bool is_prompt, // whether it is prompt + ThreadPool* tp, // thread pool + AllocatorPtr allocator) const { // allocator for temporary buffer const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); @@ -189,6 +197,24 @@ class GQAAttentionBase { const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; U* output = attention_probs + output_offset; + // Compute attention bias offset based on the batch and head indexes + // Attention bias is of shape (B or 1, H or 1, S, T) so handle broadcasting + const T* attention_bias_thread = nullptr; + ptrdiff_t attention_total_seqlen = 0; + if (attention_bias != nullptr) { + ptrdiff_t attention_bias_offset = 0; + attention_total_seqlen = static_cast(attention_bias_shape[3]); + const ptrdiff_t attention_matrix_size = sequence_length * attention_total_seqlen; + if (attention_bias_shape[0] != 1) { + attention_bias_offset += SafeInt(batch_index) * attention_bias_shape[1] * attention_matrix_size; + } + if (attention_bias_shape[1] != 1) { + attention_bias_offset += SafeInt(head_index) * attention_matrix_size; + } + + attention_bias_thread = attention_bias + attention_bias_offset; + } + const T* k; if (packed_qkv) { k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); @@ -242,7 +268,15 @@ class GQAAttentionBase { U* output_softmax = output; for (size_t seq = 0; seq < sequence_length; seq++) { size_t seq_causal_length = past_seqlen + seq + 1; - if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) { + + const bool should_apply_local_window = local_window_size_ > 0 && + seq_causal_length > static_cast(local_window_size_) + 1; + + const size_t start_offset = should_apply_local_window ? seq_causal_length - local_window_size_ - 1 : 0; + const size_t window_size = should_apply_local_window ? local_window_size_ + 1 : seq_causal_length; + + // Mask everything before local window, if local window should be applied + if (should_apply_local_window) { for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { if constexpr (std::is_same::value) { output_softmax[total_seq_id] = 0.f; @@ -250,27 +284,34 @@ class GQAAttentionBase { output_softmax[total_seq_id] = MLFloat16::FromBits(static_cast(0)); } } - if (softcap_ > 0.f) { - ComputeAttentionSoftcapInplace(output_softmax + seq_causal_length - local_window_size_ - 1, - local_window_size_ + 1, static_cast(softcap_)); - } - if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, - local_window_size_ + 1, nullptr); + } + + if (softcap_ > 0.f) { + ComputeAttentionSoftcapInplace(output_softmax + start_offset, static_cast(window_size), + static_cast(softcap_)); + } + + // Add attention bias to QxK' if provided + // TODO (#23982): Implement bias addition during softmax computation in GQA CPU operator + if (attention_bias_thread != nullptr) { + if constexpr (std::is_same_v) { + ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread + start_offset, + static_cast(window_size)); } else { - ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, - local_window_size_ + 1, nullptr); + static_assert(std::is_same_v && std::is_same_v); + size_t bytes = window_size * sizeof(float); + auto attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); + BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); + + MlasConvertHalfToFloatBuffer(attention_bias_thread + start_offset, attention_bias_thread_fp32, window_size); + ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread_fp32, static_cast(window_size)); } + } + + if (use_smooth_softmax_) { + ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); } else { - if (softcap_ > 0.f) { - ComputeAttentionSoftcapInplace(output_softmax, static_cast(seq_causal_length), - static_cast(softcap_)); - } - if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); - } else { - ComputeAttentionSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); - } + ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); } // set causal [seq_causal_length, total_seqlen) to 0.f @@ -283,6 +324,10 @@ class GQAAttentionBase { } output_softmax += present_buffer_sequence_length; + + if (attention_bias_thread != nullptr) { + attention_bias_thread += attention_total_seqlen; + } } } }); diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 8f662cd388c6d..9c7530f0126bb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -52,6 +52,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* total_seqlen_tensor = context->Input(6); const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); + const Tensor* position_ids = context->Input(9); + const Tensor* attention_bias = context->Input(10); GroupQueryAttentionParameters parameters = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, @@ -69,6 +71,10 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { scale_, softcap_)); + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, + attention_bias, + parameters)); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int present_kv_seqlen = parameters.seqlen_present_kv_cache; @@ -129,9 +135,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { auto* tp = context->GetOperatorThreadPool(); // Generate position ids const int pos_ids_size = parameters.is_first_prompt ? 1 : batch_size * sequence_length; - std::vector pos_ids(pos_ids_size); - if (parameters.is_first_prompt) { - pos_ids[0] = static_cast(0); + std::vector default_pos_ids(pos_ids_size); + const int64_t* pos_ids_data = default_pos_ids.data(); + + if (position_ids != nullptr) { + pos_ids_data = position_ids->Data(); + } else if (parameters.is_first_prompt) { + default_pos_ids[0] = static_cast(0); } else { // Note: As of now, continuous decoding supports only batch size 1 and token generation supports only sequence length 1. for (int b = 0; b < batch_size; b++) { @@ -139,13 +149,14 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const int past_seqlen = total_seqlen - sequence_length; for (int s = 0; s < sequence_length; s++) { if (past_seqlen + s < total_seqlen) { - pos_ids[b * sequence_length + s] = static_cast(past_seqlen) + s; + default_pos_ids[b * sequence_length + s] = static_cast(past_seqlen) + s; } else { - pos_ids[b * sequence_length + s] = static_cast(1); + default_pos_ids[b * sequence_length + s] = static_cast(1); } } } } + // Initialize separate buffers for rotary embeddings const T* q_input; const T* k_input; @@ -165,7 +176,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { } // Run rotary embedding for Q and K ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input, - pos_ids.data(), cos_cache->Data(), + pos_ids_data, cos_cache->Data(), sin_cache->Data(), q_rotary, rotary_interleaved_)); rotary_params.num_heads = kv_num_heads_; @@ -174,7 +185,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { rotary_params.batch_stride = kv_num_heads_ * rotary_params.head_stride; } ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input, - pos_ids.data(), cos_cache->Data(), + pos_ids_data, cos_cache->Data(), sin_cache->Data(), k_rotary, rotary_interleaved_)); // Pack V into rotary QKV buffer if (packed_qkv) { @@ -192,9 +203,10 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { } ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + // Compute the attention score and apply the score to V return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - past_key, past_value, output, present_k, present_v, + attention_bias, past_key, past_value, output, present_k, present_v, seqlens_k, parameters, allocator, context); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 4cc5a4228dc8c..7bffd768c8f7c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -288,6 +288,50 @@ Status CheckInputs(const T* query, return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap); } + +template +Status CheckCustomAttentionInputs(const T* position_ids, + const T* attention_bias, + const GroupQueryAttentionParameters& parameters) { + if (position_ids != nullptr) { + const auto& pos_ids_shape = position_ids->Shape(); + if (pos_ids_shape[0] != parameters.batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "position_ids dimension 0 must be equal to the batch size, got ", pos_ids_shape[0]); + } + + if (pos_ids_shape[1] < parameters.sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "position_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]); + } + } + + if (attention_bias != nullptr) { + const auto& attn_bias_shape = attention_bias->Shape(); + if ((attn_bias_shape[0] != parameters.batch_size) && (attn_bias_shape[0] != 1)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "attention_bias dimension 0 must be equal to the batch size or 1, got ", attn_bias_shape[0]); + } + + if ((attn_bias_shape[1] != parameters.num_heads) && (attn_bias_shape[1] != 1)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "attention_bias dimension 1 must be equal to the num heads or 1, got ", attn_bias_shape[1]); + } + + if (attn_bias_shape[2] != parameters.sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "attention_bias dimension 2 must be equal to the sequence length, got ", attn_bias_shape[2]); + } + + if (attn_bias_shape[3] != parameters.total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "attention_bias dimension 3 must be equal to total_sequence_length, got ", attn_bias_shape[3]); + } + } + + return Status::OK(); +} + } // namespace group_query_attention_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index ecc8cb091b1b6..718dd9a4397b5 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1128,6 +1128,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "2D tensor with shape (max_sequence_length, head_size / 2).", "T", OpSchema::Optional) + .Input(9, + "position_ids", + "2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel " + "uses only the first element", + "tensor(int64)", + OpSchema::Optional) + .Input(10, + "attention_bias", + "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 1401e27ca77e5..8033eab8262a0 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1030,6 +1030,16 @@ MlasComputeSoftcap( T cap ); +template +void +MLASCALL +MlasEltwiseAdd( + const T* left, + const T* right, + T* output, + size_t N + ); + template void MLASCALL diff --git a/onnxruntime/core/mlas/lib/eltwise.cpp b/onnxruntime/core/mlas/lib/eltwise.cpp new file mode 100644 index 0000000000000..f63d71b40bfbb --- /dev/null +++ b/onnxruntime/core/mlas/lib/eltwise.cpp @@ -0,0 +1,71 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise.cpp + +Abstract: + + This module implements routines to compute element-wise operations on two vectors. + + Currently supported element-wise operations: + - Add + +--*/ + +#include "mlasi.h" +#include "eltwise.h" + +template <> +void +MLASCALL +MlasEltwiseAdd( + const float* left, + const float* right, + float* output, + size_t N +) { + while (N > 0) { + if (N >= 4) { + MLAS_FLOAT32X4 LeftVec = MlasLoadFloat32x4(left); + MLAS_FLOAT32X4 RightVec = MlasLoadFloat32x4(right); + + MLAS_FLOAT32X4 ResultVec = MlasAddFloat32x4(LeftVec, RightVec); + + MlasStoreFloat32x4(output, ResultVec); + + left += 4; + right += 4; + output += 4; + N -= 4; + } else { + *output = *left + *right; + + left += 1; + right += 1; + output += 1; + N -= 1; + } + } +} + + +template <> +void +MLASCALL +MlasEltwiseAdd( + const MLAS_FP16* left, + const MLAS_FP16* right, + MLAS_FP16* output, + size_t N +) { + const auto* dispatch = GetMlasPlatform().EltwiseDispatch; + if (dispatch == nullptr || dispatch->Add_Fp16 == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Add_Fp16 is not supported."); + } + dispatch->Add_Fp16(left, right, output, N); +} diff --git a/onnxruntime/core/mlas/lib/eltwise.h b/onnxruntime/core/mlas/lib/eltwise.h new file mode 100644 index 0000000000000..a8345c499f6b7 --- /dev/null +++ b/onnxruntime/core/mlas/lib/eltwise.h @@ -0,0 +1,37 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + element-wise operations. + +--*/ +#pragma once + +#include "mlasi.h" + +struct MLAS_ELTWISE_DISPATCH { + /** + * @brief Compute the element-wise addition of the two given vectors + * @param left Address of the left operand + * @param right Address of the right operand + * @param output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input arrays + */ + typedef void(Add_Fp16_Fn)( + const MLAS_FP16* left, + const MLAS_FP16* right, + MLAS_FP16* output, + size_t N + ); + + Add_Fp16_Fn* Add_Fp16 = nullptr; +}; diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp new file mode 100644 index 0000000000000..415c1281c808e --- /dev/null +++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp @@ -0,0 +1,32 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon.cpp + +Abstract: + + This module implements the element-wise kernels for ARM NEON. + +--*/ + +#include "eltwise.h" +#include "eltwise_kernel_neon.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_ELTWISE_DISPATCH MlasEltwiseDispatchNeon = []() { + MLAS_ELTWISE_DISPATCH d; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + if (MlasFp16AccelerationSupported()) { + d.Add_Fp16 = eltwise_neon::Add_Kernel_Fp16; + } +#endif + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h new file mode 100644 index 0000000000000..d99a3e97c21f2 --- /dev/null +++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h @@ -0,0 +1,28 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + element-wise operations on ARM cpu. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" + +namespace eltwise_neon { + +void Add_Kernel_Fp16(const MLAS_FP16* left, const MLAS_FP16* right, MLAS_FP16* output, size_t N); + +} // namespace eltwise_neon diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp new file mode 100644 index 0000000000000..decbdb576d5cd --- /dev/null +++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp @@ -0,0 +1,118 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon_fp16.cpp + +Abstract: + + This module implements the fp16 element-wise kernels for ARM NEON. + +--*/ +#include +#include + +#include "fp16_common.h" +#include "eltwise.h" +#include "eltwise_kernel_neon.h" + +namespace eltwise_neon { + +void Add_Kernel_Fp16(const MLAS_FP16* left, const MLAS_FP16* right, MLAS_FP16* output, size_t N) { + const auto* left_fp16 = reinterpret_cast(left); + const auto* right_fp16 = reinterpret_cast(right); + auto* output_fp16 = reinterpret_cast<_mlas_fp16_*>(output); + + while (N >= 32) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto l1 = MlasLoadFloat16x8(left_fp16 + 8); + auto l2 = MlasLoadFloat16x8(left_fp16 + 16); + auto l3 = MlasLoadFloat16x8(left_fp16 + 24); + + auto r0 = MlasLoadFloat16x8(right_fp16); + auto r1 = MlasLoadFloat16x8(right_fp16 + 8); + auto r2 = MlasLoadFloat16x8(right_fp16 + 16); + auto r3 = MlasLoadFloat16x8(right_fp16 + 24); + + auto o0 = MlasAddFloat16(l0, r0); + auto o1 = MlasAddFloat16(l1, r1); + auto o2 = MlasAddFloat16(l2, r2); + auto o3 = MlasAddFloat16(l3, r3); + + MlasStoreFloat16x8(output_fp16, o0); + MlasStoreFloat16x8(output_fp16 + 8, o1); + MlasStoreFloat16x8(output_fp16 + 16, o2); + MlasStoreFloat16x8(output_fp16 + 24, o3); + + left_fp16 += 32; + right_fp16 += 32; + output_fp16 += 32; + N -= 32; + } + + if (N & 16) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto l1 = MlasLoadFloat16x8(left_fp16 + 8); + + auto r0 = MlasLoadFloat16x8(right_fp16); + auto r1 = MlasLoadFloat16x8(right_fp16 + 8); + + auto o0 = MlasAddFloat16(l0, r0); + auto o1 = MlasAddFloat16(l1, r1); + + MlasStoreFloat16x8(output_fp16, o0); + MlasStoreFloat16x8(output_fp16 + 8, o1); + + left_fp16 += 16; + right_fp16 += 16; + output_fp16 += 16; + N -= 16; + } + + if (N & 8) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto r0 = MlasLoadFloat16x8(right_fp16); + auto o0 = MlasAddFloat16(l0, r0); + MlasStoreFloat16x8(output_fp16, o0); + + left_fp16 += 8; + right_fp16 += 8; + output_fp16 += 8; + N -= 8; + } + + if (N & 4) { + auto l0 = MlasLoadFloat16x4(left_fp16); + auto r0 = MlasLoadFloat16x4(right_fp16); + auto o0 = MlasAddFloat16(l0, r0); + MlasStoreFloat16x4(output_fp16, o0); + + left_fp16 += 4; + right_fp16 += 4; + output_fp16 += 4; + N -= 4; + } + + if (N == 3) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 3); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 3); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 3); + } else if (N == 2) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 2); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 2); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 2); + } else if (N == 1) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 1); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 1); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 1); + } +} + +} // namespace eltwise_neon diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0681b49252495..8e704b5b94801 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1070,6 +1070,10 @@ extern const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon; struct MLAS_SOFTMAX_DISPATCH; extern const MLAS_SOFTMAX_DISPATCH MlasSoftmaxDispatchNeon; +// eltwise dispatch structure +struct MLAS_ELTWISE_DISPATCH; +extern const MLAS_ELTWISE_DISPATCH MlasEltwiseDispatchNeon; + // // Quantized depthwise convolution kernels. // @@ -1233,6 +1237,7 @@ struct MLAS_PLATFORM { const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; const MLAS_HGEMM_DISPATCH* HGemmDispatch{nullptr}; const MLAS_SOFTMAX_DISPATCH* SoftmaxDispatch{nullptr}; + const MLAS_ELTWISE_DISPATCH* EltwiseDispatch{nullptr}; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 582c1ab944b98..312a624fd160c 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -547,6 +547,7 @@ Return Value: this->RopeDispatch = &MlasRopeDispatchNeon; this->HGemmDispatch = &MlasHGemmDispatchNeon; this->SoftmaxDispatch = &MlasSoftmaxDispatchNeon; + this->EltwiseDispatch = &MlasEltwiseDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. diff --git a/onnxruntime/test/mlas/unittest/test_eltwise.cpp b/onnxruntime/test/mlas/unittest/test_eltwise.cpp new file mode 100644 index 0000000000000..c4d4b9c0eb317 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_eltwise.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/eltwise.h" + +class MlasEltwiseAddTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInputLeft; + MatrixGuardBuffer BufferInputRight; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + MatrixGuardBuffer BufferInputLeftFp16; + MatrixGuardBuffer BufferInputRightFp16; + MatrixGuardBuffer BufferOutputFp16; + + void Test(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) { + float* InputLeft = BufferInputLeft.GetBuffer(N); + float* InputRight = BufferInputRight.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + InputLeft[n] = distribution(generator); + InputRight[n] = ScalarValue.value_or(distribution(generator)); + } + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = InputLeft[n] + InputRight[n]; + } + + MlasEltwiseAdd(InputLeft, InputRight, Output, N); + + constexpr float AbsoluteTolerance = 1e-6f; + constexpr float RelativeTolerance = 1e-6f; + + for (size_t n = 0; n < N; n++) { + float diff = std::fabs(Output[n] - OutputReference[n]); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[n]) * RelativeTolerance) + << " @" << n << " of " << N << ", got: " << Output[n] << ", expecting: " << OutputReference[n]; + } + } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + void TestFp16(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) { + MLAS_FP16* InputLeft = BufferInputLeftFp16.GetBuffer(N); + MLAS_FP16* InputRight = BufferInputRightFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + InputLeft[n] = MLAS_FP16(distribution(generator)); + InputRight[n] = MLAS_FP16(ScalarValue.value_or(distribution(generator))); + } + + MlasEltwiseAdd(InputLeft, InputRight, Output, N); + + constexpr float AbsoluteTolerance = 5e-4f; + constexpr float RelativeTolerance = 1e-3f; + + for (size_t n = 0; n < N; n++) { + float inLeft = InputLeft[n].ToFloat(); + float inRight = InputRight[n].ToFloat(); + float ref = inLeft + inRight; + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << inLeft << ", " << inRight << ", got: " << out << ", expecting: " << ref + << ", r-diff: " << diff / std::fabs(ref); + } + } + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Eltwise_Add"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { + Test(n, -10.f, 10.f); + Test(n, -10.f, 10.f, -5000.f); +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(n, -17.f, 11.f); + TestFp16(n, -17.f, 11.f, -5000.f); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 77b4b326bf645..1239affcc04de 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -12,6 +12,7 @@ import math import random import unittest +from dataclasses import dataclass import numpy import torch @@ -41,42 +42,30 @@ class Formats: BNSH = 1 +@dataclass class Config: - batch_size = 0 - sequence_length = 0 - kv_sequence_length = 0 - past_sequence_length = 0 - num_heads = 0 - kv_num_heads = 0 - head_size = 0 - - def __init__(self, b, s, s2, sp, n, n2, h): - self.batch_size = b - self.sequence_length = s - self.kv_sequence_length = s2 - self.past_sequence_length = sp - self.num_heads = n - self.kv_num_heads = n2 - self.head_size = h - - + batch_size: int = 0 + sequence_length: int = 0 + kv_sequence_length: int = 0 + past_sequence_length: int = 0 + num_heads: int = 0 + kv_num_heads: int = 0 + head_size: int = 0 + has_position_ids: bool = False + has_attention_bias: bool = False + + +@dataclass class PromptConfig: - batch_size = 0 - q_sequence_length = 0 - kv_sequence_length = 0 - buffer_sequence_length = 0 - num_heads = 0 - kv_num_heads = 0 - head_size = 0 - - def __init__(self, b, sq, skv, sb, n, n2, h): - self.batch_size = b - self.q_sequence_length = sq - self.kv_sequence_length = skv - self.buffer_sequence_length = sb - self.num_heads = n - self.kv_num_heads = n2 - self.head_size = h + batch_size: int = 0 + q_sequence_length: int = 0 + kv_sequence_length: int = 0 + buffer_sequence_length: int = 0 + num_heads: int = 0 + kv_num_heads: int = 0 + head_size: int = 0 + has_position_ids: bool = False + has_attention_bias: bool = False # LLaMA Microsoft model @@ -173,6 +162,8 @@ def create_group_query_attention_graph_prompt( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", + "position_ids" if config.has_position_ids else "", + "attention_bias" if config.has_attention_bias else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -278,6 +269,24 @@ def create_group_query_attention_graph_prompt( ), ] + if config.has_position_ids: + graph_input += [ + helper.make_tensor_value_info( + "position_ids", + TensorProto.INT64, + [config.batch_size, config.kv_sequence_length], + ), + ] + + if config.has_attention_bias: + graph_input += [ + helper.make_tensor_value_info( + "attention_bias", + ORT_TYPE, + [config.batch_size, 1, config.kv_sequence_length, config.kv_sequence_length], + ), + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -334,6 +343,7 @@ def create_group_query_attention_graph_prompt( ) model = helper.make_model(graph) + return model.SerializeToString() @@ -365,6 +375,8 @@ def create_group_query_attention_graph_past( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", + "position_ids" if config.has_position_ids else "", + "attention_bias" if config.has_attention_bias else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -467,6 +479,22 @@ def create_group_query_attention_graph_past( ), ] + if config.has_position_ids: + graph_input += [ + helper.make_tensor_value_info( + "position_ids", TensorProto.INT64, [config.batch_size, config.sequence_length] + ), + ] + + if config.has_attention_bias: + graph_input += [ + helper.make_tensor_value_info( + "attention_bias", + ORT_TYPE, + [config.batch_size, 1, config.sequence_length, present_kv_seqlen], + ), + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -681,6 +709,8 @@ def gqa_prompt_func( cos=None, sin=None, seqlens_k=None, + position_ids=None, + attention_bias=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True, @@ -699,9 +729,17 @@ def gqa_prompt_func( softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) + q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None + + if config.has_position_ids: + assert position_ids is not None + + if config.has_attention_bias: + assert attention_bias is not None + if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) @@ -713,6 +751,7 @@ def gqa_prompt_func( "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } + sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() @@ -726,6 +765,15 @@ def gqa_prompt_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + if config.has_position_ids: + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() @@ -767,6 +815,15 @@ def gqa_prompt_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + if config.has_position_ids: + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) @@ -790,6 +847,8 @@ def gqa_past_func( cos=None, sin=None, seqlens_k=None, + position_ids=None, + attention_bias=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1, @@ -812,6 +871,13 @@ def gqa_past_func( q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() + + if config.has_position_ids: + assert position_ids is not None + + if config.has_attention_bias: + assert attention_bias is not None + if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) @@ -839,6 +905,15 @@ def gqa_past_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + if config.has_position_ids: + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() @@ -887,6 +962,15 @@ def gqa_past_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + if config.has_position_ids: + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) @@ -1056,6 +1140,41 @@ def attention_qkvpacked_ref( ) +def get_custom_attention_bias(batch_size, sequence_length, total_seq_len, seqlens_k=None, past=False): + if past: + assert seqlens_k is not None + attention_bias = torch.zeros((batch_size, 1, sequence_length, total_seq_len), dtype=TORCH_TYPE) + for b in range(batch_size): + total_seq_len = seqlens_k[b] + 1 + past_seq_len = total_seq_len - sequence_length + + # Configure bias + for i in range(sequence_length): + for j in range(past_seq_len + i + 1, total_seq_len): + attention_bias[b][0][i][j] = -5000 + else: + attention_bias = torch.rand(batch_size, 1, sequence_length, total_seq_len, dtype=TORCH_TYPE) + attention_bias = torch.triu(attention_bias, diagonal=1) + + return attention_bias + + +def get_custom_position_ids(batch_size, sequence_length, seqlens_k=None, past=False): + if past: + assert seqlens_k is not None + position_ids_data = [] + for b in range(batch_size): + total_seq_len = seqlens_k[b] + 1 + past_seq_len = total_seq_len - sequence_length + position_ids_data.append(list(range(past_seq_len, past_seq_len + sequence_length))) + + position_ids = torch.tensor(data=position_ids_data, dtype=torch.int64) + else: + position_ids = torch.zeros((batch_size, sequence_length), dtype=torch.int64) + + return position_ids + + def parity_check_gqa_prompt( config, causal=True, @@ -1087,6 +1206,7 @@ def parity_check_gqa_prompt( dtype=TORCH_TYPE, requires_grad=False, ) + v = torch.randn( config.batch_size, config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, @@ -1154,6 +1274,19 @@ def parity_check_gqa_prompt( cos, sin = None, None q_ro, k_ro = q, new_k + position_ids = ( + get_custom_position_ids(config.batch_size, config.kv_sequence_length, seqlens_k=None, past=False) + if config.has_position_ids + else None + ) + attention_bias = ( + get_custom_attention_bias( + config.batch_size, config.kv_sequence_length, config.q_sequence_length, seqlens_k=None, past=False + ) + if config.has_attention_bias + else None + ) + rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") @@ -1184,6 +1317,7 @@ def parity_check_gqa_prompt( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function + # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) out, present_k, present_v = gqa_prompt_func( @@ -1195,7 +1329,9 @@ def parity_check_gqa_prompt( None, cos, sin, - cache_seqlens, + cache_seqlens - 1, + position_ids, + attention_bias, left_window_size, past_format, True, @@ -1213,7 +1349,9 @@ def parity_check_gqa_prompt( new_v, cos, sin, - cache_seqlens, + cache_seqlens - 1, + position_ids, + attention_bias, left_window_size, past_format, True, @@ -1262,6 +1400,10 @@ def parity_check_gqa_prompt( config.kv_num_heads, " h:", config.head_size, + " has_position_ids:", + config.has_position_ids, + " has_attention_bias:", + config.has_attention_bias, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1347,6 +1489,19 @@ def parity_check_gqa_prompt_no_buff( q_ro, k_ro = q, k_cache_ref k_cache_ref = k_ro + position_ids = ( + get_custom_position_ids(config.batch_size, config.kv_sequence_length, seqlens_k=None, past=False) + if config.has_position_ids + else None + ) + attention_bias = ( + get_custom_attention_bias( + config.batch_size, config.kv_sequence_length, config.q_sequence_length, seqlens_k=None, past=False + ) + if config.has_attention_bias + else None + ) + brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded @@ -1371,6 +1526,7 @@ def parity_check_gqa_prompt_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function + # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) out, present_k, present_v = gqa_prompt_func( @@ -1383,6 +1539,8 @@ def parity_check_gqa_prompt_no_buff( cos, sin, cache_seqlens - 1, + position_ids, + attention_bias, left_window_size, past_format, False, @@ -1401,6 +1559,8 @@ def parity_check_gqa_prompt_no_buff( cos, sin, cache_seqlens - 1, + position_ids, + attention_bias, left_window_size, past_format, False, @@ -1449,6 +1609,10 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, " h:", config.head_size, + " has_position_ids:", + config.has_position_ids, + " has_attention_bias:", + config.has_attention_bias, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1589,6 +1753,19 @@ def parity_check_gqa_past( cache_seqlens += config.sequence_length - 1 + position_ids = ( + get_custom_position_ids(config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True) + if config.has_position_ids + else None + ) + attention_bias = ( + get_custom_attention_bias( + config.batch_size, config.sequence_length, config.kv_sequence_length, seqlens_k=cache_seqlens, past=True + ) + if config.has_attention_bias + else None + ) + # ORT function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1602,6 +1779,8 @@ def parity_check_gqa_past( cos, sin, cache_seqlens, + position_ids, + attention_bias, past_format, True, left_window_size, @@ -1620,6 +1799,8 @@ def parity_check_gqa_past( cos, sin, cache_seqlens, + position_ids, + attention_bias, past_format, True, left_window_size, @@ -1668,6 +1849,10 @@ def parity_check_gqa_past( config.kv_num_heads, " h:", config.head_size, + " has_position_ids:", + config.has_position_ids, + " has_attention_bias:", + config.has_attention_bias, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1814,6 +1999,23 @@ def parity_check_gqa_past_no_buff( cache_seqlens += config.sequence_length - 1 + position_ids = ( + get_custom_position_ids(config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True) + if config.has_position_ids + else None + ) + attention_bias = ( + get_custom_attention_bias( + config.batch_size, + config.sequence_length, + config.kv_sequence_length + config.sequence_length, + seqlens_k=cache_seqlens, + past=True, + ) + if config.has_attention_bias + else None + ) + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1827,6 +2029,8 @@ def parity_check_gqa_past_no_buff( cos, sin, cache_seqlens, + position_ids, + attention_bias, past_format, False, window_size=left_window_size, @@ -1845,6 +2049,8 @@ def parity_check_gqa_past_no_buff( cos, sin, cache_seqlens, + position_ids, + attention_bias, past_format, False, window_size=left_window_size, @@ -1889,6 +2095,10 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, " h:", config.head_size, + " has_position_ids:", + config.has_position_ids, + " has_attention_bias:", + config.has_attention_bias, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1916,6 +2126,11 @@ def test_gqa_no_past(self): (8000, 8000), ] ) + pos_ids_attn_bias = ( + [(False, False), (True, True)] + if pipeline_mode + else [(False, False), (True, True), (False, True), (True, False)] + ) num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] for b in batches: @@ -1927,30 +2142,41 @@ def test_gqa_no_past(self): for packed in [False, True]: for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_prompt( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_prompt_no_buff( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) + for has_position_ids, has_attention_bias in pos_ids_attn_bias: + config = PromptConfig( + b, + sq, + skv, + sq + skv + 8, + n, + n2, + h, + has_position_ids, + has_attention_bias, + ) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) def test_gqa_past(self): print("-------- TEST GQA PAST (TOKEN GEN) ---------") @@ -1972,6 +2198,11 @@ def test_gqa_past(self): # (128, 128), ] ) + pos_ids_attn_bias = ( + [(False, False), (True, True)] + if pipeline_mode + else [(False, False), (True, True), (False, True), (True, False)] + ) num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) @@ -1984,35 +2215,38 @@ def test_gqa_past(self): for packed in [False, True]: for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=RTOL, - atol=ATOL, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=RTOL, - atol=ATOL, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) + for has_position_ids, has_attention_bias in pos_ids_attn_bias: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config( + b, s, s2, sp, n, n2, h, has_position_ids, has_attention_bias + ) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=RTOL, + atol=ATOL, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=RTOL, + atol=ATOL, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") @@ -2034,6 +2268,11 @@ def test_gqa_interactive_one_batch(self): # (128, 128), ] ) + pos_ids_attn_bias = ( + [(False, False), (True, True)] + if pipeline_mode + else [(False, False), (True, True), (False, True), (True, False)] + ) num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) @@ -2044,30 +2283,31 @@ def test_gqa_interactive_one_batch(self): for local in [False, True]: for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: for packed in [False, True]: - config = Config(b, s, s2, -1, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=RTOL, - atol=ATOL, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=RTOL, - atol=ATOL, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) + for has_position_ids, has_attention_bias in pos_ids_attn_bias: + config = Config(b, s, s2, -1, n, n2, h, has_position_ids, has_attention_bias) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=RTOL, + atol=ATOL, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=RTOL, + atol=ATOL, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) if __name__ == "__main__":