diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 0d5117709c18a..bfa450f4287f8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -280,6 +280,18 @@ class GQAAttentionBase { output, static_cast(present_buffer_sequence_length), nullptr); } + // Pre-allocate buffer for attention mask to avoid allocating it for every processed token + float* attention_bias_thread_fp32 = nullptr; + if (attention_bias_thread != nullptr) { + if constexpr (!std::is_same_v) { + static_assert(std::is_same_v && std::is_same_v); + + size_t bytes = attention_total_seqlen * sizeof(float); + attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); + } + } + BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); + // compute Softmax U* output_softmax = output; for (size_t seq = 0; seq < sequence_length; seq++) { @@ -316,9 +328,6 @@ class GQAAttentionBase { static_cast(window_size)); } else { static_assert(std::is_same_v && std::is_same_v); - size_t bytes = window_size * sizeof(float); - auto attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); - BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); MlasConvertHalfToFloatBuffer(attention_bias_thread + start_offset, attention_bias_thread_fp32, window_size); ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread_fp32, static_cast(window_size));