From d7f5aa1af3dc8dd818efa68a37931c8d861008ec Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Tue, 4 Mar 2025 17:10:47 +0100 Subject: [PATCH 01/25] Scalar support for custom position ids and mask in GQA --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 76 ++++++--- .../cpu/bert/group_query_attention.cc | 18 +- .../core/graph/contrib_ops/bert_defs.cc | 10 ++ .../test/python/transformers/test_gqa_cpu.py | 160 ++++++++++++++---- 4 files changed, 200 insertions(+), 64 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 70d66e534ee8a..39df9bd5302bc 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -50,6 +50,7 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH + const Tensor* attention_mask, // Causal attention mask to apply before const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) Tensor* output, // output tensor @@ -87,13 +88,16 @@ class GQAAttentionBase { const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; + const T* attention_mask_data = attention_mask != nullptr ? attention_mask->Data() : nullptr; + const size_t attention_mask_total_seqlen = attention_mask != nullptr ? attention_mask->Shape()[2] : 0; + bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; if (gqa_mlas_supported) { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, - sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_mask_data, batch_size, + sequence_length, attention_mask_total_seqlen, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) @@ -104,9 +108,10 @@ class GQAAttentionBase { hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); } else { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, - sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_mask_data, + batch_size, sequence_length, attention_mask_total_seqlen, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, + past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, + allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -130,8 +135,10 @@ class GQAAttentionBase { const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const T* attention_mask, const size_t batch_size, // batch size of self-attention const size_t sequence_length, // sequence length of self-attention (S) + const size_t attention_mask_total_seqlen, // max total seqlen in batch used for attention last dim const size_t past_buffer_sequence_length, // sequence length of past state const size_t present_buffer_sequence_length, // sequence length of present state const size_t head_size, // head size of self-attention @@ -189,6 +196,9 @@ class GQAAttentionBase { const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; U* output = attention_probs + output_offset; + const ptrdiff_t attention_mask_offset = SafeInt(batch_index) * sequence_length * attention_mask_total_seqlen; + const T* attention_mask_batch = attention_mask != nullptr ? attention_mask + attention_mask_offset : nullptr; + const T* k; if (packed_qkv) { k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); @@ -242,7 +252,27 @@ class GQAAttentionBase { U* output_softmax = output; for (size_t seq = 0; seq < sequence_length; seq++) { size_t seq_causal_length = past_seqlen + seq + 1; - if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) { + + const bool should_apply_local_window = local_window_size_ > 0 && + seq_causal_length > static_cast(local_window_size_) + 1; + + const size_t start_offset = should_apply_local_window ? seq_causal_length - local_window_size_ - 1 : 0; + const size_t window_size = should_apply_local_window ? local_window_size_ + 1 : seq_causal_length; + + // Apply custom attention mask if there is any + // TODO: Vectorize this addition + if (attention_mask_batch != nullptr) { + for (size_t idx = start_offset; idx < start_offset + window_size; idx++) { + if constexpr (std::is_same::value) { + output_softmax[idx] += static_cast(attention_mask_batch[idx]); + } else { + output_softmax[idx] = MLFloat16(output_softmax[idx].ToFloat() + attention_mask_batch[idx].ToFloat()); + } + } + } + + // Mask everything before local window, if local window should be applied + if (should_apply_local_window) { for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { if constexpr (std::is_same::value) { output_softmax[total_seq_id] = 0.f; @@ -250,27 +280,17 @@ class GQAAttentionBase { output_softmax[total_seq_id] = MLFloat16::FromBits(static_cast(0)); } } - if (softcap_ > 0.f) { - ComputeAttentionSoftcapInplace(output_softmax + seq_causal_length - local_window_size_ - 1, - local_window_size_ + 1, static_cast(softcap_)); - } - if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, - local_window_size_ + 1, nullptr); - } else { - ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, - local_window_size_ + 1, nullptr); - } + } + + // Calculate softmax + if (softcap_ > 0.f) { + ComputeAttentionSoftcapInplace(output_softmax + start_offset, static_cast(window_size), + static_cast(softcap_)); + } + if (use_smooth_softmax_) { + ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); } else { - if (softcap_ > 0.f) { - ComputeAttentionSoftcapInplace(output_softmax, static_cast(seq_causal_length), - static_cast(softcap_)); - } - if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); - } else { - ComputeAttentionSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); - } + ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); } // set causal [seq_causal_length, total_seqlen) to 0.f @@ -283,6 +303,10 @@ class GQAAttentionBase { } output_softmax += present_buffer_sequence_length; + + if (attention_mask_batch != nullptr) { + attention_mask_batch += attention_mask_total_seqlen; + } } } }); diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 8f662cd388c6d..4ce383d7652fa 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -53,6 +53,9 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); + const Tensor* custom_pos_ids = context->Input(9); + const Tensor* custom_causal_attention_mask = context->Input(10); + GroupQueryAttentionParameters parameters = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, key, @@ -130,7 +133,12 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Generate position ids const int pos_ids_size = parameters.is_first_prompt ? 1 : batch_size * sequence_length; std::vector pos_ids(pos_ids_size); - if (parameters.is_first_prompt) { + const int64_t* pos_ids_data = pos_ids.data(); + + if (custom_pos_ids != nullptr) { + ORT_RETURN_IF_NOT(pos_ids_size == custom_pos_ids->Shape()[0] * custom_pos_ids->Shape()[1]); + pos_ids_data = custom_pos_ids->Data(); + } else if (parameters.is_first_prompt) { pos_ids[0] = static_cast(0); } else { // Note: As of now, continuous decoding supports only batch size 1 and token generation supports only sequence length 1. @@ -146,6 +154,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { } } } + // Initialize separate buffers for rotary embeddings const T* q_input; const T* k_input; @@ -165,7 +174,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { } // Run rotary embedding for Q and K ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input, - pos_ids.data(), cos_cache->Data(), + pos_ids_data, cos_cache->Data(), sin_cache->Data(), q_rotary, rotary_interleaved_)); rotary_params.num_heads = kv_num_heads_; @@ -174,7 +183,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { rotary_params.batch_stride = kv_num_heads_ * rotary_params.head_stride; } ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input, - pos_ids.data(), cos_cache->Data(), + pos_ids_data, cos_cache->Data(), sin_cache->Data(), k_rotary, rotary_interleaved_)); // Pack V into rotary QKV buffer if (packed_qkv) { @@ -192,9 +201,10 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { } ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + // Compute the attention score and apply the score to V return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - past_key, past_value, output, present_k, present_v, + custom_causal_attention_mask, past_key, past_value, output, present_k, present_v, seqlens_k, parameters, allocator, context); } } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index ecc8cb091b1b6..e724aab81bb8e 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1128,6 +1128,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "2D tensor with shape (max_sequence_length, head_size / 2).", "T", OpSchema::Optional) + .Input(9, + "custom_pos_ids", + "2D tensor with shape (batch_size, sequence_length).", + "tensor(int64)", + OpSchema::Optional) + .Input(10, + "custom_causal_attention_mask", + "3D tensor with shape (batch_size, sequence_length, total_sequence_length)", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 77b4b326bf645..73ee12cc3ec2f 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -9,6 +9,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # ------------------------------------------------------------------------- +from dataclasses import dataclass import math import random import unittest @@ -35,48 +36,34 @@ RTOL = 3e-2 if ORT_TYPE == TensorProto.FLOAT16 else 1e-3 ATOL = RTOL +DO_TREE_ATTENTION = True + class Formats: BSNH = 0 BNSH = 1 +@dataclass class Config: - batch_size = 0 - sequence_length = 0 - kv_sequence_length = 0 - past_sequence_length = 0 - num_heads = 0 - kv_num_heads = 0 - head_size = 0 - - def __init__(self, b, s, s2, sp, n, n2, h): - self.batch_size = b - self.sequence_length = s - self.kv_sequence_length = s2 - self.past_sequence_length = sp - self.num_heads = n - self.kv_num_heads = n2 - self.head_size = h + batch_size: int = 0 + sequence_length: int = 0 + kv_sequence_length: int = 0 + past_sequence_length: int = 0 + num_heads: int = 0 + kv_num_heads: int = 0 + head_size: int = 0 +@dataclass class PromptConfig: - batch_size = 0 - q_sequence_length = 0 - kv_sequence_length = 0 - buffer_sequence_length = 0 - num_heads = 0 - kv_num_heads = 0 - head_size = 0 - - def __init__(self, b, sq, skv, sb, n, n2, h): - self.batch_size = b - self.q_sequence_length = sq - self.kv_sequence_length = skv - self.buffer_sequence_length = sb - self.num_heads = n - self.kv_num_heads = n2 - self.head_size = h + batch_size: int = 0 + q_sequence_length: int = 0 + kv_sequence_length: int = 0 + buffer_sequence_length: int = 0 + num_heads: int = 0 + kv_num_heads: int = 0 + head_size: int = 0 # LLaMA Microsoft model @@ -173,6 +160,8 @@ def create_group_query_attention_graph_prompt( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", + "custom_pos_ids" if DO_TREE_ATTENTION else "", + "custom_causal_attention_mask" if DO_TREE_ATTENTION else "" ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -278,6 +267,20 @@ def create_group_query_attention_graph_prompt( ), ] + if DO_TREE_ATTENTION: + graph_input += [ + helper.make_tensor_value_info( + "custom_pos_ids", + TensorProto.INT64, + [1, 1] + ), + helper.make_tensor_value_info( + "custom_causal_attention_mask", + ORT_TYPE, + [config.batch_size, config.kv_sequence_length, config.kv_sequence_length] + ) + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -334,11 +337,13 @@ def create_group_query_attention_graph_prompt( ) model = helper.make_model(graph) + return model.SerializeToString() def create_group_query_attention_graph_past( config, + seqlens_k, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1, @@ -352,6 +357,7 @@ def create_group_query_attention_graph_past( present_kv_seqlen = ( config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length ) + max_seqlen_in_batch = seqlens_k.max().item() + 1 nodes = [ helper.make_node( "GroupQueryAttention", @@ -365,6 +371,9 @@ def create_group_query_attention_graph_past( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", + "custom_pos_ids" if DO_TREE_ATTENTION else "", + "custom_causal_attention_mask" if DO_TREE_ATTENTION else "" + ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -467,6 +476,20 @@ def create_group_query_attention_graph_past( ), ] + if DO_TREE_ATTENTION: + graph_input += [ + helper.make_tensor_value_info( + "custom_pos_ids", + TensorProto.INT64, + [config.batch_size, config.sequence_length] + ), + helper.make_tensor_value_info( + "custom_causal_attention_mask", + ORT_TYPE, + [config.batch_size, config.sequence_length, max_seqlen_in_batch] + ) + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -699,9 +722,22 @@ def gqa_prompt_func( softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) + q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None + + # Construct position ids and attention mask data + custom_pos_ids = torch.tensor(data=[[0]], dtype=torch.int64) + custom_causal_attention_mask = torch.rand(config.batch_size, config.kv_sequence_length, config.kv_sequence_length, dtype=TORCH_TYPE) + custom_causal_attention_mask = torch.triu(custom_causal_attention_mask, diagonal=1) + + # print(f"Tree causal attention mask shape {custom_causal_attention_mask.shape}") + # print(f"Seqlens_k: {seqlens_k}") + # print(f"Q shape: {q.shape}") + # print(f"total_sequence_length: {config.q_sequence_length}") + # print(f"kv_sequence_length: {config.kv_sequence_length}") + if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) @@ -713,6 +749,7 @@ def gqa_prompt_func( "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } + sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() @@ -726,6 +763,13 @@ def gqa_prompt_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + if DO_TREE_ATTENTION: + ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() + ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() + io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) + io_binding.bind_cpu_input("custom_causal_attention_mask", ort_inputs["custom_causal_attention_mask"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() @@ -767,6 +811,13 @@ def gqa_prompt_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + if DO_TREE_ATTENTION: + ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() + ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() + io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) + io_binding.bind_cpu_input("custom_causal_attention_mask", ort_inputs["custom_causal_attention_mask"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) @@ -800,6 +851,7 @@ def gqa_past_func( assert seqlens_k is not None onnx_model_str = create_group_query_attention_graph_past( config, + seqlens_k, past_kv_format, share_buffer, local_window_size=window_size, @@ -812,6 +864,29 @@ def gqa_past_func( q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() + + # Construct position ids and mask data + custom_pos_ids_data = [] + max_seqlen_in_batch = seqlens_k.max().item() + 1 + custom_causal_attention_mask = torch.zeros((config.batch_size, config.sequence_length, max_seqlen_in_batch), dtype=TORCH_TYPE) + for b in range(config.batch_size): + total_seq_len = seqlens_k[b] + 1 + past_seq_len = total_seq_len - config.sequence_length; + custom_pos_ids_data.append(list(range(past_seq_len, past_seq_len + config.sequence_length))) + + # Configure mask + for i in range(config.sequence_length): + for j in range(past_seq_len + i + 1, max_seqlen_in_batch): + custom_causal_attention_mask[b][i][j] = -5000 + + custom_pos_ids = torch.tensor(data=custom_pos_ids_data, dtype=torch.int64) + + # print(custom_causal_attention_mask[0]) + # print(f"Tree causal attention mask shape: {custom_causal_attention_mask.shape}") + # print(f"Seqlens_k: {seqlens_k}") + # print(f"Q shape: {q.shape}") + # print(f"past k shape: {past_k.shape}") + if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) @@ -839,6 +914,13 @@ def gqa_past_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + if DO_TREE_ATTENTION: + ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() + ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() + io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) + io_binding.bind_cpu_input("custom_causal_attention_mask", ort_inputs["custom_causal_attention_mask"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() @@ -887,6 +969,13 @@ def gqa_past_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + if DO_TREE_ATTENTION: + ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() + ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() + io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) + io_binding.bind_cpu_input("custom_causal_attention_mask", ort_inputs["custom_causal_attention_mask"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) @@ -1087,6 +1176,7 @@ def parity_check_gqa_prompt( dtype=TORCH_TYPE, requires_grad=False, ) + v = torch.randn( config.batch_size, config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, @@ -1195,7 +1285,7 @@ def parity_check_gqa_prompt( None, cos, sin, - cache_seqlens, + cache_seqlens - 1, left_window_size, past_format, True, @@ -1213,7 +1303,7 @@ def parity_check_gqa_prompt( new_v, cos, sin, - cache_seqlens, + cache_seqlens - 1, left_window_size, past_format, True, @@ -1530,6 +1620,7 @@ def parity_check_gqa_past( if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + # TODO(derdeljan): Move random seqlens_k selection to config cache_seqlens = torch.randint( 0, config.kv_sequence_length - config.sequence_length + 1, @@ -1752,6 +1843,7 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) k_cache_ref = torch.cat((k_cache_ref, new_k), 1) v_cache_ref = torch.cat((v_cache_ref, new_v), 1) + # TODO(derdeljan): Move random seqlens_k selection to config cache_seqlens = torch.randint( 0, config.kv_sequence_length, From 15172c3bc5ec399a5f09a7660a3e8fc5ae8bf603 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Thu, 6 Mar 2025 11:53:54 +0100 Subject: [PATCH 02/25] Vectorized attention mask application for fp32 --- cmake/onnxruntime_mlas.cmake | 1 + .../contrib_ops/cpu/bert/attention_helper.h | 13 +++ .../contrib_ops/cpu/bert/gqa_attention_base.h | 15 ++-- onnxruntime/core/mlas/inc/mlas.h | 10 +++ onnxruntime/core/mlas/lib/eltwise_add.cpp | 86 +++++++++++++++++++ 5 files changed, 118 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/eltwise_add.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 15864a0198161..16637d0da00a2 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -27,6 +27,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/activate.cpp ${MLAS_SRC_DIR}/logistic.cpp ${MLAS_SRC_DIR}/tanh.cpp + ${MLAS_SRC_DIR}/eltwise_add.cpp ${MLAS_SRC_DIR}/erf.cpp ${MLAS_SRC_DIR}/compute.cpp ${MLAS_SRC_DIR}/quantize.cpp diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 188fc6e43b5b5..f09477ff04a55 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -31,6 +31,19 @@ void ComputeAttentionSoftcapInplace(T* scores, int sequence_length, T softcap) { MlasComputeSoftcap(scores, scores, sequence_length, softcap); } +template +void ApplyAttentionMask(T* softmax_logits, const T* attention_mask, int N) { + MlasEltwiseAdd(softmax_logits, attention_mask, softmax_logits, N); + + // for (int i = 0; i < N; i++) { + // if constexpr (std::is_same::value) { + // softmax_logits[i] += static_cast(attention_mask[i]); + // } else { + // softmax_logits[i] = MLFloat16(softmax_logits[i].ToFloat() + attention_mask[i].ToFloat()); + // } + // } +} + template void PrepareMask(const int32_t* mask_index, gsl::span mask_index_dims, diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 39df9bd5302bc..3f2bb1b5d42c1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -260,14 +260,15 @@ class GQAAttentionBase { const size_t window_size = should_apply_local_window ? local_window_size_ + 1 : seq_causal_length; // Apply custom attention mask if there is any - // TODO: Vectorize this addition if (attention_mask_batch != nullptr) { - for (size_t idx = start_offset; idx < start_offset + window_size; idx++) { - if constexpr (std::is_same::value) { - output_softmax[idx] += static_cast(attention_mask_batch[idx]); - } else { - output_softmax[idx] = MLFloat16(output_softmax[idx].ToFloat() + attention_mask_batch[idx].ToFloat()); - } + if constexpr (std::is_same_v) { + ApplyAttentionMask(output_softmax + start_offset, attention_mask_batch + start_offset, + static_cast(window_size)); + } else { + // TODO: Handle the case where U and T are different types + // The only case where U and T can be of different types is when U is float32 and T is float16 + // In that case, allocate a float32 scratch buffer and upcast the mask to it + // Then, invoke the ApplyAttentionMask with the temporary buffer being passed in } } diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 1401e27ca77e5..8033eab8262a0 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1030,6 +1030,16 @@ MlasComputeSoftcap( T cap ); +template +void +MLASCALL +MlasEltwiseAdd( + const T* left, + const T* right, + T* output, + size_t N + ); + template void MLASCALL diff --git a/onnxruntime/core/mlas/lib/eltwise_add.cpp b/onnxruntime/core/mlas/lib/eltwise_add.cpp new file mode 100644 index 0000000000000..e8338833cb0c3 --- /dev/null +++ b/onnxruntime/core/mlas/lib/eltwise_add.cpp @@ -0,0 +1,86 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_add.cpp + +Abstract: + + This module implements routines to compute eltwise addition of two vectors. + +--*/ + +#include "mlasi.h" +#include + +template <> +void +MLASCALL +MlasEltwiseAdd( + const float* left, + const float* right, + float* output, + size_t N +) { + // std::cout << "running vectorized mlas kernel for fp32 addition" << std::endl; + // for (size_t i = 0; i < N; i++) { + // output[i] = left[i] + right[i]; + // } + + while (N > 0) { + MLAS_FLOAT32X4 LeftVec, RightVec; + + if (N >= 4) { + LeftVec = MlasLoadFloat32x4(left); + RightVec = MlasLoadFloat32x4(right); + } else { +#if defined(MLAS_SSE2_INTRINSICS) + // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle + // and use zeroes for the upper elements. + LeftVec = _mm_load_ss(left); + RightVec = _mm_load_ss(right); +#elif defined(MLAS_LSX_INTRINSICS) + LeftVec = (MLAS_FLOAT32X4)__lsx_vldrepl_w(left, 0); + RightVec = (MLAS_FLOAT32X4)__lsx_vldrepl_w(right, 0); +#else + LeftVec = MlasBroadcastFloat32x4(left); + RightVec = MlasBroadcastFloat32x4(right); +#endif + } + + MLAS_FLOAT32X4 ResultVec = MlasAddFloat32x4(LeftVec, RightVec); + + if (N >= 4) { + MlasStoreFloat32x4(output, ResultVec); + + left += 4; + right += 4; + output += 4; + N -= 4; + } else { + MlasStoreLaneFloat32x4<0>(output, ResultVec); + + left += 1; + right += 1; + output += 1; + N -= 1; + } + } +} + + +template <> +void +MLASCALL +MlasEltwiseAdd( + const MLAS_FP16* left, + const MLAS_FP16* right, + MLAS_FP16* output, + size_t N +) { + MLAS_THROW_EX(std::runtime_error, "Mlas eltwise add for fp16 is not yet supported."); +} From d7eae786f3329ff0fb5cd79a045f0a76d4929fa4 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Thu, 6 Mar 2025 15:20:55 +0100 Subject: [PATCH 03/25] Vectorized attention mask application for fp16 --- cmake/onnxruntime_mlas.cmake | 10 +- .../mlas/lib/{eltwise_add.cpp => eltwise.cpp} | 20 +-- onnxruntime/core/mlas/lib/eltwise.h | 37 ++++++ .../core/mlas/lib/eltwise_kernel_neon.cpp | 32 +++++ .../core/mlas/lib/eltwise_kernel_neon.h | 28 +++++ .../mlas/lib/eltwise_kernel_neon_fp16.cpp | 118 ++++++++++++++++++ onnxruntime/core/mlas/lib/mlasi.h | 5 + onnxruntime/core/mlas/lib/platform.cpp | 1 + 8 files changed, 241 insertions(+), 10 deletions(-) rename onnxruntime/core/mlas/lib/{eltwise_add.cpp => eltwise.cpp} (78%) create mode 100644 onnxruntime/core/mlas/lib/eltwise.h create mode 100644 onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp create mode 100644 onnxruntime/core/mlas/lib/eltwise_kernel_neon.h create mode 100644 onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 16637d0da00a2..87387d4f281ed 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -27,7 +27,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/activate.cpp ${MLAS_SRC_DIR}/logistic.cpp ${MLAS_SRC_DIR}/tanh.cpp - ${MLAS_SRC_DIR}/eltwise_add.cpp + ${MLAS_SRC_DIR}/eltwise.h + ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp ${MLAS_SRC_DIR}/compute.cpp ${MLAS_SRC_DIR}/quantize.cpp @@ -102,6 +103,9 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/softmax_kernel_neon.h ${MLAS_SRC_DIR}/softmax_kernel_neon.cpp ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon.h + ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ) set(mlas_platform_preprocess_srcs @@ -388,6 +392,8 @@ else() ${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/softmax_kernel_neon.h ${MLAS_SRC_DIR}/softmax_kernel_neon.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon.h + ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") @@ -410,6 +416,7 @@ else() ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -424,6 +431,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/core/mlas/lib/eltwise_add.cpp b/onnxruntime/core/mlas/lib/eltwise.cpp similarity index 78% rename from onnxruntime/core/mlas/lib/eltwise_add.cpp rename to onnxruntime/core/mlas/lib/eltwise.cpp index e8338833cb0c3..a9b319804d4c0 100644 --- a/onnxruntime/core/mlas/lib/eltwise_add.cpp +++ b/onnxruntime/core/mlas/lib/eltwise.cpp @@ -6,16 +6,19 @@ Licensed under the MIT License. Module Name: - eltwise_add.cpp + eltwise.cpp Abstract: - This module implements routines to compute eltwise addition of two vectors. + This module implements routines to compute eltwise operations on two vectors. + + Currently supported eltwise operations: + - Add --*/ #include "mlasi.h" -#include +#include "eltwise.h" template <> void @@ -26,11 +29,6 @@ MlasEltwiseAdd( float* output, size_t N ) { - // std::cout << "running vectorized mlas kernel for fp32 addition" << std::endl; - // for (size_t i = 0; i < N; i++) { - // output[i] = left[i] + right[i]; - // } - while (N > 0) { MLAS_FLOAT32X4 LeftVec, RightVec; @@ -82,5 +80,9 @@ MlasEltwiseAdd( MLAS_FP16* output, size_t N ) { - MLAS_THROW_EX(std::runtime_error, "Mlas eltwise add for fp16 is not yet supported."); + const auto* dispatch = GetMlasPlatform().EltwiseDispatch; + if (dispatch == nullptr || dispatch->Add_Fp16 == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Add_Fp16 is not supported."); + } + dispatch->Add_Fp16(left, right, output, N); } diff --git a/onnxruntime/core/mlas/lib/eltwise.h b/onnxruntime/core/mlas/lib/eltwise.h new file mode 100644 index 0000000000000..582899a2db24d --- /dev/null +++ b/onnxruntime/core/mlas/lib/eltwise.h @@ -0,0 +1,37 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + eltwise operations. + +--*/ +#pragma once + +#include "mlasi.h" + +struct MLAS_ELTWISE_DISPATCH { + /** + * @brief Compute the element-wise addition of the two given vectors + * @param left Address of the left operand + * @param right Address of the right operand + * @param output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input arrays + */ + typedef void(Add_Fp16_Fn)( + const MLAS_FP16* left, + const MLAS_FP16* right, + MLAS_FP16* output, + size_t N + ); + + Add_Fp16_Fn* Add_Fp16 = nullptr; +}; diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp new file mode 100644 index 0000000000000..02ad05b2bbd7c --- /dev/null +++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp @@ -0,0 +1,32 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon.cpp + +Abstract: + + This module implements the eltwise kernels for ARM NEON. + +--*/ + +#include "eltwise.h" +#include "eltwise_kernel_neon.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_ELTWISE_DISPATCH MlasEltwiseDispatchNeon = []() { + MLAS_ELTWISE_DISPATCH d; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + if (MlasFp16AccelerationSupported()) { + d.Add_Fp16 = eltwise_neon::Add_Kernel_Fp16; + } +#endif + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h new file mode 100644 index 0000000000000..f9eb7b1ed81f7 --- /dev/null +++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h @@ -0,0 +1,28 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + eltwise operations on ARM cpu. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" + +namespace eltwise_neon { + +void Add_Kernel_Fp16(const MLAS_FP16* left, const MLAS_FP16* right, MLAS_FP16* output, size_t N); + +} // namespace eltwise_neon diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp new file mode 100644 index 0000000000000..80f95ee14c3d0 --- /dev/null +++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp @@ -0,0 +1,118 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon_fp16.cpp + +Abstract: + + This module implements the fp16 eltwise kernels for ARM NEON. + +--*/ +#include +#include + +#include "fp16_common.h" +#include "eltwise.h" +#include "eltwise_kernel_neon.h" + +namespace eltwise_neon { + +void Add_Kernel_Fp16(const MLAS_FP16* left, const MLAS_FP16* right, MLAS_FP16* output, size_t N) { + const auto* left_fp16 = reinterpret_cast(left); + const auto* right_fp16 = reinterpret_cast(right); + auto* output_fp16 = reinterpret_cast<_mlas_fp16_*>(output); + + while (N >= 32) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto l1 = MlasLoadFloat16x8(left_fp16 + 8); + auto l2 = MlasLoadFloat16x8(left_fp16 + 16); + auto l3 = MlasLoadFloat16x8(left_fp16 + 24); + + auto r0 = MlasLoadFloat16x8(right_fp16); + auto r1 = MlasLoadFloat16x8(right_fp16 + 8); + auto r2 = MlasLoadFloat16x8(right_fp16 + 16); + auto r3 = MlasLoadFloat16x8(right_fp16 + 24); + + auto o0 = MlasAddFloat16(l0, r0); + auto o1 = MlasAddFloat16(l1, r1); + auto o2 = MlasAddFloat16(l2, r2); + auto o3 = MlasAddFloat16(l3, r3); + + MlasStoreFloat16x8(output_fp16, o0); + MlasStoreFloat16x8(output_fp16 + 8, o1); + MlasStoreFloat16x8(output_fp16 + 16, o2); + MlasStoreFloat16x8(output_fp16 + 24, o3); + + left_fp16 += 32; + right_fp16 += 32; + output_fp16 += 32; + N -= 32; + } + + if (N & 16) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto l1 = MlasLoadFloat16x8(left_fp16 + 8); + + auto r0 = MlasLoadFloat16x8(right_fp16); + auto r1 = MlasLoadFloat16x8(right_fp16 + 8); + + auto o0 = MlasAddFloat16(l0, r0); + auto o1 = MlasAddFloat16(l1, r1); + + MlasStoreFloat16x8(output_fp16, o0); + MlasStoreFloat16x8(output_fp16 + 8, o1); + + left_fp16 += 16; + right_fp16 += 16; + output_fp16 += 16; + N -= 16; + } + + if (N & 8) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto r0 = MlasLoadFloat16x8(right_fp16); + auto o0 = MlasAddFloat16(l0, r0); + MlasStoreFloat16x8(output_fp16, o0); + + left_fp16 += 8; + right_fp16 += 8; + output_fp16 += 8; + N -= 8; + } + + if (N & 4) { + auto l0 = MlasLoadFloat16x4(left_fp16); + auto r0 = MlasLoadFloat16x4(right_fp16); + auto o0 = MlasAddFLoat16(l0, r0); + MlasStoreFloat16x4(output_fp16, o0); + + left_fp16 += 4; + right_fp16 += 4; + output_fp16 += 4; + N -= 4; + } + + if (N == 3) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 3); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 3); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 3); + } else if (N == 2) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 2); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 2); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 2); + } else if (N == 1) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 1); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 1); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 1); + } +} + +} // namespace eltwise_neon diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0681b49252495..8e704b5b94801 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1070,6 +1070,10 @@ extern const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon; struct MLAS_SOFTMAX_DISPATCH; extern const MLAS_SOFTMAX_DISPATCH MlasSoftmaxDispatchNeon; +// eltwise dispatch structure +struct MLAS_ELTWISE_DISPATCH; +extern const MLAS_ELTWISE_DISPATCH MlasEltwiseDispatchNeon; + // // Quantized depthwise convolution kernels. // @@ -1233,6 +1237,7 @@ struct MLAS_PLATFORM { const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; const MLAS_HGEMM_DISPATCH* HGemmDispatch{nullptr}; const MLAS_SOFTMAX_DISPATCH* SoftmaxDispatch{nullptr}; + const MLAS_ELTWISE_DISPATCH* EltwiseDispatch{nullptr}; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 582c1ab944b98..312a624fd160c 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -547,6 +547,7 @@ Return Value: this->RopeDispatch = &MlasRopeDispatchNeon; this->HGemmDispatch = &MlasHGemmDispatchNeon; this->SoftmaxDispatch = &MlasSoftmaxDispatchNeon; + this->EltwiseDispatch = &MlasEltwiseDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. From 9d244dd7432d69a9a79baf51b77e834cda82577f Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Thu, 6 Mar 2025 16:11:52 +0100 Subject: [PATCH 04/25] Add mask upscale to fp32 if the platform doesn't support fp16 --- onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 3f2bb1b5d42c1..bfbe32ec77513 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -265,10 +265,12 @@ class GQAAttentionBase { ApplyAttentionMask(output_softmax + start_offset, attention_mask_batch + start_offset, static_cast(window_size)); } else { - // TODO: Handle the case where U and T are different types - // The only case where U and T can be of different types is when U is float32 and T is float16 - // In that case, allocate a float32 scratch buffer and upcast the mask to it - // Then, invoke the ApplyAttentionMask with the temporary buffer being passed in + size_t bytes = window_size * sizeof(float); + auto attention_mask_batch_fp32 = static_cast(allocator->Alloc(bytes)); + BufferUniquePtr scratch_buffer(attention_mask_batch_fp32, BufferDeleter(allocator)); + + MlasConvertHalfToFloatBuffer(attention_mask_batch + start_offset, attention_mask_batch_fp32, window_size); + ApplyAttentionMask(output_softmax, attention_mask_batch_fp32, static_cast(window_size)); } } From 8faee66136d69a2943c56b6b445d5bbb50bcadb0 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Thu, 6 Mar 2025 17:02:31 +0100 Subject: [PATCH 05/25] Fix typo in fp16 eltwise kernels --- onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp index 80f95ee14c3d0..97b88983e16a0 100644 --- a/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp @@ -88,7 +88,7 @@ void Add_Kernel_Fp16(const MLAS_FP16* left, const MLAS_FP16* right, MLAS_FP16* o if (N & 4) { auto l0 = MlasLoadFloat16x4(left_fp16); auto r0 = MlasLoadFloat16x4(right_fp16); - auto o0 = MlasAddFLoat16(l0, r0); + auto o0 = MlasAddFloat16(l0, r0); MlasStoreFloat16x4(output_fp16, o0); left_fp16 += 4; From 147d19b6e680980cf90f813a7786ece5df66e08e Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Fri, 7 Mar 2025 11:34:30 +0100 Subject: [PATCH 06/25] Add validation for custom attention parameters --- .../cpu/bert/group_query_attention.cc | 9 +++- .../cpu/bert/group_query_attention_helper.h | 48 +++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 4ce383d7652fa..4a22ecf5339fc 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -52,7 +52,6 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* total_seqlen_tensor = context->Input(6); const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); - const Tensor* custom_pos_ids = context->Input(9); const Tensor* custom_causal_attention_mask = context->Input(10); @@ -72,6 +71,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { scale_, softcap_)); + const int32_t* seqlens_k_data = seqlens_k->Data(); + const int32_t max_seqlens_k = *std::max_element(seqlens_k_data, seqlens_k_data + parameters.batch_size); + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(custom_pos_ids, + custom_causal_attention_mask, + max_seqlens_k, + parameters)); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int present_kv_seqlen = parameters.seqlen_present_kv_cache; @@ -136,7 +142,6 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const int64_t* pos_ids_data = pos_ids.data(); if (custom_pos_ids != nullptr) { - ORT_RETURN_IF_NOT(pos_ids_size == custom_pos_ids->Shape()[0] * custom_pos_ids->Shape()[1]); pos_ids_data = custom_pos_ids->Data(); } else if (parameters.is_first_prompt) { pos_ids[0] = static_cast(0); diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 4cc5a4228dc8c..c9dfb5267d38f 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -288,6 +288,54 @@ Status CheckInputs(const T* query, return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap); } + +template +Status CheckCustomAttentionInputs(const T* custom_pos_ids, + const T* custom_causal_attention_mask, + const int max_seqlens_k, + const GroupQueryAttentionParameters& parameters) +{ + if (custom_pos_ids != nullptr) { + const auto& pos_ids_shape = custom_pos_ids->Shape(); + if (parameters.is_first_prompt) { + if (pos_ids_shape[0] != 1 || pos_ids_shape[1] != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Shape of custom_pos_ids must be [1, 1] when processing the prompt"); + } + } else { + if (pos_ids_shape[0] != parameters.batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "custom_pos_ids dimension 0 must be equal to the batch size, got ", pos_ids_shape[0]); + } + + if (pos_ids_shape[1] < parameters.sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "custom_pos_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]); + } + } + } + + if (custom_causal_attention_mask != nullptr) { + const auto& mask_shape = custom_causal_attention_mask->Shape(); + if (mask_shape[0] != parameters.batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "custom_causal_attention_mask dimension 0 must be equal to the batch size, got ", mask_shape[0]); + } + + if (mask_shape[1] != parameters.sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "custom_causal_attention_mask dimension 1 must be equal to the sequence length, got ", mask_shape[1]); + } + + if (mask_shape[2] < max_seqlens_k + 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "custom_causal_attention_mask dimension 2 must be atleast max(seqlens_k) + 1, got ", mask_shape[2]); + } + } + + return Status::OK(); +} + } // namespace group_query_attention_helper } // namespace contrib } // namespace onnxruntime From 4b1262eb5a129be100979130ca61ea9f3b8ac25d Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Fri, 7 Mar 2025 15:51:03 +0100 Subject: [PATCH 07/25] Add mlas unit test for eltwise kernels --- .../test/mlas/unittest/test_eltwise.cpp | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 onnxruntime/test/mlas/unittest/test_eltwise.cpp diff --git a/onnxruntime/test/mlas/unittest/test_eltwise.cpp b/onnxruntime/test/mlas/unittest/test_eltwise.cpp new file mode 100644 index 0000000000000..cc44a296ff330 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_eltwise.cpp @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/eltwise.h" + +class MlasEltwiseAddTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInputLeft; + MatrixGuardBuffer BufferInputRight; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + MatrixGuardBuffer BufferInputLeftFp16; + MatrixGuardBuffer BufferInputRightFp16; + MatrixGuardBuffer BufferOutputFp16; + + void Test(size_t N, float MinimumValue, float MaximumValue) { + float* InputLeft = BufferInputLeft.GetBuffer(N); + float* InputRight = BufferInputRight.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + InputLeft[n] = distribution(generator); + InputRight[n] = distribution(generator); + } + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = InputLeft[n] + InputRight[n]; + } + + MlasEltwiseAdd(InputLeft, InputRight, Output, N); + + constexpr float AbsoluteTolerance = 1e-6f; + constexpr float RelativeTolerance = 1e-6f; + + for (size_t n = 0; n < N; n++) { + float diff = std::fabs(Output[n] - OutputReference[n]); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[n]) * RelativeTolerance) + << " @" << n << " of " << N << ", got: " << Output[n] << ", expecting: " << OutputReference[n]; + } + } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + void TestFp16(size_t N, float MinimumValue, float MaximumValue) { + MLAS_FP16* InputLeft = BufferInputLeftFp16.GetBuffer(N); + MLAS_FP16* InputRight = BufferInputRightFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + InputLeft[n] = MLAS_FP16(distribution(generator)); + InputRight[n] = MLAS_FP16(distribution(generator)); + } + + MlasEltwiseAdd(InputLeft, InputRight, Output, N); + + constexpr float AbsoluteTolerance = 5e-4f; + constexpr float RelativeTolerance = 1e-3f; + + for (size_t n = 0; n < N; n++) { + float inLeft = InputLeft[n].ToFloat(); + float inRight = InputRight[n].ToFloat(); + float ref = inLeft + inRight; + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << inLeft << ", " << inRight << ", got: " << out << ", expecting: " << ref + << ", r-diff: " << diff / std::fabs(ref); + } + } + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Eltwise_Add"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { + Test(n, -10.f, 10.f); +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(n, -17.f, 11.f); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); From f7a07881f1fefffb0de07265aa4743980bca69d3 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Fri, 7 Mar 2025 15:52:37 +0100 Subject: [PATCH 08/25] Refactor python unit GQA tests --- .../test/python/transformers/test_gqa_cpu.py | 319 +++++++++++------- 1 file changed, 197 insertions(+), 122 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 73ee12cc3ec2f..f25bc1bce3047 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -36,8 +36,6 @@ RTOL = 3e-2 if ORT_TYPE == TensorProto.FLOAT16 else 1e-3 ATOL = RTOL -DO_TREE_ATTENTION = True - class Formats: BSNH = 0 @@ -144,6 +142,7 @@ def create_group_query_attention_graph_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, + do_custom_tree_attention=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -160,8 +159,8 @@ def create_group_query_attention_graph_prompt( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", - "custom_pos_ids" if DO_TREE_ATTENTION else "", - "custom_causal_attention_mask" if DO_TREE_ATTENTION else "" + "custom_pos_ids" if do_custom_tree_attention else "", + "custom_causal_attention_mask" if do_custom_tree_attention else "" ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -267,7 +266,7 @@ def create_group_query_attention_graph_prompt( ), ] - if DO_TREE_ATTENTION: + if do_custom_tree_attention: graph_input += [ helper.make_tensor_value_info( "custom_pos_ids", @@ -352,6 +351,7 @@ def create_group_query_attention_graph_past( packed=False, softcap=0.0, use_smooth_softmax=False, + do_custom_tree_attention=False ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -371,8 +371,8 @@ def create_group_query_attention_graph_past( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", - "custom_pos_ids" if DO_TREE_ATTENTION else "", - "custom_causal_attention_mask" if DO_TREE_ATTENTION else "" + "custom_pos_ids" if do_custom_tree_attention else "", + "custom_causal_attention_mask" if do_custom_tree_attention else "" ], ["output", "present_key", "present_value"], @@ -476,7 +476,7 @@ def create_group_query_attention_graph_past( ), ] - if DO_TREE_ATTENTION: + if do_custom_tree_attention: graph_input += [ helper.make_tensor_value_info( "custom_pos_ids", @@ -704,12 +704,15 @@ def gqa_prompt_func( cos=None, sin=None, seqlens_k=None, + custom_pos_ids=None, + custom_causal_attention_mask=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True, rotary_interleaved=False, softcap=0.0, use_smooth_softmax=False, + do_custom_tree_attention=False ): onnx_model_str = create_group_query_attention_graph_prompt( config, @@ -721,22 +724,16 @@ def gqa_prompt_func( packed=new_k is None, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + do_custom_tree_attention=do_custom_tree_attention ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None - # Construct position ids and attention mask data - custom_pos_ids = torch.tensor(data=[[0]], dtype=torch.int64) - custom_causal_attention_mask = torch.rand(config.batch_size, config.kv_sequence_length, config.kv_sequence_length, dtype=TORCH_TYPE) - custom_causal_attention_mask = torch.triu(custom_causal_attention_mask, diagonal=1) - - # print(f"Tree causal attention mask shape {custom_causal_attention_mask.shape}") - # print(f"Seqlens_k: {seqlens_k}") - # print(f"Q shape: {q.shape}") - # print(f"total_sequence_length: {config.q_sequence_length}") - # print(f"kv_sequence_length: {config.kv_sequence_length}") + if do_custom_tree_attention: + assert custom_pos_ids is not None + assert custom_causal_attention_mask is not None if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) @@ -764,7 +761,7 @@ def gqa_prompt_func( io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - if DO_TREE_ATTENTION: + if do_custom_tree_attention: ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) @@ -812,7 +809,7 @@ def gqa_prompt_func( io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - if DO_TREE_ATTENTION: + if do_custom_tree_attention: ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) @@ -841,12 +838,15 @@ def gqa_past_func( cos=None, sin=None, seqlens_k=None, + custom_pos_ids=None, + custom_causal_attention_mask=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1, rotary_interleaved=False, softcap=0.0, use_smooth_softmax=False, + do_custom_tree_attention=False ): assert seqlens_k is not None onnx_model_str = create_group_query_attention_graph_past( @@ -860,32 +860,15 @@ def gqa_past_func( packed=new_k is None, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + do_custom_tree_attention=do_custom_tree_attention ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() - # Construct position ids and mask data - custom_pos_ids_data = [] - max_seqlen_in_batch = seqlens_k.max().item() + 1 - custom_causal_attention_mask = torch.zeros((config.batch_size, config.sequence_length, max_seqlen_in_batch), dtype=TORCH_TYPE) - for b in range(config.batch_size): - total_seq_len = seqlens_k[b] + 1 - past_seq_len = total_seq_len - config.sequence_length; - custom_pos_ids_data.append(list(range(past_seq_len, past_seq_len + config.sequence_length))) - - # Configure mask - for i in range(config.sequence_length): - for j in range(past_seq_len + i + 1, max_seqlen_in_batch): - custom_causal_attention_mask[b][i][j] = -5000 - - custom_pos_ids = torch.tensor(data=custom_pos_ids_data, dtype=torch.int64) - - # print(custom_causal_attention_mask[0]) - # print(f"Tree causal attention mask shape: {custom_causal_attention_mask.shape}") - # print(f"Seqlens_k: {seqlens_k}") - # print(f"Q shape: {q.shape}") - # print(f"past k shape: {past_k.shape}") + if do_custom_tree_attention: + assert custom_pos_ids is not None + assert custom_causal_attention_mask is not None if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) @@ -915,7 +898,7 @@ def gqa_past_func( io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - if DO_TREE_ATTENTION: + if do_custom_tree_attention: ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) @@ -970,7 +953,7 @@ def gqa_past_func( io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - if DO_TREE_ATTENTION: + if do_custom_tree_attention: ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) @@ -1145,6 +1128,31 @@ def attention_qkvpacked_ref( ) +def get_custom_attention_inputs(batch_size, sequence_length, seqlens_k=None, past=False): + if past: + assert seqlens_k is not None + custom_pos_ids_data = [] + max_seqlen_in_batch = seqlens_k.max().item() + 1 + custom_causal_attention_mask = torch.zeros((batch_size, sequence_length, max_seqlen_in_batch), dtype=TORCH_TYPE) + for b in range(batch_size): + total_seq_len = seqlens_k[b] + 1 + past_seq_len = total_seq_len - sequence_length; + custom_pos_ids_data.append(list(range(past_seq_len, past_seq_len + sequence_length))) + + # Configure mask + for i in range(sequence_length): + for j in range(past_seq_len + i + 1, max_seqlen_in_batch): + custom_causal_attention_mask[b][i][j] = -5000 + + custom_pos_ids = torch.tensor(data=custom_pos_ids_data, dtype=torch.int64) + else: + custom_pos_ids = torch.tensor(data=[[0]], dtype=torch.int64) + custom_causal_attention_mask = torch.rand(batch_size, sequence_length, sequence_length, dtype=TORCH_TYPE) + custom_causal_attention_mask = torch.triu(custom_causal_attention_mask, diagonal=1) + + return custom_pos_ids, custom_causal_attention_mask + + def parity_check_gqa_prompt( config, causal=True, @@ -1155,6 +1163,7 @@ def parity_check_gqa_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, + do_custom_tree_attention=False, rtol=RTOL, atol=ATOL, ): @@ -1244,6 +1253,13 @@ def parity_check_gqa_prompt( cos, sin = None, None q_ro, k_ro = q, new_k + if do_custom_tree_attention: + custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( + config.batch_size, config.kv_sequence_length, past=False) + else: + custom_pos_ids = None + custom_causal_attention_mask = None + rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") @@ -1286,12 +1302,15 @@ def parity_check_gqa_prompt( cos, sin, cache_seqlens - 1, + custom_pos_ids, + custom_causal_attention_mask, left_window_size, past_format, True, rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, + do_custom_tree_attention=do_custom_tree_attention ) else: out, present_k, present_v = gqa_prompt_func( @@ -1304,12 +1323,15 @@ def parity_check_gqa_prompt( cos, sin, cache_seqlens - 1, + custom_pos_ids, + custom_causal_attention_mask, left_window_size, past_format, True, rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, + do_custom_tree_attention=do_custom_tree_attention ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1338,6 +1360,8 @@ def parity_check_gqa_prompt( softcap, " smooth_softmax:", use_smooth_softmax, + " custom_tree_attention:", + do_custom_tree_attention, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1369,6 +1393,7 @@ def parity_check_gqa_prompt_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, + do_custom_tree_attention=False, rtol=RTOL, atol=ATOL, ): @@ -1437,6 +1462,13 @@ def parity_check_gqa_prompt_no_buff( q_ro, k_ro = q, k_cache_ref k_cache_ref = k_ro + if do_custom_tree_attention: + custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( + config.batch_size, config.kv_sequence_length, past=False) + else: + custom_pos_ids = None + custom_causal_attention_mask = None + brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded @@ -1473,6 +1505,8 @@ def parity_check_gqa_prompt_no_buff( cos, sin, cache_seqlens - 1, + custom_pos_ids, + custom_causal_attention_mask, left_window_size, past_format, False, @@ -1491,6 +1525,8 @@ def parity_check_gqa_prompt_no_buff( cos, sin, cache_seqlens - 1, + custom_pos_ids, + custom_causal_attention_mask, left_window_size, past_format, False, @@ -1525,6 +1561,8 @@ def parity_check_gqa_prompt_no_buff( softcap, " smooth_softmax:", use_smooth_softmax, + " custom_tree_attention:", + do_custom_tree_attention, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1556,6 +1594,7 @@ def parity_check_gqa_past( packed=False, softcap=0.0, use_smooth_softmax=False, + do_custom_tree_attention=False, rtol=RTOL, atol=ATOL, ): @@ -1620,7 +1659,6 @@ def parity_check_gqa_past( if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) - # TODO(derdeljan): Move random seqlens_k selection to config cache_seqlens = torch.randint( 0, config.kv_sequence_length - config.sequence_length + 1, @@ -1680,6 +1718,13 @@ def parity_check_gqa_past( cache_seqlens += config.sequence_length - 1 + if do_custom_tree_attention: + custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( + config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True) + else: + custom_pos_ids = None + custom_causal_attention_mask = None + # ORT function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1693,12 +1738,15 @@ def parity_check_gqa_past( cos, sin, cache_seqlens, + custom_pos_ids, + custom_causal_attention_mask, past_format, True, left_window_size, rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, + do_custom_tree_attention=do_custom_tree_attention ) else: out, present_k, present_v = gqa_past_func( @@ -1711,12 +1759,15 @@ def parity_check_gqa_past( cos, sin, cache_seqlens, + custom_pos_ids, + custom_causal_attention_mask, past_format, True, left_window_size, rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, + do_custom_tree_attention=do_custom_tree_attention ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -1747,6 +1798,8 @@ def parity_check_gqa_past( softcap, " smooth_softmax:", use_smooth_softmax, + " custom_tree_attention:", + do_custom_tree_attention, " B:", config.batch_size, " S:", @@ -1776,6 +1829,7 @@ def parity_check_gqa_past_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, + do_custom_tree_attention=False, rtol=RTOL, atol=ATOL, ): @@ -1843,7 +1897,6 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) k_cache_ref = torch.cat((k_cache_ref, new_k), 1) v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - # TODO(derdeljan): Move random seqlens_k selection to config cache_seqlens = torch.randint( 0, config.kv_sequence_length, @@ -1906,6 +1959,13 @@ def parity_check_gqa_past_no_buff( cache_seqlens += config.sequence_length - 1 + if do_custom_tree_attention: + custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( + config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True) + else: + custom_pos_ids = None + custom_causal_attention_mask = None + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1919,6 +1979,8 @@ def parity_check_gqa_past_no_buff( cos, sin, cache_seqlens, + custom_pos_ids, + custom_causal_attention_mask, past_format, False, window_size=left_window_size, @@ -1937,6 +1999,8 @@ def parity_check_gqa_past_no_buff( cos, sin, cache_seqlens, + custom_pos_ids, + custom_causal_attention_mask, past_format, False, window_size=left_window_size, @@ -1967,6 +2031,8 @@ def parity_check_gqa_past_no_buff( softcap, " smooth_softmax:", use_smooth_softmax, + " custom_tree_attention:", + do_custom_tree_attention, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -2019,30 +2085,33 @@ def test_gqa_no_past(self): for packed in [False, True]: for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_prompt( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_prompt_no_buff( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) + for do_custom_tree_attention in [False, True]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + do_custom_tree_attention=do_custom_tree_attention + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + do_custom_tree_attention=do_custom_tree_attention + ) + self.assertTrue(all_close) def test_gqa_past(self): print("-------- TEST GQA PAST (TOKEN GEN) ---------") @@ -2076,35 +2145,38 @@ def test_gqa_past(self): for packed in [False, True]: for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=RTOL, - atol=ATOL, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=RTOL, - atol=ATOL, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) + for do_custom_tree_attention in [False, True]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=RTOL, + atol=ATOL, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + do_custom_tree_attention=do_custom_tree_attention + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=RTOL, + atol=ATOL, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + do_custom_tree_attention=do_custom_tree_attention + ) + self.assertTrue(all_close) def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") @@ -2136,30 +2208,33 @@ def test_gqa_interactive_one_batch(self): for local in [False, True]: for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: for packed in [False, True]: - config = Config(b, s, s2, -1, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=RTOL, - atol=ATOL, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=RTOL, - atol=ATOL, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) + for do_custom_tree_attention in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=RTOL, + atol=ATOL, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + do_custom_tree_attention=do_custom_tree_attention + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=RTOL, + atol=ATOL, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + do_custom_tree_attention=do_custom_tree_attention + ) + self.assertTrue(all_close) if __name__ == "__main__": From 9dec0564104cec8c9b13a631de0c3ad3c7feb0fb Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Fri, 7 Mar 2025 15:57:37 +0100 Subject: [PATCH 09/25] Cleanup comments --- onnxruntime/contrib_ops/cpu/bert/attention_helper.h | 10 +--------- onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h | 4 ++-- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index f09477ff04a55..043cf12ad00f5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -31,17 +31,9 @@ void ComputeAttentionSoftcapInplace(T* scores, int sequence_length, T softcap) { MlasComputeSoftcap(scores, scores, sequence_length, softcap); } -template +template void ApplyAttentionMask(T* softmax_logits, const T* attention_mask, int N) { MlasEltwiseAdd(softmax_logits, attention_mask, softmax_logits, N); - - // for (int i = 0; i < N; i++) { - // if constexpr (std::is_same::value) { - // softmax_logits[i] += static_cast(attention_mask[i]); - // } else { - // softmax_logits[i] = MLFloat16(softmax_logits[i].ToFloat() + attention_mask[i].ToFloat()); - // } - // } } template diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index bfbe32ec77513..24d5a9d93ea46 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -50,7 +50,7 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH - const Tensor* attention_mask, // Causal attention mask to apply before + const Tensor* attention_mask, // Causal attention mask to apply before const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) Tensor* output, // output tensor @@ -135,7 +135,7 @@ class GQAAttentionBase { const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor - const T* attention_mask, + const T* attention_mask, // optional causal attention mask const size_t batch_size, // batch size of self-attention const size_t sequence_length, // sequence length of self-attention (S) const size_t attention_mask_total_seqlen, // max total seqlen in batch used for attention last dim From 5d23817e140e17df63d8e44611d593d3b7999dbe Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Sat, 8 Mar 2025 00:40:36 +0100 Subject: [PATCH 10/25] Fix CI pipeline errors --- docs/OperatorKernels.md | 2 +- onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h | 3 ++- onnxruntime/test/python/transformers/test_gqa_cpu.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 84b9c7c9fc174..c3318752dad55 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1399,7 +1399,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* custom_pos_ids:**tensor(int64)**
*in* custom_causal_attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 24d5a9d93ea46..18c99de0354b5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -89,7 +89,8 @@ class GQAAttentionBase { T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; const T* attention_mask_data = attention_mask != nullptr ? attention_mask->Data() : nullptr; - const size_t attention_mask_total_seqlen = attention_mask != nullptr ? attention_mask->Shape()[2] : 0; + const size_t attention_mask_total_seqlen = + attention_mask != nullptr ? static_cast(attention_mask->Shape()[2]) : static_cast(0); bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index f25bc1bce3047..db2e81cca42b9 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -9,10 +9,10 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # ------------------------------------------------------------------------- -from dataclasses import dataclass import math import random import unittest +from dataclasses import dataclass import numpy import torch @@ -1136,7 +1136,7 @@ def get_custom_attention_inputs(batch_size, sequence_length, seqlens_k=None, pas custom_causal_attention_mask = torch.zeros((batch_size, sequence_length, max_seqlen_in_batch), dtype=TORCH_TYPE) for b in range(batch_size): total_seq_len = seqlens_k[b] + 1 - past_seq_len = total_seq_len - sequence_length; + past_seq_len = total_seq_len - sequence_length custom_pos_ids_data.append(list(range(past_seq_len, past_seq_len + sequence_length))) # Configure mask From 42e83d683eec1f1e340eef8ada0c149557365e5d Mon Sep 17 00:00:00 2001 From: derdeljan-msft Date: Sat, 8 Mar 2025 00:54:08 +0100 Subject: [PATCH 11/25] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../cpu/bert/group_query_attention_helper.h | 5 +- .../test/python/transformers/test_gqa_cpu.py | 65 +++++++++---------- 2 files changed, 33 insertions(+), 37 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index c9dfb5267d38f..82a85474e9e67 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -289,12 +289,11 @@ Status CheckInputs(const T* query, return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap); } -template +template Status CheckCustomAttentionInputs(const T* custom_pos_ids, const T* custom_causal_attention_mask, const int max_seqlens_k, - const GroupQueryAttentionParameters& parameters) -{ + const GroupQueryAttentionParameters& parameters) { if (custom_pos_ids != nullptr) { const auto& pos_ids_shape = custom_pos_ids->Shape(); if (parameters.is_first_prompt) { diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index db2e81cca42b9..dcf6be48e1c2e 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -160,7 +160,7 @@ def create_group_query_attention_graph_prompt( "cos_cache" if rotary else "", "sin_cache" if rotary else "", "custom_pos_ids" if do_custom_tree_attention else "", - "custom_causal_attention_mask" if do_custom_tree_attention else "" + "custom_causal_attention_mask" if do_custom_tree_attention else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -268,16 +268,12 @@ def create_group_query_attention_graph_prompt( if do_custom_tree_attention: graph_input += [ - helper.make_tensor_value_info( - "custom_pos_ids", - TensorProto.INT64, - [1, 1] - ), + helper.make_tensor_value_info("custom_pos_ids", TensorProto.INT64, [1, 1]), helper.make_tensor_value_info( "custom_causal_attention_mask", ORT_TYPE, - [config.batch_size, config.kv_sequence_length, config.kv_sequence_length] - ) + [config.batch_size, config.kv_sequence_length, config.kv_sequence_length], + ), ] graph_output = [ @@ -351,7 +347,7 @@ def create_group_query_attention_graph_past( packed=False, softcap=0.0, use_smooth_softmax=False, - do_custom_tree_attention=False + do_custom_tree_attention=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -372,8 +368,7 @@ def create_group_query_attention_graph_past( "cos_cache" if rotary else "", "sin_cache" if rotary else "", "custom_pos_ids" if do_custom_tree_attention else "", - "custom_causal_attention_mask" if do_custom_tree_attention else "" - + "custom_causal_attention_mask" if do_custom_tree_attention else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -479,15 +474,13 @@ def create_group_query_attention_graph_past( if do_custom_tree_attention: graph_input += [ helper.make_tensor_value_info( - "custom_pos_ids", - TensorProto.INT64, - [config.batch_size, config.sequence_length] + "custom_pos_ids", TensorProto.INT64, [config.batch_size, config.sequence_length] ), helper.make_tensor_value_info( "custom_causal_attention_mask", ORT_TYPE, - [config.batch_size, config.sequence_length, max_seqlen_in_batch] - ) + [config.batch_size, config.sequence_length, max_seqlen_in_batch], + ), ] graph_output = [ @@ -712,7 +705,7 @@ def gqa_prompt_func( rotary_interleaved=False, softcap=0.0, use_smooth_softmax=False, - do_custom_tree_attention=False + do_custom_tree_attention=False, ): onnx_model_str = create_group_query_attention_graph_prompt( config, @@ -724,7 +717,7 @@ def gqa_prompt_func( packed=new_k is None, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) @@ -846,7 +839,7 @@ def gqa_past_func( rotary_interleaved=False, softcap=0.0, use_smooth_softmax=False, - do_custom_tree_attention=False + do_custom_tree_attention=False, ): assert seqlens_k is not None onnx_model_str = create_group_query_attention_graph_past( @@ -860,7 +853,7 @@ def gqa_past_func( packed=new_k is None, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() @@ -1255,7 +1248,8 @@ def parity_check_gqa_prompt( if do_custom_tree_attention: custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( - config.batch_size, config.kv_sequence_length, past=False) + config.batch_size, config.kv_sequence_length, past=False + ) else: custom_pos_ids = None custom_causal_attention_mask = None @@ -1310,7 +1304,7 @@ def parity_check_gqa_prompt( rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1331,7 +1325,7 @@ def parity_check_gqa_prompt( rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1464,7 +1458,8 @@ def parity_check_gqa_prompt_no_buff( if do_custom_tree_attention: custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( - config.batch_size, config.kv_sequence_length, past=False) + config.batch_size, config.kv_sequence_length, past=False + ) else: custom_pos_ids = None custom_causal_attention_mask = None @@ -1720,7 +1715,8 @@ def parity_check_gqa_past( if do_custom_tree_attention: custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( - config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True) + config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True + ) else: custom_pos_ids = None custom_causal_attention_mask = None @@ -1746,7 +1742,7 @@ def parity_check_gqa_past( rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) else: out, present_k, present_v = gqa_past_func( @@ -1767,7 +1763,7 @@ def parity_check_gqa_past( rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -1961,7 +1957,8 @@ def parity_check_gqa_past_no_buff( if do_custom_tree_attention: custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( - config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True) + config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True + ) else: custom_pos_ids = None custom_causal_attention_mask = None @@ -2097,7 +2094,7 @@ def test_gqa_no_past(self): packed=packed, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close) all_close = parity_check_gqa_prompt_no_buff( @@ -2109,7 +2106,7 @@ def test_gqa_no_past(self): packed=packed, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close) @@ -2160,7 +2157,7 @@ def test_gqa_past(self): packed=packed, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close) all_close = parity_check_gqa_past_no_buff( @@ -2174,7 +2171,7 @@ def test_gqa_past(self): packed=packed, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close) @@ -2220,7 +2217,7 @@ def test_gqa_interactive_one_batch(self): rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close) all_close = parity_check_gqa_past_no_buff( @@ -2232,7 +2229,7 @@ def test_gqa_interactive_one_batch(self): rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - do_custom_tree_attention=do_custom_tree_attention + do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close) From bc0d69b934b608905a40d5fa874621465fff43fc Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Sat, 8 Mar 2025 11:51:12 +0100 Subject: [PATCH 12/25] Fix docs pipeline build --- docs/ContribOperators.md | 6 +++++- docs/OperatorKernels.md | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 274531faaf717..1d46d0ae4a95e 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2551,7 +2551,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Softcap value for attention weights. Default value is 0.
-#### Inputs (7 - 9) +#### Inputs (7 - 11)
query : T
@@ -2572,6 +2572,10 @@ This version of the operator has been available since version 1 of the 'com.micr
2D tensor with shape (max_sequence_length, head_size / 2).
sin_cache (optional) : T
2D tensor with shape (max_sequence_length, head_size / 2).
+
custom_pos_ids (optional) : tensor(int64)
+
2D tensor with shape (batch_size, sequence_length).
+
custom_causal_attention_mask (optional) : T
+
3D tensor with shape (batch_size, sequence_length, total_sequence_length)
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index c3318752dad55..818297b91eddf 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -520,7 +520,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* custom_pos_ids:**tensor(int64)**
*in* custom_causal_attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| From ab60cbc6cde968a00fbaecbd2260772f6929b88b Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Sat, 8 Mar 2025 23:31:32 +0100 Subject: [PATCH 13/25] Fix docs pipeline build --- docs/OperatorKernels.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 818297b91eddf..d6a151d0011b2 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -922,7 +922,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| ++|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* custom_pos_ids:**tensor(int64)**
*in* custom_causal_attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1399,7 +1399,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* custom_pos_ids:**tensor(int64)**
*in* custom_causal_attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* custom_pos_ids:**tensor(int64)**
*in* custom_causal_attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| From 4e0ca5c27a5f30b1eca855c72021229327273833 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Tue, 11 Mar 2025 00:28:42 +0100 Subject: [PATCH 14/25] Fix first batch of PR comments --- onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h | 3 ++- onnxruntime/test/mlas/unittest/test_eltwise.cpp | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 18c99de0354b5..e868a576e8708 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -261,6 +261,7 @@ class GQAAttentionBase { const size_t window_size = should_apply_local_window ? local_window_size_ + 1 : seq_causal_length; // Apply custom attention mask if there is any + // TODO (#23982): Implement masking during softmax / softcap computation in GQA CPU operator if (attention_mask_batch != nullptr) { if constexpr (std::is_same_v) { ApplyAttentionMask(output_softmax + start_offset, attention_mask_batch + start_offset, @@ -286,7 +287,7 @@ class GQAAttentionBase { } } - // Calculate softmax + // Calculate softcap / softmax if (softcap_ > 0.f) { ComputeAttentionSoftcapInplace(output_softmax + start_offset, static_cast(window_size), static_cast(softcap_)); diff --git a/onnxruntime/test/mlas/unittest/test_eltwise.cpp b/onnxruntime/test/mlas/unittest/test_eltwise.cpp index cc44a296ff330..c4d4b9c0eb317 100644 --- a/onnxruntime/test/mlas/unittest/test_eltwise.cpp +++ b/onnxruntime/test/mlas/unittest/test_eltwise.cpp @@ -15,7 +15,7 @@ class MlasEltwiseAddTest : public MlasTestBase { MatrixGuardBuffer BufferInputRightFp16; MatrixGuardBuffer BufferOutputFp16; - void Test(size_t N, float MinimumValue, float MaximumValue) { + void Test(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) { float* InputLeft = BufferInputLeft.GetBuffer(N); float* InputRight = BufferInputRight.GetBuffer(N); float* Output = BufferOutput.GetBuffer(N); @@ -26,7 +26,7 @@ class MlasEltwiseAddTest : public MlasTestBase { for (size_t n = 0; n < N; n++) { InputLeft[n] = distribution(generator); - InputRight[n] = distribution(generator); + InputRight[n] = ScalarValue.value_or(distribution(generator)); } for (size_t n = 0; n < N; n++) { @@ -47,7 +47,7 @@ class MlasEltwiseAddTest : public MlasTestBase { #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) - void TestFp16(size_t N, float MinimumValue, float MaximumValue) { + void TestFp16(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) { MLAS_FP16* InputLeft = BufferInputLeftFp16.GetBuffer(N); MLAS_FP16* InputRight = BufferInputRightFp16.GetBuffer(N); MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); @@ -57,7 +57,7 @@ class MlasEltwiseAddTest : public MlasTestBase { for (size_t n = 0; n < N; n++) { InputLeft[n] = MLAS_FP16(distribution(generator)); - InputRight[n] = MLAS_FP16(distribution(generator)); + InputRight[n] = MLAS_FP16(ScalarValue.value_or(distribution(generator))); } MlasEltwiseAdd(InputLeft, InputRight, Output, N); @@ -88,8 +88,10 @@ class MlasEltwiseAddTest : public MlasTestBase { void ExecuteShort(void) override { for (size_t n = 1; n < 128; n++) { Test(n, -10.f, 10.f); + Test(n, -10.f, 10.f, -5000.f); #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) TestFp16(n, -17.f, 11.f); + TestFp16(n, -17.f, 11.f, -5000.f); #endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) } } From 949118f5b6890aec8bfb6d641aed5c28c4c48d40 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Thu, 13 Mar 2025 01:32:45 +0100 Subject: [PATCH 15/25] Fix PR comments 1. Rename new inputs to pos_ids and attention_mask 2. Make attention mask 4D tensor of shape (B or 1, H or 1, S, T) 3. Handl new attention mask offset and broadcasting 4. Update python test to reflect new op input names and shape --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 89 ++++++----- .../cpu/bert/group_query_attention.cc | 24 +-- .../cpu/bert/group_query_attention_helper.h | 35 ++-- .../core/graph/contrib_ops/bert_defs.cc | 7 +- .../test/python/transformers/test_gqa_cpu.py | 150 +++++++++--------- 5 files changed, 160 insertions(+), 145 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index e868a576e8708..405327794164c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -50,7 +50,7 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH - const Tensor* attention_mask, // Causal attention mask to apply before + const Tensor* attention_mask, // Attention mask to apply before softmax / softcap const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) Tensor* output, // output tensor @@ -89,17 +89,17 @@ class GQAAttentionBase { T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; const T* attention_mask_data = attention_mask != nullptr ? attention_mask->Data() : nullptr; - const size_t attention_mask_total_seqlen = - attention_mask != nullptr ? static_cast(attention_mask->Shape()[2]) : static_cast(0); + auto attention_mask_shape = attention_mask != nullptr ? attention_mask->Shape().GetDims() : gsl::span{}; bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; if (gqa_mlas_supported) { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_mask_data, batch_size, - sequence_length, attention_mask_total_seqlen, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_mask_data, + batch_size, sequence_length, attention_mask_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, + head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, + tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -110,9 +110,9 @@ class GQAAttentionBase { is_prompt, tp, allocator); } else { ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_mask_data, - batch_size, sequence_length, attention_mask_total_seqlen, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, - past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, - allocator); + batch_size, sequence_length, attention_mask_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, + head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, + tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -132,24 +132,24 @@ class GQAAttentionBase { // attention_probs(B, N, S, T) = Softmax(attention_probs) // If T is float32, U is float32. If T is float16, U could be float16 or float32. template - void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - const int32_t* seqlens_k, // total - 1 sequence lengths tensor - const T* attention_mask, // optional causal attention mask - const size_t batch_size, // batch size of self-attention - const size_t sequence_length, // sequence length of self-attention (S) - const size_t attention_mask_total_seqlen, // max total seqlen in batch used for attention last dim - const size_t past_buffer_sequence_length, // sequence length of past state - const size_t present_buffer_sequence_length, // sequence length of present state - const size_t head_size, // head size of self-attention - const T* past_key, // past key only - T* present_key, // present key only - const bool past_present_share_buffer, // whether present key and value share the same buffer - const bool packed_qkv, // whether Q, K, V are packed - const bool is_prompt, // whether it is prompt - ThreadPool* tp, // thread pool - AllocatorPtr allocator) const { // allocator for temporary buffer + void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const T* attention_mask, // optional attention mask + const size_t batch_size, // batch size of self-attention + const size_t sequence_length, // sequence length of self-attention (S) + const gsl::span attention_mask_shape, // shape of the attention mask + const size_t past_buffer_sequence_length, // sequence length of past state + const size_t present_buffer_sequence_length, // sequence length of present state + const size_t head_size, // head size of self-attention + const T* past_key, // past key only + T* present_key, // present key only + const bool past_present_share_buffer, // whether present key and value share the same buffer + const bool packed_qkv, // whether Q, K, V are packed + const bool is_prompt, // whether it is prompt + ThreadPool* tp, // thread pool + AllocatorPtr allocator) const { // allocator for temporary buffer const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); @@ -197,8 +197,23 @@ class GQAAttentionBase { const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; U* output = attention_probs + output_offset; - const ptrdiff_t attention_mask_offset = SafeInt(batch_index) * sequence_length * attention_mask_total_seqlen; - const T* attention_mask_batch = attention_mask != nullptr ? attention_mask + attention_mask_offset : nullptr; + // Compute attention mask offset based on the batch and head indexes + // Attention mask is of shape (B or 1, H or 1, S, T) so handle broadcasting + const T* attention_mask_thread = nullptr; + ptrdiff_t attention_total_seqlen = 0; + if (attention_mask != nullptr) { + ptrdiff_t attention_mask_offset = 0; + attention_total_seqlen = static_cast(attention_mask_shape[3]); + const ptrdiff_t attention_matrix_size = SafeInt(attention_mask_shape[2]) * attention_total_seqlen; + if (attention_mask_shape[0] != 1) { + attention_mask_offset += SafeInt(batch_index) * attention_mask_shape[1] * attention_matrix_size; + } + if (attention_mask_shape[1] != 1) { + attention_mask_offset += SafeInt(head_index) * attention_matrix_size; + } + + attention_mask_thread = attention_mask + attention_mask_offset; + } const T* k; if (packed_qkv) { @@ -262,17 +277,17 @@ class GQAAttentionBase { // Apply custom attention mask if there is any // TODO (#23982): Implement masking during softmax / softcap computation in GQA CPU operator - if (attention_mask_batch != nullptr) { + if (attention_mask_thread != nullptr) { if constexpr (std::is_same_v) { - ApplyAttentionMask(output_softmax + start_offset, attention_mask_batch + start_offset, + ApplyAttentionMask(output_softmax + start_offset, attention_mask_thread + start_offset, static_cast(window_size)); } else { size_t bytes = window_size * sizeof(float); - auto attention_mask_batch_fp32 = static_cast(allocator->Alloc(bytes)); - BufferUniquePtr scratch_buffer(attention_mask_batch_fp32, BufferDeleter(allocator)); + auto attention_mask_thread_fp32 = static_cast(allocator->Alloc(bytes)); + BufferUniquePtr scratch_buffer(attention_mask_thread_fp32, BufferDeleter(allocator)); - MlasConvertHalfToFloatBuffer(attention_mask_batch + start_offset, attention_mask_batch_fp32, window_size); - ApplyAttentionMask(output_softmax, attention_mask_batch_fp32, static_cast(window_size)); + MlasConvertHalfToFloatBuffer(attention_mask_thread + start_offset, attention_mask_thread_fp32, window_size); + ApplyAttentionMask(output_softmax, attention_mask_thread_fp32, static_cast(window_size)); } } @@ -309,8 +324,8 @@ class GQAAttentionBase { output_softmax += present_buffer_sequence_length; - if (attention_mask_batch != nullptr) { - attention_mask_batch += attention_mask_total_seqlen; + if (attention_mask_thread != nullptr) { + attention_mask_thread += attention_total_seqlen; } } } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 4a22ecf5339fc..5227dd194aa51 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -52,8 +52,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* total_seqlen_tensor = context->Input(6); const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); - const Tensor* custom_pos_ids = context->Input(9); - const Tensor* custom_causal_attention_mask = context->Input(10); + const Tensor* pos_ids = context->Input(9); + const Tensor* attention_mask = context->Input(10); GroupQueryAttentionParameters parameters = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, @@ -73,8 +73,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const int32_t* seqlens_k_data = seqlens_k->Data(); const int32_t max_seqlens_k = *std::max_element(seqlens_k_data, seqlens_k_data + parameters.batch_size); - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(custom_pos_ids, - custom_causal_attention_mask, + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(pos_ids, + attention_mask, max_seqlens_k, parameters)); @@ -138,13 +138,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { auto* tp = context->GetOperatorThreadPool(); // Generate position ids const int pos_ids_size = parameters.is_first_prompt ? 1 : batch_size * sequence_length; - std::vector pos_ids(pos_ids_size); - const int64_t* pos_ids_data = pos_ids.data(); + std::vector default_pos_ids(pos_ids_size); + const int64_t* pos_ids_data = default_pos_ids.data(); - if (custom_pos_ids != nullptr) { - pos_ids_data = custom_pos_ids->Data(); + if (pos_ids != nullptr) { + pos_ids_data = pos_ids->Data(); } else if (parameters.is_first_prompt) { - pos_ids[0] = static_cast(0); + default_pos_ids[0] = static_cast(0); } else { // Note: As of now, continuous decoding supports only batch size 1 and token generation supports only sequence length 1. for (int b = 0; b < batch_size; b++) { @@ -152,9 +152,9 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const int past_seqlen = total_seqlen - sequence_length; for (int s = 0; s < sequence_length; s++) { if (past_seqlen + s < total_seqlen) { - pos_ids[b * sequence_length + s] = static_cast(past_seqlen) + s; + default_pos_ids[b * sequence_length + s] = static_cast(past_seqlen) + s; } else { - pos_ids[b * sequence_length + s] = static_cast(1); + default_pos_ids[b * sequence_length + s] = static_cast(1); } } } @@ -209,7 +209,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - custom_causal_attention_mask, past_key, past_value, output, present_k, present_v, + attention_mask, past_key, past_value, output, present_k, present_v, seqlens_k, parameters, allocator, context); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 82a85474e9e67..f66cf06e9174a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -290,45 +290,50 @@ Status CheckInputs(const T* query, } template -Status CheckCustomAttentionInputs(const T* custom_pos_ids, - const T* custom_causal_attention_mask, +Status CheckCustomAttentionInputs(const T* pos_ids, + const T* attention_mask, const int max_seqlens_k, const GroupQueryAttentionParameters& parameters) { - if (custom_pos_ids != nullptr) { - const auto& pos_ids_shape = custom_pos_ids->Shape(); + if (pos_ids != nullptr) { + const auto& pos_ids_shape = pos_ids->Shape(); if (parameters.is_first_prompt) { if (pos_ids_shape[0] != 1 || pos_ids_shape[1] != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Shape of custom_pos_ids must be [1, 1] when processing the prompt"); + "Shape of pos_ids must be [1, 1] when processing the prompt"); } } else { if (pos_ids_shape[0] != parameters.batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "custom_pos_ids dimension 0 must be equal to the batch size, got ", pos_ids_shape[0]); + "pos_ids dimension 0 must be equal to the batch size, got ", pos_ids_shape[0]); } if (pos_ids_shape[1] < parameters.sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "custom_pos_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]); + "pos_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]); } } } - if (custom_causal_attention_mask != nullptr) { - const auto& mask_shape = custom_causal_attention_mask->Shape(); - if (mask_shape[0] != parameters.batch_size) { + if (attention_mask != nullptr) { + const auto& mask_shape = attention_mask->Shape(); + if ((mask_shape[0] != parameters.batch_size) && (mask_shape[0] != 1)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "custom_causal_attention_mask dimension 0 must be equal to the batch size, got ", mask_shape[0]); + "attention_mask dimension 0 must be equal to the batch size or 1, got ", mask_shape[0]); } - if (mask_shape[1] != parameters.sequence_length) { + if ((mask_shape[1] != parameters.num_heads) && (mask_shape[1] != 1)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "custom_causal_attention_mask dimension 1 must be equal to the sequence length, got ", mask_shape[1]); + "attention_mask dimension 1 must be equal to the num heads or 1, got ", mask_shape[1]); } - if (mask_shape[2] < max_seqlens_k + 1) { + if (mask_shape[2] != parameters.sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "custom_causal_attention_mask dimension 2 must be atleast max(seqlens_k) + 1, got ", mask_shape[2]); + "attention_mask dimension 2 must be equal to the sequence length, got ", mask_shape[2]); + } + + if (mask_shape[3] < max_seqlens_k + 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "attention_mask dimension 3 must be atleast max(seqlens_k) + 1, got ", mask_shape[3]); } } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index e724aab81bb8e..062c627495c0d 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1129,13 +1129,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T", OpSchema::Optional) .Input(9, - "custom_pos_ids", + "pos_ids", "2D tensor with shape (batch_size, sequence_length).", "tensor(int64)", OpSchema::Optional) .Input(10, - "custom_causal_attention_mask", - "3D tensor with shape (batch_size, sequence_length, total_sequence_length)", + "attention_mask", + "4D tensor with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length). " + "The op only supports local and causal attention.", "T", OpSchema::Optional) .Output(0, diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index dcf6be48e1c2e..2506b505993be 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -30,7 +30,7 @@ GREEN = "\033[32m" RESET = "\033[0m" -ORT_TYPE = TensorProto.FLOAT +ORT_TYPE = TensorProto.FLOAT16 TORCH_TYPE = torch.float16 if ORT_TYPE == TensorProto.FLOAT16 else torch.float32 NUMPY_TYPE = numpy.float16 if ORT_TYPE == TensorProto.FLOAT16 else numpy.float32 RTOL = 3e-2 if ORT_TYPE == TensorProto.FLOAT16 else 1e-3 @@ -159,8 +159,8 @@ def create_group_query_attention_graph_prompt( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", - "custom_pos_ids" if do_custom_tree_attention else "", - "custom_causal_attention_mask" if do_custom_tree_attention else "", + "pos_ids" if do_custom_tree_attention else "", + "attention_mask" if do_custom_tree_attention else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -268,11 +268,11 @@ def create_group_query_attention_graph_prompt( if do_custom_tree_attention: graph_input += [ - helper.make_tensor_value_info("custom_pos_ids", TensorProto.INT64, [1, 1]), + helper.make_tensor_value_info("pos_ids", TensorProto.INT64, [1, 1]), helper.make_tensor_value_info( - "custom_causal_attention_mask", + "attention_mask", ORT_TYPE, - [config.batch_size, config.kv_sequence_length, config.kv_sequence_length], + [config.batch_size, 1, config.kv_sequence_length, config.kv_sequence_length], ), ] @@ -367,8 +367,8 @@ def create_group_query_attention_graph_past( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", - "custom_pos_ids" if do_custom_tree_attention else "", - "custom_causal_attention_mask" if do_custom_tree_attention else "", + "pos_ids" if do_custom_tree_attention else "", + "attention_mask" if do_custom_tree_attention else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -473,13 +473,11 @@ def create_group_query_attention_graph_past( if do_custom_tree_attention: graph_input += [ + helper.make_tensor_value_info("pos_ids", TensorProto.INT64, [config.batch_size, config.sequence_length]), helper.make_tensor_value_info( - "custom_pos_ids", TensorProto.INT64, [config.batch_size, config.sequence_length] - ), - helper.make_tensor_value_info( - "custom_causal_attention_mask", + "attention_mask", ORT_TYPE, - [config.batch_size, config.sequence_length, max_seqlen_in_batch], + [config.batch_size, 1, config.sequence_length, max_seqlen_in_batch], ), ] @@ -697,8 +695,8 @@ def gqa_prompt_func( cos=None, sin=None, seqlens_k=None, - custom_pos_ids=None, - custom_causal_attention_mask=None, + pos_ids=None, + attention_mask=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True, @@ -725,8 +723,8 @@ def gqa_prompt_func( past_v = v.clone() if share_buffer else None if do_custom_tree_attention: - assert custom_pos_ids is not None - assert custom_causal_attention_mask is not None + assert pos_ids is not None + assert attention_mask is not None if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) @@ -755,10 +753,10 @@ def gqa_prompt_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) if do_custom_tree_attention: - ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() - ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() - io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) - io_binding.bind_cpu_input("custom_causal_attention_mask", ort_inputs["custom_causal_attention_mask"]) + ort_inputs["pos_ids"] = pos_ids.detach().cpu().numpy() + ort_inputs["attention_mask"] = attention_mask.detach().cpu().numpy() + io_binding.bind_cpu_input("pos_ids", ort_inputs["pos_ids"]) + io_binding.bind_cpu_input("attention_mask", ort_inputs["attention_mask"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( @@ -803,10 +801,10 @@ def gqa_prompt_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) if do_custom_tree_attention: - ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() - ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() - io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) - io_binding.bind_cpu_input("custom_causal_attention_mask", ort_inputs["custom_causal_attention_mask"]) + ort_inputs["pos_ids"] = pos_ids.detach().cpu().numpy() + ort_inputs["attention_mask"] = attention_mask.detach().cpu().numpy() + io_binding.bind_cpu_input("pos_ids", ort_inputs["pos_ids"]) + io_binding.bind_cpu_input("attention_mask", ort_inputs["attention_mask"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) @@ -831,8 +829,8 @@ def gqa_past_func( cos=None, sin=None, seqlens_k=None, - custom_pos_ids=None, - custom_causal_attention_mask=None, + pos_ids=None, + attention_mask=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1, @@ -860,8 +858,8 @@ def gqa_past_func( past_v = v.clone() if do_custom_tree_attention: - assert custom_pos_ids is not None - assert custom_causal_attention_mask is not None + assert pos_ids is not None + assert attention_mask is not None if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) @@ -892,10 +890,10 @@ def gqa_past_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) if do_custom_tree_attention: - ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() - ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() - io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) - io_binding.bind_cpu_input("custom_causal_attention_mask", ort_inputs["custom_causal_attention_mask"]) + ort_inputs["pos_ids"] = pos_ids.detach().cpu().numpy() + ort_inputs["attention_mask"] = attention_mask.detach().cpu().numpy() + io_binding.bind_cpu_input("pos_ids", ort_inputs["pos_ids"]) + io_binding.bind_cpu_input("attention_mask", ort_inputs["attention_mask"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( @@ -947,10 +945,10 @@ def gqa_past_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) if do_custom_tree_attention: - ort_inputs["custom_pos_ids"] = custom_pos_ids.detach().cpu().numpy() - ort_inputs["custom_causal_attention_mask"] = custom_causal_attention_mask.detach().cpu().numpy() - io_binding.bind_cpu_input("custom_pos_ids", ort_inputs["custom_pos_ids"]) - io_binding.bind_cpu_input("custom_causal_attention_mask", ort_inputs["custom_causal_attention_mask"]) + ort_inputs["pos_ids"] = pos_ids.detach().cpu().numpy() + ort_inputs["attention_mask"] = attention_mask.detach().cpu().numpy() + io_binding.bind_cpu_input("pos_ids", ort_inputs["pos_ids"]) + io_binding.bind_cpu_input("attention_mask", ort_inputs["attention_mask"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) @@ -1124,26 +1122,26 @@ def attention_qkvpacked_ref( def get_custom_attention_inputs(batch_size, sequence_length, seqlens_k=None, past=False): if past: assert seqlens_k is not None - custom_pos_ids_data = [] + pos_ids_data = [] max_seqlen_in_batch = seqlens_k.max().item() + 1 - custom_causal_attention_mask = torch.zeros((batch_size, sequence_length, max_seqlen_in_batch), dtype=TORCH_TYPE) + attention_mask = torch.zeros((batch_size, 1, sequence_length, max_seqlen_in_batch), dtype=TORCH_TYPE) for b in range(batch_size): total_seq_len = seqlens_k[b] + 1 past_seq_len = total_seq_len - sequence_length - custom_pos_ids_data.append(list(range(past_seq_len, past_seq_len + sequence_length))) + pos_ids_data.append(list(range(past_seq_len, past_seq_len + sequence_length))) # Configure mask for i in range(sequence_length): for j in range(past_seq_len + i + 1, max_seqlen_in_batch): - custom_causal_attention_mask[b][i][j] = -5000 + attention_mask[b][0][i][j] = -5000 - custom_pos_ids = torch.tensor(data=custom_pos_ids_data, dtype=torch.int64) + pos_ids = torch.tensor(data=pos_ids_data, dtype=torch.int64) else: - custom_pos_ids = torch.tensor(data=[[0]], dtype=torch.int64) - custom_causal_attention_mask = torch.rand(batch_size, sequence_length, sequence_length, dtype=TORCH_TYPE) - custom_causal_attention_mask = torch.triu(custom_causal_attention_mask, diagonal=1) + pos_ids = torch.tensor(data=[[0]], dtype=torch.int64) + attention_mask = torch.rand(batch_size, 1, sequence_length, sequence_length, dtype=TORCH_TYPE) + attention_mask = torch.triu(attention_mask, diagonal=1) - return custom_pos_ids, custom_causal_attention_mask + return pos_ids, attention_mask def parity_check_gqa_prompt( @@ -1247,12 +1245,10 @@ def parity_check_gqa_prompt( q_ro, k_ro = q, new_k if do_custom_tree_attention: - custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( - config.batch_size, config.kv_sequence_length, past=False - ) + pos_ids, attention_mask = get_custom_attention_inputs(config.batch_size, config.kv_sequence_length, past=False) else: - custom_pos_ids = None - custom_causal_attention_mask = None + pos_ids = None + attention_mask = None rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s") @@ -1296,8 +1292,8 @@ def parity_check_gqa_prompt( cos, sin, cache_seqlens - 1, - custom_pos_ids, - custom_causal_attention_mask, + pos_ids, + attention_mask, left_window_size, past_format, True, @@ -1317,8 +1313,8 @@ def parity_check_gqa_prompt( cos, sin, cache_seqlens - 1, - custom_pos_ids, - custom_causal_attention_mask, + pos_ids, + attention_mask, left_window_size, past_format, True, @@ -1457,12 +1453,10 @@ def parity_check_gqa_prompt_no_buff( k_cache_ref = k_ro if do_custom_tree_attention: - custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( - config.batch_size, config.kv_sequence_length, past=False - ) + pos_ids, attention_mask = get_custom_attention_inputs(config.batch_size, config.kv_sequence_length, past=False) else: - custom_pos_ids = None - custom_causal_attention_mask = None + pos_ids = None + attention_mask = None brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") @@ -1500,8 +1494,8 @@ def parity_check_gqa_prompt_no_buff( cos, sin, cache_seqlens - 1, - custom_pos_ids, - custom_causal_attention_mask, + pos_ids, + attention_mask, left_window_size, past_format, False, @@ -1520,8 +1514,8 @@ def parity_check_gqa_prompt_no_buff( cos, sin, cache_seqlens - 1, - custom_pos_ids, - custom_causal_attention_mask, + pos_ids, + attention_mask, left_window_size, past_format, False, @@ -1714,12 +1708,12 @@ def parity_check_gqa_past( cache_seqlens += config.sequence_length - 1 if do_custom_tree_attention: - custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( + pos_ids, attention_mask = get_custom_attention_inputs( config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True ) else: - custom_pos_ids = None - custom_causal_attention_mask = None + pos_ids = None + attention_mask = None # ORT function if packed: @@ -1734,8 +1728,8 @@ def parity_check_gqa_past( cos, sin, cache_seqlens, - custom_pos_ids, - custom_causal_attention_mask, + pos_ids, + attention_mask, past_format, True, left_window_size, @@ -1755,8 +1749,8 @@ def parity_check_gqa_past( cos, sin, cache_seqlens, - custom_pos_ids, - custom_causal_attention_mask, + pos_ids, + attention_mask, past_format, True, left_window_size, @@ -1956,12 +1950,12 @@ def parity_check_gqa_past_no_buff( cache_seqlens += config.sequence_length - 1 if do_custom_tree_attention: - custom_pos_ids, custom_causal_attention_mask = get_custom_attention_inputs( + pos_ids, attention_mask = get_custom_attention_inputs( config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True ) else: - custom_pos_ids = None - custom_causal_attention_mask = None + pos_ids = None + attention_mask = None # Flash function if packed: @@ -1976,8 +1970,8 @@ def parity_check_gqa_past_no_buff( cos, sin, cache_seqlens, - custom_pos_ids, - custom_causal_attention_mask, + pos_ids, + attention_mask, past_format, False, window_size=left_window_size, @@ -1996,8 +1990,8 @@ def parity_check_gqa_past_no_buff( cos, sin, cache_seqlens, - custom_pos_ids, - custom_causal_attention_mask, + pos_ids, + attention_mask, past_format, False, window_size=left_window_size, From 62d39a5afaaea58e988332452c8d53f343a9f3c4 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Thu, 13 Mar 2025 01:48:57 +0100 Subject: [PATCH 16/25] Linter fix --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 405327794164c..820646d6df809 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -132,24 +132,24 @@ class GQAAttentionBase { // attention_probs(B, N, S, T) = Softmax(attention_probs) // If T is float32, U is float32. If T is float16, U could be float16 or float32. template - void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - const int32_t* seqlens_k, // total - 1 sequence lengths tensor - const T* attention_mask, // optional attention mask - const size_t batch_size, // batch size of self-attention - const size_t sequence_length, // sequence length of self-attention (S) - const gsl::span attention_mask_shape, // shape of the attention mask - const size_t past_buffer_sequence_length, // sequence length of past state - const size_t present_buffer_sequence_length, // sequence length of present state - const size_t head_size, // head size of self-attention - const T* past_key, // past key only - T* present_key, // present key only - const bool past_present_share_buffer, // whether present key and value share the same buffer - const bool packed_qkv, // whether Q, K, V are packed - const bool is_prompt, // whether it is prompt - ThreadPool* tp, // thread pool - AllocatorPtr allocator) const { // allocator for temporary buffer + void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const T* attention_mask, // optional attention mask + const size_t batch_size, // batch size of self-attention + const size_t sequence_length, // sequence length of self-attention (S) + const gsl::span attention_mask_shape, // shape of the attention mask + const size_t past_buffer_sequence_length, // sequence length of past state + const size_t present_buffer_sequence_length, // sequence length of present state + const size_t head_size, // head size of self-attention + const T* past_key, // past key only + T* present_key, // present key only + const bool past_present_share_buffer, // whether present key and value share the same buffer + const bool packed_qkv, // whether Q, K, V are packed + const bool is_prompt, // whether it is prompt + ThreadPool* tp, // thread pool + AllocatorPtr allocator) const { // allocator for temporary buffer const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); From 0349678c7786a5fc80ee6d19f1b00d1d4f9dffea Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Thu, 13 Mar 2025 07:42:37 +0100 Subject: [PATCH 17/25] Update attention_mask input description --- onnxruntime/core/graph/contrib_ops/bert_defs.cc | 3 +-- onnxruntime/test/python/transformers/test_gqa_cpu.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 062c627495c0d..0a0f36ab7658e 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1135,8 +1135,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(10, "attention_mask", - "4D tensor with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length). " - "The op only supports local and causal attention.", + "4D tensor with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)." "T", OpSchema::Optional) .Output(0, diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 2506b505993be..0301e13065618 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -30,7 +30,7 @@ GREEN = "\033[32m" RESET = "\033[0m" -ORT_TYPE = TensorProto.FLOAT16 +ORT_TYPE = TensorProto.FLOAT TORCH_TYPE = torch.float16 if ORT_TYPE == TensorProto.FLOAT16 else torch.float32 NUMPY_TYPE = numpy.float16 if ORT_TYPE == TensorProto.FLOAT16 else numpy.float32 RTOL = 3e-2 if ORT_TYPE == TensorProto.FLOAT16 else 1e-3 From 0865ddbfa2db9e1b68ceec4fec9fcd4d8f99bb77 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Thu, 13 Mar 2025 09:27:58 +0100 Subject: [PATCH 18/25] Fix build break --- onnxruntime/core/graph/contrib_ops/bert_defs.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 0a0f36ab7658e..7033cbc1abef4 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1135,7 +1135,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(10, "attention_mask", - "4D tensor with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)." + "4D tensor with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length).", "T", OpSchema::Optional) .Output(0, From 55e09c9a4408708d970869bdc83a7902d3690482 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Thu, 13 Mar 2025 11:02:31 +0100 Subject: [PATCH 19/25] Fix docs gen CI pipeline --- docs/ContribOperators.md | 6 +++--- docs/OperatorKernels.md | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 1d46d0ae4a95e..a2ced1571d0bc 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2572,10 +2572,10 @@ This version of the operator has been available since version 1 of the 'com.micr
2D tensor with shape (max_sequence_length, head_size / 2).
sin_cache (optional) : T
2D tensor with shape (max_sequence_length, head_size / 2).
-
custom_pos_ids (optional) : tensor(int64)
+
pos_ids (optional) : tensor(int64)
2D tensor with shape (batch_size, sequence_length).
-
custom_causal_attention_mask (optional) : T
-
3D tensor with shape (batch_size, sequence_length, total_sequence_length)
+
attention_mask (optional) : T
+
4D tensor with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length).
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d6a151d0011b2..e8e98f3ddd00a 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -520,7 +520,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* custom_pos_ids:**tensor(int64)**
*in* custom_causal_attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* pos_ids:**tensor(int64)**
*in* attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -922,7 +922,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* custom_pos_ids:**tensor(int64)**
*in* custom_causal_attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* pos_ids:**tensor(int64)**
*in* attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1399,7 +1399,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* custom_pos_ids:**tensor(int64)**
*in* custom_causal_attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* pos_ids:**tensor(int64)**
*in* attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| From e3bc338cbd1b56aeea5d606102fa21b8a1e836f0 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Thu, 13 Mar 2025 20:05:19 +0100 Subject: [PATCH 20/25] Apply attention mask after softcap --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 820646d6df809..d32b6dd257f68 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -50,7 +50,7 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH - const Tensor* attention_mask, // Attention mask to apply before softmax / softcap + const Tensor* attention_mask, // Attention mask to apply before softmax const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) Tensor* output, // output tensor @@ -275,22 +275,6 @@ class GQAAttentionBase { const size_t start_offset = should_apply_local_window ? seq_causal_length - local_window_size_ - 1 : 0; const size_t window_size = should_apply_local_window ? local_window_size_ + 1 : seq_causal_length; - // Apply custom attention mask if there is any - // TODO (#23982): Implement masking during softmax / softcap computation in GQA CPU operator - if (attention_mask_thread != nullptr) { - if constexpr (std::is_same_v) { - ApplyAttentionMask(output_softmax + start_offset, attention_mask_thread + start_offset, - static_cast(window_size)); - } else { - size_t bytes = window_size * sizeof(float); - auto attention_mask_thread_fp32 = static_cast(allocator->Alloc(bytes)); - BufferUniquePtr scratch_buffer(attention_mask_thread_fp32, BufferDeleter(allocator)); - - MlasConvertHalfToFloatBuffer(attention_mask_thread + start_offset, attention_mask_thread_fp32, window_size); - ApplyAttentionMask(output_softmax, attention_mask_thread_fp32, static_cast(window_size)); - } - } - // Mask everything before local window, if local window should be applied if (should_apply_local_window) { for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { @@ -302,11 +286,27 @@ class GQAAttentionBase { } } - // Calculate softcap / softmax if (softcap_ > 0.f) { ComputeAttentionSoftcapInplace(output_softmax + start_offset, static_cast(window_size), static_cast(softcap_)); } + + // Apply custom attention mask if there is any + // TODO (#23982): Implement masking during softmax computation in GQA CPU operator + if (attention_mask_thread != nullptr) { + if constexpr (std::is_same_v) { + ApplyAttentionMask(output_softmax + start_offset, attention_mask_thread + start_offset, + static_cast(window_size)); + } else { + size_t bytes = window_size * sizeof(float); + auto attention_mask_thread_fp32 = static_cast(allocator->Alloc(bytes)); + BufferUniquePtr scratch_buffer(attention_mask_thread_fp32, BufferDeleter(allocator)); + + MlasConvertHalfToFloatBuffer(attention_mask_thread + start_offset, attention_mask_thread_fp32, window_size); + ApplyAttentionMask(output_softmax, attention_mask_thread_fp32, static_cast(window_size)); + } + } + if (use_smooth_softmax_) { ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); } else { From 757af32c7101569830825cdf320d399b142a6e13 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Thu, 13 Mar 2025 20:06:03 +0100 Subject: [PATCH 21/25] Cleanup mlas eltwise module --- onnxruntime/core/mlas/lib/eltwise.cpp | 29 ++++--------------- onnxruntime/core/mlas/lib/eltwise.h | 2 +- .../core/mlas/lib/eltwise_kernel_neon.cpp | 2 +- .../core/mlas/lib/eltwise_kernel_neon.h | 2 +- .../mlas/lib/eltwise_kernel_neon_fp16.cpp | 2 +- 5 files changed, 10 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/mlas/lib/eltwise.cpp b/onnxruntime/core/mlas/lib/eltwise.cpp index a9b319804d4c0..f63d71b40bfbb 100644 --- a/onnxruntime/core/mlas/lib/eltwise.cpp +++ b/onnxruntime/core/mlas/lib/eltwise.cpp @@ -10,9 +10,9 @@ Module Name: Abstract: - This module implements routines to compute eltwise operations on two vectors. + This module implements routines to compute element-wise operations on two vectors. - Currently supported eltwise operations: + Currently supported element-wise operations: - Add --*/ @@ -30,29 +30,12 @@ MlasEltwiseAdd( size_t N ) { while (N > 0) { - MLAS_FLOAT32X4 LeftVec, RightVec; - if (N >= 4) { - LeftVec = MlasLoadFloat32x4(left); - RightVec = MlasLoadFloat32x4(right); - } else { -#if defined(MLAS_SSE2_INTRINSICS) - // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle - // and use zeroes for the upper elements. - LeftVec = _mm_load_ss(left); - RightVec = _mm_load_ss(right); -#elif defined(MLAS_LSX_INTRINSICS) - LeftVec = (MLAS_FLOAT32X4)__lsx_vldrepl_w(left, 0); - RightVec = (MLAS_FLOAT32X4)__lsx_vldrepl_w(right, 0); -#else - LeftVec = MlasBroadcastFloat32x4(left); - RightVec = MlasBroadcastFloat32x4(right); -#endif - } + MLAS_FLOAT32X4 LeftVec = MlasLoadFloat32x4(left); + MLAS_FLOAT32X4 RightVec = MlasLoadFloat32x4(right); - MLAS_FLOAT32X4 ResultVec = MlasAddFloat32x4(LeftVec, RightVec); + MLAS_FLOAT32X4 ResultVec = MlasAddFloat32x4(LeftVec, RightVec); - if (N >= 4) { MlasStoreFloat32x4(output, ResultVec); left += 4; @@ -60,7 +43,7 @@ MlasEltwiseAdd( output += 4; N -= 4; } else { - MlasStoreLaneFloat32x4<0>(output, ResultVec); + *output = *left + *right; left += 1; right += 1; diff --git a/onnxruntime/core/mlas/lib/eltwise.h b/onnxruntime/core/mlas/lib/eltwise.h index 582899a2db24d..a8345c499f6b7 100644 --- a/onnxruntime/core/mlas/lib/eltwise.h +++ b/onnxruntime/core/mlas/lib/eltwise.h @@ -11,7 +11,7 @@ Module Name: Abstract: This module includes kernel function prototypes and helper functions for - eltwise operations. + element-wise operations. --*/ #pragma once diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp index 02ad05b2bbd7c..415c1281c808e 100644 --- a/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.cpp @@ -10,7 +10,7 @@ Module Name: Abstract: - This module implements the eltwise kernels for ARM NEON. + This module implements the element-wise kernels for ARM NEON. --*/ diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h index f9eb7b1ed81f7..d99a3e97c21f2 100644 --- a/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon.h @@ -11,7 +11,7 @@ Module Name: Abstract: This module includes function declarations and common helper functions for - eltwise operations on ARM cpu. + element-wise operations on ARM cpu. --*/ diff --git a/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp index 97b88983e16a0..decbdb576d5cd 100644 --- a/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/eltwise_kernel_neon_fp16.cpp @@ -10,7 +10,7 @@ Module Name: Abstract: - This module implements the fp16 eltwise kernels for ARM NEON. + This module implements the fp16 element-wise kernels for ARM NEON. --*/ #include From 0c268c94cc1bb430d734d6c837315599adb84505 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Fri, 14 Mar 2025 00:07:35 +0100 Subject: [PATCH 22/25] Fix PR comments 1. Rename attention_mask -> attention_bias 2. Make last dim of attention_bias to be total_sequence_length 3. Rename pos_ids -> position_ids 4. Disallow custom position_ids when processing the first prompt 5. Add static assert for fp32 bias upscale --- .../contrib_ops/cpu/bert/attention_helper.h | 2 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 63 ++++---- .../cpu/bert/group_query_attention.cc | 19 +-- .../cpu/bert/group_query_attention_helper.h | 51 +++--- .../core/graph/contrib_ops/bert_defs.cc | 6 +- .../test/python/transformers/test_gqa_cpu.py | 151 +++++++++--------- 6 files changed, 140 insertions(+), 152 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 043cf12ad00f5..ac32a4445f3ca 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -32,7 +32,7 @@ void ComputeAttentionSoftcapInplace(T* scores, int sequence_length, T softcap) { } template -void ApplyAttentionMask(T* softmax_logits, const T* attention_mask, int N) { +void ApplyAttentionBias(T* softmax_logits, const T* attention_mask, int N) { MlasEltwiseAdd(softmax_logits, attention_mask, softmax_logits, N); } diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index d32b6dd257f68..ff6cb8edc0231 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -50,7 +50,7 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH - const Tensor* attention_mask, // Attention mask to apply before softmax + const Tensor* attention_bias, // Attention bias to add to QxK' const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) Tensor* output, // output tensor @@ -88,16 +88,16 @@ class GQAAttentionBase { const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; - const T* attention_mask_data = attention_mask != nullptr ? attention_mask->Data() : nullptr; - auto attention_mask_shape = attention_mask != nullptr ? attention_mask->Shape().GetDims() : gsl::span{}; + const T* attention_bias_data = attention_bias != nullptr ? attention_bias->Data() : nullptr; + auto attention_bias_shape = attention_bias != nullptr ? attention_bias->Shape().GetDims() : gsl::span{}; bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; if (gqa_mlas_supported) { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_mask_data, - batch_size, sequence_length, attention_mask_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); @@ -109,8 +109,8 @@ class GQAAttentionBase { hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); } else { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_mask_data, - batch_size, sequence_length, attention_mask_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); @@ -136,10 +136,10 @@ class GQAAttentionBase { const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor - const T* attention_mask, // optional attention mask + const T* attention_bias, // optional attention bias const size_t batch_size, // batch size of self-attention const size_t sequence_length, // sequence length of self-attention (S) - const gsl::span attention_mask_shape, // shape of the attention mask + const gsl::span attention_bias_shape, // shape of the attention bias const size_t past_buffer_sequence_length, // sequence length of past state const size_t present_buffer_sequence_length, // sequence length of present state const size_t head_size, // head size of self-attention @@ -197,22 +197,22 @@ class GQAAttentionBase { const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; U* output = attention_probs + output_offset; - // Compute attention mask offset based on the batch and head indexes - // Attention mask is of shape (B or 1, H or 1, S, T) so handle broadcasting - const T* attention_mask_thread = nullptr; + // Compute attention bias offset based on the batch and head indexes + // Attention bias is of shape (B or 1, H or 1, S, T) so handle broadcasting + const T* attention_bias_thread = nullptr; ptrdiff_t attention_total_seqlen = 0; - if (attention_mask != nullptr) { - ptrdiff_t attention_mask_offset = 0; - attention_total_seqlen = static_cast(attention_mask_shape[3]); - const ptrdiff_t attention_matrix_size = SafeInt(attention_mask_shape[2]) * attention_total_seqlen; - if (attention_mask_shape[0] != 1) { - attention_mask_offset += SafeInt(batch_index) * attention_mask_shape[1] * attention_matrix_size; + if (attention_bias != nullptr) { + ptrdiff_t attention_bias_offset = 0; + attention_total_seqlen = static_cast(attention_bias_shape[3]); + const ptrdiff_t attention_matrix_size = sequence_length * attention_total_seqlen; + if (attention_bias_shape[0] != 1) { + attention_bias_offset += SafeInt(batch_index) * attention_bias_shape[1] * attention_matrix_size; } - if (attention_mask_shape[1] != 1) { - attention_mask_offset += SafeInt(head_index) * attention_matrix_size; + if (attention_bias_shape[1] != 1) { + attention_bias_offset += SafeInt(head_index) * attention_matrix_size; } - attention_mask_thread = attention_mask + attention_mask_offset; + attention_bias_thread = attention_bias + attention_bias_offset; } const T* k; @@ -291,19 +291,20 @@ class GQAAttentionBase { static_cast(softcap_)); } - // Apply custom attention mask if there is any - // TODO (#23982): Implement masking during softmax computation in GQA CPU operator - if (attention_mask_thread != nullptr) { + // Add attention bias to QxK' if provided + // TODO (#23982): Implement bias addition during softmax computation in GQA CPU operator + if (attention_bias_thread != nullptr) { if constexpr (std::is_same_v) { - ApplyAttentionMask(output_softmax + start_offset, attention_mask_thread + start_offset, + ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread + start_offset, static_cast(window_size)); } else { + static_assert(std::is_same_v && std::is_same_v); size_t bytes = window_size * sizeof(float); - auto attention_mask_thread_fp32 = static_cast(allocator->Alloc(bytes)); - BufferUniquePtr scratch_buffer(attention_mask_thread_fp32, BufferDeleter(allocator)); + auto attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); + BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); - MlasConvertHalfToFloatBuffer(attention_mask_thread + start_offset, attention_mask_thread_fp32, window_size); - ApplyAttentionMask(output_softmax, attention_mask_thread_fp32, static_cast(window_size)); + MlasConvertHalfToFloatBuffer(attention_bias_thread + start_offset, attention_bias_thread_fp32, window_size); + ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread_fp32, static_cast(window_size)); } } @@ -324,8 +325,8 @@ class GQAAttentionBase { output_softmax += present_buffer_sequence_length; - if (attention_mask_thread != nullptr) { - attention_mask_thread += attention_total_seqlen; + if (attention_bias_thread != nullptr) { + attention_bias_thread += attention_total_seqlen; } } } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 5227dd194aa51..964ff99138116 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -52,8 +52,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* total_seqlen_tensor = context->Input(6); const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); - const Tensor* pos_ids = context->Input(9); - const Tensor* attention_mask = context->Input(10); + const Tensor* position_ids = context->Input(9); + const Tensor* attention_bias = context->Input(10); GroupQueryAttentionParameters parameters = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, @@ -71,11 +71,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { scale_, softcap_)); - const int32_t* seqlens_k_data = seqlens_k->Data(); - const int32_t max_seqlens_k = *std::max_element(seqlens_k_data, seqlens_k_data + parameters.batch_size); - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(pos_ids, - attention_mask, - max_seqlens_k, + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, + attention_bias, parameters)); const int batch_size = parameters.batch_size; @@ -141,10 +138,10 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { std::vector default_pos_ids(pos_ids_size); const int64_t* pos_ids_data = default_pos_ids.data(); - if (pos_ids != nullptr) { - pos_ids_data = pos_ids->Data(); - } else if (parameters.is_first_prompt) { + if (parameters.is_first_prompt) { default_pos_ids[0] = static_cast(0); + } else if (position_ids != nullptr) { + pos_ids_data = position_ids->Data(); } else { // Note: As of now, continuous decoding supports only batch size 1 and token generation supports only sequence length 1. for (int b = 0; b < batch_size; b++) { @@ -209,7 +206,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - attention_mask, past_key, past_value, output, present_k, present_v, + attention_bias, past_key, past_value, output, present_k, present_v, seqlens_k, parameters, allocator, context); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f66cf06e9174a..eeca28eef08fe 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -290,50 +290,47 @@ Status CheckInputs(const T* query, } template -Status CheckCustomAttentionInputs(const T* pos_ids, - const T* attention_mask, - const int max_seqlens_k, +Status CheckCustomAttentionInputs(const T* position_ids, + const T* attention_bias, const GroupQueryAttentionParameters& parameters) { - if (pos_ids != nullptr) { - const auto& pos_ids_shape = pos_ids->Shape(); + if (position_ids != nullptr) { if (parameters.is_first_prompt) { - if (pos_ids_shape[0] != 1 || pos_ids_shape[1] != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Shape of pos_ids must be [1, 1] when processing the prompt"); - } - } else { - if (pos_ids_shape[0] != parameters.batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "pos_ids dimension 0 must be equal to the batch size, got ", pos_ids_shape[0]); - } + "Position ids input is not allowed when processing the first prompt"); + } - if (pos_ids_shape[1] < parameters.sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "pos_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]); - } + const auto& pos_ids_shape = position_ids->Shape(); + if (pos_ids_shape[0] != parameters.batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "pos_ids dimension 0 must be equal to the batch size, got ", pos_ids_shape[0]); + } + + if (pos_ids_shape[1] < parameters.sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "pos_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]); } } - if (attention_mask != nullptr) { - const auto& mask_shape = attention_mask->Shape(); - if ((mask_shape[0] != parameters.batch_size) && (mask_shape[0] != 1)) { + if (attention_bias != nullptr) { + const auto& attn_bias_shape = attention_bias->Shape(); + if ((attn_bias_shape[0] != parameters.batch_size) && (attn_bias_shape[0] != 1)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "attention_mask dimension 0 must be equal to the batch size or 1, got ", mask_shape[0]); + "attention_bias dimension 0 must be equal to the batch size or 1, got ", attn_bias_shape[0]); } - if ((mask_shape[1] != parameters.num_heads) && (mask_shape[1] != 1)) { + if ((attn_bias_shape[1] != parameters.num_heads) && (attn_bias_shape[1] != 1)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "attention_mask dimension 1 must be equal to the num heads or 1, got ", mask_shape[1]); + "attention_bias dimension 1 must be equal to the num heads or 1, got ", attn_bias_shape[1]); } - if (mask_shape[2] != parameters.sequence_length) { + if (attn_bias_shape[2] != parameters.sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "attention_mask dimension 2 must be equal to the sequence length, got ", mask_shape[2]); + "attention_bias dimension 2 must be equal to the sequence length, got ", attn_bias_shape[2]); } - if (mask_shape[3] < max_seqlens_k + 1) { + if (attn_bias_shape[3] != parameters.total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "attention_mask dimension 3 must be atleast max(seqlens_k) + 1, got ", mask_shape[3]); + "attention_bias dimension 3 must be equal to total_sequence_length, got ", attn_bias_shape[3]); } } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7033cbc1abef4..e0ac80c6666a9 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1129,13 +1129,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T", OpSchema::Optional) .Input(9, - "pos_ids", + "position_ids", "2D tensor with shape (batch_size, sequence_length).", "tensor(int64)", OpSchema::Optional) .Input(10, - "attention_mask", - "4D tensor with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length).", + "attention_bias", + "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) .Output(0, diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 0301e13065618..74c78658f3408 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -159,8 +159,8 @@ def create_group_query_attention_graph_prompt( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", - "pos_ids" if do_custom_tree_attention else "", - "attention_mask" if do_custom_tree_attention else "", + "", + "attention_bias" if do_custom_tree_attention else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -268,9 +268,8 @@ def create_group_query_attention_graph_prompt( if do_custom_tree_attention: graph_input += [ - helper.make_tensor_value_info("pos_ids", TensorProto.INT64, [1, 1]), helper.make_tensor_value_info( - "attention_mask", + "attention_bias", ORT_TYPE, [config.batch_size, 1, config.kv_sequence_length, config.kv_sequence_length], ), @@ -338,7 +337,6 @@ def create_group_query_attention_graph_prompt( def create_group_query_attention_graph_past( config, - seqlens_k, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1, @@ -353,7 +351,6 @@ def create_group_query_attention_graph_past( present_kv_seqlen = ( config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length ) - max_seqlen_in_batch = seqlens_k.max().item() + 1 nodes = [ helper.make_node( "GroupQueryAttention", @@ -367,8 +364,8 @@ def create_group_query_attention_graph_past( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", - "pos_ids" if do_custom_tree_attention else "", - "attention_mask" if do_custom_tree_attention else "", + "position_ids" if do_custom_tree_attention else "", + "attention_bias" if do_custom_tree_attention else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -473,11 +470,13 @@ def create_group_query_attention_graph_past( if do_custom_tree_attention: graph_input += [ - helper.make_tensor_value_info("pos_ids", TensorProto.INT64, [config.batch_size, config.sequence_length]), helper.make_tensor_value_info( - "attention_mask", + "position_ids", TensorProto.INT64, [config.batch_size, config.sequence_length] + ), + helper.make_tensor_value_info( + "attention_bias", ORT_TYPE, - [config.batch_size, 1, config.sequence_length, max_seqlen_in_batch], + [config.batch_size, 1, config.sequence_length, present_kv_seqlen], ), ] @@ -695,8 +694,7 @@ def gqa_prompt_func( cos=None, sin=None, seqlens_k=None, - pos_ids=None, - attention_mask=None, + attention_bias=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True, @@ -723,8 +721,7 @@ def gqa_prompt_func( past_v = v.clone() if share_buffer else None if do_custom_tree_attention: - assert pos_ids is not None - assert attention_mask is not None + assert attention_bias is not None if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) @@ -753,10 +750,8 @@ def gqa_prompt_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) if do_custom_tree_attention: - ort_inputs["pos_ids"] = pos_ids.detach().cpu().numpy() - ort_inputs["attention_mask"] = attention_mask.detach().cpu().numpy() - io_binding.bind_cpu_input("pos_ids", ort_inputs["pos_ids"]) - io_binding.bind_cpu_input("attention_mask", ort_inputs["attention_mask"]) + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( @@ -801,10 +796,8 @@ def gqa_prompt_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) if do_custom_tree_attention: - ort_inputs["pos_ids"] = pos_ids.detach().cpu().numpy() - ort_inputs["attention_mask"] = attention_mask.detach().cpu().numpy() - io_binding.bind_cpu_input("pos_ids", ort_inputs["pos_ids"]) - io_binding.bind_cpu_input("attention_mask", ort_inputs["attention_mask"]) + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) @@ -829,8 +822,8 @@ def gqa_past_func( cos=None, sin=None, seqlens_k=None, - pos_ids=None, - attention_mask=None, + position_ids=None, + attention_bias=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1, @@ -842,7 +835,6 @@ def gqa_past_func( assert seqlens_k is not None onnx_model_str = create_group_query_attention_graph_past( config, - seqlens_k, past_kv_format, share_buffer, local_window_size=window_size, @@ -858,8 +850,8 @@ def gqa_past_func( past_v = v.clone() if do_custom_tree_attention: - assert pos_ids is not None - assert attention_mask is not None + assert position_ids is not None + assert attention_bias is not None if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) @@ -890,10 +882,10 @@ def gqa_past_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) if do_custom_tree_attention: - ort_inputs["pos_ids"] = pos_ids.detach().cpu().numpy() - ort_inputs["attention_mask"] = attention_mask.detach().cpu().numpy() - io_binding.bind_cpu_input("pos_ids", ort_inputs["pos_ids"]) - io_binding.bind_cpu_input("attention_mask", ort_inputs["attention_mask"]) + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( @@ -945,10 +937,10 @@ def gqa_past_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) if do_custom_tree_attention: - ort_inputs["pos_ids"] = pos_ids.detach().cpu().numpy() - ort_inputs["attention_mask"] = attention_mask.detach().cpu().numpy() - io_binding.bind_cpu_input("pos_ids", ort_inputs["pos_ids"]) - io_binding.bind_cpu_input("attention_mask", ort_inputs["attention_mask"]) + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) @@ -1119,29 +1111,28 @@ def attention_qkvpacked_ref( ) -def get_custom_attention_inputs(batch_size, sequence_length, seqlens_k=None, past=False): +def get_custom_attention_inputs(batch_size, sequence_length, total_seq_len, seqlens_k=None, past=False): if past: assert seqlens_k is not None - pos_ids_data = [] - max_seqlen_in_batch = seqlens_k.max().item() + 1 - attention_mask = torch.zeros((batch_size, 1, sequence_length, max_seqlen_in_batch), dtype=TORCH_TYPE) + position_ids_data = [] + attention_bias = torch.zeros((batch_size, 1, sequence_length, total_seq_len), dtype=TORCH_TYPE) for b in range(batch_size): total_seq_len = seqlens_k[b] + 1 past_seq_len = total_seq_len - sequence_length - pos_ids_data.append(list(range(past_seq_len, past_seq_len + sequence_length))) + position_ids_data.append(list(range(past_seq_len, past_seq_len + sequence_length))) - # Configure mask + # Configure bias for i in range(sequence_length): - for j in range(past_seq_len + i + 1, max_seqlen_in_batch): - attention_mask[b][0][i][j] = -5000 + for j in range(past_seq_len + i + 1, total_seq_len): + attention_bias[b][0][i][j] = -5000 - pos_ids = torch.tensor(data=pos_ids_data, dtype=torch.int64) + position_ids = torch.tensor(data=position_ids_data, dtype=torch.int64) else: - pos_ids = torch.tensor(data=[[0]], dtype=torch.int64) - attention_mask = torch.rand(batch_size, 1, sequence_length, sequence_length, dtype=TORCH_TYPE) - attention_mask = torch.triu(attention_mask, diagonal=1) + position_ids = None + attention_bias = torch.rand(batch_size, 1, sequence_length, total_seq_len, dtype=TORCH_TYPE) + attention_bias = torch.triu(attention_bias, diagonal=1) - return pos_ids, attention_mask + return position_ids, attention_bias def parity_check_gqa_prompt( @@ -1245,10 +1236,11 @@ def parity_check_gqa_prompt( q_ro, k_ro = q, new_k if do_custom_tree_attention: - pos_ids, attention_mask = get_custom_attention_inputs(config.batch_size, config.kv_sequence_length, past=False) + _, attention_bias = get_custom_attention_inputs( + config.batch_size, config.kv_sequence_length, config.q_sequence_length, seqlens_k=None, past=False + ) else: - pos_ids = None - attention_mask = None + attention_bias = None rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s") @@ -1292,8 +1284,7 @@ def parity_check_gqa_prompt( cos, sin, cache_seqlens - 1, - pos_ids, - attention_mask, + attention_bias, left_window_size, past_format, True, @@ -1313,8 +1304,7 @@ def parity_check_gqa_prompt( cos, sin, cache_seqlens - 1, - pos_ids, - attention_mask, + attention_bias, left_window_size, past_format, True, @@ -1453,10 +1443,11 @@ def parity_check_gqa_prompt_no_buff( k_cache_ref = k_ro if do_custom_tree_attention: - pos_ids, attention_mask = get_custom_attention_inputs(config.batch_size, config.kv_sequence_length, past=False) + _, attention_bias = get_custom_attention_inputs( + config.batch_size, config.kv_sequence_length, config.q_sequence_length, seqlens_k=None, past=False + ) else: - pos_ids = None - attention_mask = None + attention_bias = None brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") @@ -1494,8 +1485,7 @@ def parity_check_gqa_prompt_no_buff( cos, sin, cache_seqlens - 1, - pos_ids, - attention_mask, + attention_bias, left_window_size, past_format, False, @@ -1514,8 +1504,7 @@ def parity_check_gqa_prompt_no_buff( cos, sin, cache_seqlens - 1, - pos_ids, - attention_mask, + attention_bias, left_window_size, past_format, False, @@ -1708,12 +1697,12 @@ def parity_check_gqa_past( cache_seqlens += config.sequence_length - 1 if do_custom_tree_attention: - pos_ids, attention_mask = get_custom_attention_inputs( - config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True + position_ids, attention_bias = get_custom_attention_inputs( + config.batch_size, config.sequence_length, config.kv_sequence_length, seqlens_k=cache_seqlens, past=True ) else: - pos_ids = None - attention_mask = None + position_ids = None + attention_bias = None # ORT function if packed: @@ -1728,8 +1717,8 @@ def parity_check_gqa_past( cos, sin, cache_seqlens, - pos_ids, - attention_mask, + position_ids, + attention_bias, past_format, True, left_window_size, @@ -1749,8 +1738,8 @@ def parity_check_gqa_past( cos, sin, cache_seqlens, - pos_ids, - attention_mask, + position_ids, + attention_bias, past_format, True, left_window_size, @@ -1950,12 +1939,16 @@ def parity_check_gqa_past_no_buff( cache_seqlens += config.sequence_length - 1 if do_custom_tree_attention: - pos_ids, attention_mask = get_custom_attention_inputs( - config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True + position_ids, attention_bias = get_custom_attention_inputs( + config.batch_size, + config.sequence_length, + config.kv_sequence_length + config.sequence_length, + seqlens_k=cache_seqlens, + past=True, ) else: - pos_ids = None - attention_mask = None + position_ids = None + attention_bias = None # Flash function if packed: @@ -1970,8 +1963,8 @@ def parity_check_gqa_past_no_buff( cos, sin, cache_seqlens, - pos_ids, - attention_mask, + position_ids, + attention_bias, past_format, False, window_size=left_window_size, @@ -1990,8 +1983,8 @@ def parity_check_gqa_past_no_buff( cos, sin, cache_seqlens, - pos_ids, - attention_mask, + position_ids, + attention_bias, past_format, False, window_size=left_window_size, From c36a9cfd854d0979c0afec49aae6da51a05098d2 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Fri, 14 Mar 2025 00:32:26 +0100 Subject: [PATCH 23/25] Fix position_ids handling for the first prompt --- .../cpu/bert/group_query_attention.cc | 6 ++--- .../cpu/bert/group_query_attention_helper.h | 9 ++----- .../core/graph/contrib_ops/bert_defs.cc | 3 ++- .../test/python/transformers/test_gqa_cpu.py | 25 ++++++++++++++++--- 4 files changed, 28 insertions(+), 15 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 964ff99138116..9c7530f0126bb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -138,10 +138,10 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { std::vector default_pos_ids(pos_ids_size); const int64_t* pos_ids_data = default_pos_ids.data(); - if (parameters.is_first_prompt) { - default_pos_ids[0] = static_cast(0); - } else if (position_ids != nullptr) { + if (position_ids != nullptr) { pos_ids_data = position_ids->Data(); + } else if (parameters.is_first_prompt) { + default_pos_ids[0] = static_cast(0); } else { // Note: As of now, continuous decoding supports only batch size 1 and token generation supports only sequence length 1. for (int b = 0; b < batch_size; b++) { diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index eeca28eef08fe..7bffd768c8f7c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -294,20 +294,15 @@ Status CheckCustomAttentionInputs(const T* position_ids, const T* attention_bias, const GroupQueryAttentionParameters& parameters) { if (position_ids != nullptr) { - if (parameters.is_first_prompt) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Position ids input is not allowed when processing the first prompt"); - } - const auto& pos_ids_shape = position_ids->Shape(); if (pos_ids_shape[0] != parameters.batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "pos_ids dimension 0 must be equal to the batch size, got ", pos_ids_shape[0]); + "position_ids dimension 0 must be equal to the batch size, got ", pos_ids_shape[0]); } if (pos_ids_shape[1] < parameters.sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "pos_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]); + "position_ids dimension 1 must be atleast sequence length, got ", pos_ids_shape[1]); } } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index e0ac80c6666a9..54e1ab0684f50 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1130,7 +1130,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(9, "position_ids", - "2D tensor with shape (batch_size, sequence_length).", + "2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel ", + "uses only the first element", "tensor(int64)", OpSchema::Optional) .Input(10, diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 74c78658f3408..5887f41e9f990 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -159,7 +159,7 @@ def create_group_query_attention_graph_prompt( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", - "", + "position_ids" if do_custom_tree_attention else "", "attention_bias" if do_custom_tree_attention else "", ], ["output", "present_key", "present_value"], @@ -268,6 +268,11 @@ def create_group_query_attention_graph_prompt( if do_custom_tree_attention: graph_input += [ + helper.make_tensor_value_info( + "position_ids", + TensorProto.INT64, + [config.batch_size, config.kv_sequence_length], + ), helper.make_tensor_value_info( "attention_bias", ORT_TYPE, @@ -694,6 +699,7 @@ def gqa_prompt_func( cos=None, sin=None, seqlens_k=None, + position_ids=None, attention_bias=None, window_size=-1, past_kv_format=Formats.BSNH, @@ -721,6 +727,7 @@ def gqa_prompt_func( past_v = v.clone() if share_buffer else None if do_custom_tree_attention: + assert position_ids is not None assert attention_bias is not None if new_k is not None: @@ -751,7 +758,9 @@ def gqa_prompt_func( if do_custom_tree_attention: ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( @@ -797,7 +806,9 @@ def gqa_prompt_func( if do_custom_tree_attention: ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) @@ -1128,7 +1139,7 @@ def get_custom_attention_inputs(batch_size, sequence_length, total_seq_len, seql position_ids = torch.tensor(data=position_ids_data, dtype=torch.int64) else: - position_ids = None + position_ids = torch.zeros((batch_size, sequence_length), dtype=torch.int64) attention_bias = torch.rand(batch_size, 1, sequence_length, total_seq_len, dtype=TORCH_TYPE) attention_bias = torch.triu(attention_bias, diagonal=1) @@ -1236,10 +1247,11 @@ def parity_check_gqa_prompt( q_ro, k_ro = q, new_k if do_custom_tree_attention: - _, attention_bias = get_custom_attention_inputs( + position_ids, attention_bias = get_custom_attention_inputs( config.batch_size, config.kv_sequence_length, config.q_sequence_length, seqlens_k=None, past=False ) else: + position_ids = None attention_bias = None rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") @@ -1284,6 +1296,7 @@ def parity_check_gqa_prompt( cos, sin, cache_seqlens - 1, + position_ids, attention_bias, left_window_size, past_format, @@ -1304,6 +1317,7 @@ def parity_check_gqa_prompt( cos, sin, cache_seqlens - 1, + position_ids, attention_bias, left_window_size, past_format, @@ -1443,10 +1457,11 @@ def parity_check_gqa_prompt_no_buff( k_cache_ref = k_ro if do_custom_tree_attention: - _, attention_bias = get_custom_attention_inputs( + position_ids, attention_bias = get_custom_attention_inputs( config.batch_size, config.kv_sequence_length, config.q_sequence_length, seqlens_k=None, past=False ) else: + position_ids = None attention_bias = None brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") @@ -1485,6 +1500,7 @@ def parity_check_gqa_prompt_no_buff( cos, sin, cache_seqlens - 1, + position_ids, attention_bias, left_window_size, past_format, @@ -1504,6 +1520,7 @@ def parity_check_gqa_prompt_no_buff( cos, sin, cache_seqlens - 1, + position_ids, attention_bias, left_window_size, past_format, From 86a7737f678036af148a75413c59df441edccdc1 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Fri, 14 Mar 2025 00:44:04 +0100 Subject: [PATCH 24/25] Fix build break --- onnxruntime/core/graph/contrib_ops/bert_defs.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 54e1ab0684f50..718dd9a4397b5 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1130,7 +1130,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(9, "position_ids", - "2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel ", + "2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel " "uses only the first element", "tensor(int64)", OpSchema::Optional) From 56fe768303a94068b46096145442d0b1e60c8d95 Mon Sep 17 00:00:00 2001 From: Dusan Erdeljan Date: Fri, 14 Mar 2025 11:10:56 +0100 Subject: [PATCH 25/25] Fix PR comments and fix docs gen CI pipeline --- docs/ContribOperators.md | 8 +- docs/OperatorKernels.md | 6 +- .../test/python/transformers/test_gqa_cpu.py | 228 ++++++++++++------ 3 files changed, 157 insertions(+), 85 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index a2ced1571d0bc..f85ed1e5f146c 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2572,10 +2572,10 @@ This version of the operator has been available since version 1 of the 'com.micr
2D tensor with shape (max_sequence_length, head_size / 2).
sin_cache (optional) : T
2D tensor with shape (max_sequence_length, head_size / 2).
-
pos_ids (optional) : tensor(int64)
-
2D tensor with shape (batch_size, sequence_length).
-
attention_mask (optional) : T
-
4D tensor with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length).
+
position_ids (optional) : tensor(int64)
+
2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element
+
attention_bias (optional) : T
+
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index e8e98f3ddd00a..1dd145463367b 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -520,7 +520,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* pos_ids:**tensor(int64)**
*in* attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -922,7 +922,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* pos_ids:**tensor(int64)**
*in* attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1399,7 +1399,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* pos_ids:**tensor(int64)**
*in* attention_mask:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 5887f41e9f990..1239affcc04de 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -51,6 +51,8 @@ class Config: num_heads: int = 0 kv_num_heads: int = 0 head_size: int = 0 + has_position_ids: bool = False + has_attention_bias: bool = False @dataclass @@ -62,6 +64,8 @@ class PromptConfig: num_heads: int = 0 kv_num_heads: int = 0 head_size: int = 0 + has_position_ids: bool = False + has_attention_bias: bool = False # LLaMA Microsoft model @@ -142,7 +146,6 @@ def create_group_query_attention_graph_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, - do_custom_tree_attention=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -159,8 +162,8 @@ def create_group_query_attention_graph_prompt( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", - "position_ids" if do_custom_tree_attention else "", - "attention_bias" if do_custom_tree_attention else "", + "position_ids" if config.has_position_ids else "", + "attention_bias" if config.has_attention_bias else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -266,13 +269,17 @@ def create_group_query_attention_graph_prompt( ), ] - if do_custom_tree_attention: + if config.has_position_ids: graph_input += [ helper.make_tensor_value_info( "position_ids", TensorProto.INT64, [config.batch_size, config.kv_sequence_length], ), + ] + + if config.has_attention_bias: + graph_input += [ helper.make_tensor_value_info( "attention_bias", ORT_TYPE, @@ -350,7 +357,6 @@ def create_group_query_attention_graph_past( packed=False, softcap=0.0, use_smooth_softmax=False, - do_custom_tree_attention=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -369,8 +375,8 @@ def create_group_query_attention_graph_past( "total_sequence_length", "cos_cache" if rotary else "", "sin_cache" if rotary else "", - "position_ids" if do_custom_tree_attention else "", - "attention_bias" if do_custom_tree_attention else "", + "position_ids" if config.has_position_ids else "", + "attention_bias" if config.has_attention_bias else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", @@ -473,11 +479,15 @@ def create_group_query_attention_graph_past( ), ] - if do_custom_tree_attention: + if config.has_position_ids: graph_input += [ helper.make_tensor_value_info( "position_ids", TensorProto.INT64, [config.batch_size, config.sequence_length] ), + ] + + if config.has_attention_bias: + graph_input += [ helper.make_tensor_value_info( "attention_bias", ORT_TYPE, @@ -707,7 +717,6 @@ def gqa_prompt_func( rotary_interleaved=False, softcap=0.0, use_smooth_softmax=False, - do_custom_tree_attention=False, ): onnx_model_str = create_group_query_attention_graph_prompt( config, @@ -719,15 +728,16 @@ def gqa_prompt_func( packed=new_k is None, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None - if do_custom_tree_attention: + if config.has_position_ids: assert position_ids is not None + + if config.has_attention_bias: assert attention_bias is not None if new_k is not None: @@ -756,12 +766,14 @@ def gqa_prompt_func( io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - if do_custom_tree_attention: - ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + if config.has_position_ids: ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() - io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() @@ -804,12 +816,14 @@ def gqa_prompt_func( io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - if do_custom_tree_attention: - ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + if config.has_position_ids: ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() - io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) @@ -841,7 +855,6 @@ def gqa_past_func( rotary_interleaved=False, softcap=0.0, use_smooth_softmax=False, - do_custom_tree_attention=False, ): assert seqlens_k is not None onnx_model_str = create_group_query_attention_graph_past( @@ -854,14 +867,15 @@ def gqa_past_func( packed=new_k is None, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() - if do_custom_tree_attention: + if config.has_position_ids: assert position_ids is not None + + if config.has_attention_bias: assert attention_bias is not None if new_k is not None: @@ -892,10 +906,12 @@ def gqa_past_func( io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - if do_custom_tree_attention: + if config.has_position_ids: ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() - ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) @@ -947,10 +963,12 @@ def gqa_past_func( io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - if do_custom_tree_attention: + if config.has_position_ids: ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() - ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) @@ -1122,28 +1140,39 @@ def attention_qkvpacked_ref( ) -def get_custom_attention_inputs(batch_size, sequence_length, total_seq_len, seqlens_k=None, past=False): +def get_custom_attention_bias(batch_size, sequence_length, total_seq_len, seqlens_k=None, past=False): if past: assert seqlens_k is not None - position_ids_data = [] attention_bias = torch.zeros((batch_size, 1, sequence_length, total_seq_len), dtype=TORCH_TYPE) for b in range(batch_size): total_seq_len = seqlens_k[b] + 1 past_seq_len = total_seq_len - sequence_length - position_ids_data.append(list(range(past_seq_len, past_seq_len + sequence_length))) # Configure bias for i in range(sequence_length): for j in range(past_seq_len + i + 1, total_seq_len): attention_bias[b][0][i][j] = -5000 + else: + attention_bias = torch.rand(batch_size, 1, sequence_length, total_seq_len, dtype=TORCH_TYPE) + attention_bias = torch.triu(attention_bias, diagonal=1) + + return attention_bias + + +def get_custom_position_ids(batch_size, sequence_length, seqlens_k=None, past=False): + if past: + assert seqlens_k is not None + position_ids_data = [] + for b in range(batch_size): + total_seq_len = seqlens_k[b] + 1 + past_seq_len = total_seq_len - sequence_length + position_ids_data.append(list(range(past_seq_len, past_seq_len + sequence_length))) position_ids = torch.tensor(data=position_ids_data, dtype=torch.int64) else: position_ids = torch.zeros((batch_size, sequence_length), dtype=torch.int64) - attention_bias = torch.rand(batch_size, 1, sequence_length, total_seq_len, dtype=TORCH_TYPE) - attention_bias = torch.triu(attention_bias, diagonal=1) - return position_ids, attention_bias + return position_ids def parity_check_gqa_prompt( @@ -1156,7 +1185,6 @@ def parity_check_gqa_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, - do_custom_tree_attention=False, rtol=RTOL, atol=ATOL, ): @@ -1246,13 +1274,18 @@ def parity_check_gqa_prompt( cos, sin = None, None q_ro, k_ro = q, new_k - if do_custom_tree_attention: - position_ids, attention_bias = get_custom_attention_inputs( + position_ids = ( + get_custom_position_ids(config.batch_size, config.kv_sequence_length, seqlens_k=None, past=False) + if config.has_position_ids + else None + ) + attention_bias = ( + get_custom_attention_bias( config.batch_size, config.kv_sequence_length, config.q_sequence_length, seqlens_k=None, past=False ) - else: - position_ids = None - attention_bias = None + if config.has_attention_bias + else None + ) rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s") @@ -1284,6 +1317,7 @@ def parity_check_gqa_prompt( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function + # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) out, present_k, present_v = gqa_prompt_func( @@ -1304,7 +1338,6 @@ def parity_check_gqa_prompt( rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1325,7 +1358,6 @@ def parity_check_gqa_prompt( rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1354,8 +1386,6 @@ def parity_check_gqa_prompt( softcap, " smooth_softmax:", use_smooth_softmax, - " custom_tree_attention:", - do_custom_tree_attention, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1370,6 +1400,10 @@ def parity_check_gqa_prompt( config.kv_num_heads, " h:", config.head_size, + " has_position_ids:", + config.has_position_ids, + " has_attention_bias:", + config.has_attention_bias, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1387,7 +1421,6 @@ def parity_check_gqa_prompt_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - do_custom_tree_attention=False, rtol=RTOL, atol=ATOL, ): @@ -1456,13 +1489,18 @@ def parity_check_gqa_prompt_no_buff( q_ro, k_ro = q, k_cache_ref k_cache_ref = k_ro - if do_custom_tree_attention: - position_ids, attention_bias = get_custom_attention_inputs( + position_ids = ( + get_custom_position_ids(config.batch_size, config.kv_sequence_length, seqlens_k=None, past=False) + if config.has_position_ids + else None + ) + attention_bias = ( + get_custom_attention_bias( config.batch_size, config.kv_sequence_length, config.q_sequence_length, seqlens_k=None, past=False ) - else: - position_ids = None - attention_bias = None + if config.has_attention_bias + else None + ) brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") @@ -1488,6 +1526,7 @@ def parity_check_gqa_prompt_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function + # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) out, present_k, present_v = gqa_prompt_func( @@ -1556,8 +1595,6 @@ def parity_check_gqa_prompt_no_buff( softcap, " smooth_softmax:", use_smooth_softmax, - " custom_tree_attention:", - do_custom_tree_attention, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1572,6 +1609,10 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, " h:", config.head_size, + " has_position_ids:", + config.has_position_ids, + " has_attention_bias:", + config.has_attention_bias, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1589,7 +1630,6 @@ def parity_check_gqa_past( packed=False, softcap=0.0, use_smooth_softmax=False, - do_custom_tree_attention=False, rtol=RTOL, atol=ATOL, ): @@ -1713,13 +1753,18 @@ def parity_check_gqa_past( cache_seqlens += config.sequence_length - 1 - if do_custom_tree_attention: - position_ids, attention_bias = get_custom_attention_inputs( + position_ids = ( + get_custom_position_ids(config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True) + if config.has_position_ids + else None + ) + attention_bias = ( + get_custom_attention_bias( config.batch_size, config.sequence_length, config.kv_sequence_length, seqlens_k=cache_seqlens, past=True ) - else: - position_ids = None - attention_bias = None + if config.has_attention_bias + else None + ) # ORT function if packed: @@ -1742,7 +1787,6 @@ def parity_check_gqa_past( rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention, ) else: out, present_k, present_v = gqa_past_func( @@ -1763,7 +1807,6 @@ def parity_check_gqa_past( rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -1794,8 +1837,6 @@ def parity_check_gqa_past( softcap, " smooth_softmax:", use_smooth_softmax, - " custom_tree_attention:", - do_custom_tree_attention, " B:", config.batch_size, " S:", @@ -1808,6 +1849,10 @@ def parity_check_gqa_past( config.kv_num_heads, " h:", config.head_size, + " has_position_ids:", + config.has_position_ids, + " has_attention_bias:", + config.has_attention_bias, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1825,7 +1870,6 @@ def parity_check_gqa_past_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - do_custom_tree_attention=False, rtol=RTOL, atol=ATOL, ): @@ -1955,17 +1999,22 @@ def parity_check_gqa_past_no_buff( cache_seqlens += config.sequence_length - 1 - if do_custom_tree_attention: - position_ids, attention_bias = get_custom_attention_inputs( + position_ids = ( + get_custom_position_ids(config.batch_size, config.sequence_length, seqlens_k=cache_seqlens, past=True) + if config.has_position_ids + else None + ) + attention_bias = ( + get_custom_attention_bias( config.batch_size, config.sequence_length, config.kv_sequence_length + config.sequence_length, seqlens_k=cache_seqlens, past=True, ) - else: - position_ids = None - attention_bias = None + if config.has_attention_bias + else None + ) # Flash function if packed: @@ -2032,8 +2081,6 @@ def parity_check_gqa_past_no_buff( softcap, " smooth_softmax:", use_smooth_softmax, - " custom_tree_attention:", - do_custom_tree_attention, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -2048,6 +2095,10 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, " h:", config.head_size, + " has_position_ids:", + config.has_position_ids, + " has_attention_bias:", + config.has_attention_bias, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -2075,6 +2126,11 @@ def test_gqa_no_past(self): (8000, 8000), ] ) + pos_ids_attn_bias = ( + [(False, False), (True, True)] + if pipeline_mode + else [(False, False), (True, True), (False, True), (True, False)] + ) num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] for b in batches: @@ -2086,8 +2142,18 @@ def test_gqa_no_past(self): for packed in [False, True]: for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: - for do_custom_tree_attention in [False, True]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + for has_position_ids, has_attention_bias in pos_ids_attn_bias: + config = PromptConfig( + b, + sq, + skv, + sq + skv + 8, + n, + n2, + h, + has_position_ids, + has_attention_bias, + ) past_kv_format = Formats.BNSH all_close = parity_check_gqa_prompt( config, @@ -2098,7 +2164,6 @@ def test_gqa_no_past(self): packed=packed, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close) all_close = parity_check_gqa_prompt_no_buff( @@ -2110,7 +2175,6 @@ def test_gqa_no_past(self): packed=packed, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close) @@ -2134,6 +2198,11 @@ def test_gqa_past(self): # (128, 128), ] ) + pos_ids_attn_bias = ( + [(False, False), (True, True)] + if pipeline_mode + else [(False, False), (True, True), (False, True), (True, False)] + ) num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) @@ -2146,9 +2215,11 @@ def test_gqa_past(self): for packed in [False, True]: for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: - for do_custom_tree_attention in [False, True]: + for has_position_ids, has_attention_bias in pos_ids_attn_bias: sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) + config = Config( + b, s, s2, sp, n, n2, h, has_position_ids, has_attention_bias + ) past_kv_format = Formats.BNSH all_close = parity_check_gqa_past( config, @@ -2161,7 +2232,6 @@ def test_gqa_past(self): packed=packed, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close) all_close = parity_check_gqa_past_no_buff( @@ -2175,7 +2245,6 @@ def test_gqa_past(self): packed=packed, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close) @@ -2199,6 +2268,11 @@ def test_gqa_interactive_one_batch(self): # (128, 128), ] ) + pos_ids_attn_bias = ( + [(False, False), (True, True)] + if pipeline_mode + else [(False, False), (True, True), (False, True), (True, False)] + ) num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) @@ -2209,8 +2283,8 @@ def test_gqa_interactive_one_batch(self): for local in [False, True]: for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: for packed in [False, True]: - for do_custom_tree_attention in [False, True]: - config = Config(b, s, s2, -1, n, n2, h) + for has_position_ids, has_attention_bias in pos_ids_attn_bias: + config = Config(b, s, s2, -1, n, n2, h, has_position_ids, has_attention_bias) past_kv_format = Formats.BNSH all_close = parity_check_gqa_past( config, @@ -2221,7 +2295,6 @@ def test_gqa_interactive_one_batch(self): rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close) all_close = parity_check_gqa_past_no_buff( @@ -2233,7 +2306,6 @@ def test_gqa_interactive_one_batch(self): rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - do_custom_tree_attention=do_custom_tree_attention, ) self.assertTrue(all_close)