From bc5eed258de7b253d6e2eba75298b6ff8ddaa6d7 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 12 Mar 2025 11:08:11 -0700 Subject: [PATCH 1/6] init --- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 + .../cpu/sparse/sparse_attention.cc | 24 +++++++----- .../cpu/sparse/sparse_attention_base.h | 38 ++++++++++++++----- 3 files changed, 45 insertions(+), 19 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index c742cd1e95bdd..345b5e793a764 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -24,6 +24,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SparseAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); @@ -299,6 +300,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc index e337f41a8688d..66fd30cf96b0b 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc @@ -21,16 +21,20 @@ using onnxruntime::concurrency::ThreadPool; namespace onnxruntime { namespace contrib { -ONNX_OPERATOR_TYPED_KERNEL_EX( - SparseAttention, - kMSDomain, - 1, - float, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("M", DataTypeImpl::GetTensorType()), - SparseAttention); +#define REGISTER_KERNEL_TYPED(T) + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SparseAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + SparseAttention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) template SparseAttention::SparseAttention(const OpKernelInfo& info) : OpKernel(info), SparseAttentionBase(info) { diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h index 2c719b3724106..435fad59d956d 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -67,7 +67,10 @@ class SparseAttentionBase { int present_buffer_sequence_length = static_cast(present_key->Shape().GetDims()[2]); // Allocate a buffer to store Softmax(QK) - size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * parameters.total_sequence_length * sizeof(T); + bool attention_mlas_supported = MlasGQASupported(CblasNoTrans, CblasTrans) && + MlasGQASupported(CblasNoTrans, CblasNoTrans); + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * parameters.total_sequence_length * + (attention_mlas_supported ? sizeof(T) : sizeof(float)); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -77,21 +80,38 @@ class SparseAttentionBase { auto* tp = context->GetOperatorThreadPool(); const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - ComputeAttentionProbs( + const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; + + if (attention_mlas_supported) { + ComputeAttentionProbs( static_cast(attention_probs), Q, k, total_key_lengths->Data(), batch_size, sequence_length, parameters.total_sequence_length, past_buffer_sequence_length, present_buffer_sequence_length, head_size, past_key->Data(), present_key->MutableData(), past_present_share_buffer, packed_qkv, block_row_indices->Data(), block_col_indices->Data(), parameters, tp); - // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) - const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; - ComputeVxAttentionScore( + ComputeVxAttentionScore( output->MutableData(), static_cast(attention_probs), v, total_key_lengths->Data(), batch_size, sequence_length, parameters.total_sequence_length, past_buffer_sequence_length, present_buffer_sequence_length, head_size, parameters.hidden_size, past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp); + } else { + ComputeAttentionProbs( + static_cast(attention_probs), Q, k, total_key_lengths->Data(), + batch_size, sequence_length, parameters.total_sequence_length, + past_buffer_sequence_length, present_buffer_sequence_length, head_size, + past_key->Data(), present_key->MutableData(), past_present_share_buffer, packed_qkv, + block_row_indices->Data(), block_col_indices->Data(), parameters, tp); + + ComputeVxAttentionScore( + output->MutableData(), static_cast(attention_probs), v, + total_key_lengths->Data(), + batch_size, sequence_length, parameters.total_sequence_length, + past_buffer_sequence_length, present_buffer_sequence_length, head_size, parameters.hidden_size, + past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp); + } + return Status::OK(); } @@ -100,9 +120,9 @@ class SparseAttentionBase { // Helper function to compute the attention probs. It does 2 things: // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) - template + template void ComputeAttentionProbs( - T* attention_probs, // output buffer with size BxNxSxT + U* attention_probs, // output buffer with size BxNxSxT const T* Q, // query start pointer const T* K, // key start pointer const int32_t* total_key_lengths, // total key sequence lengths (past + new) @@ -299,9 +319,9 @@ class SparseAttentionBase { }); } - template + template void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH - const T* attention_probs, // Softmax of Q*K' with size BxNxSxT + const U* attention_probs, // Softmax of Q*K' with size BxNxSxT const T* V, // v value with size BxN_kvxSxH const int32_t* total_key_lengths, // total sequence lengths int batch_size, // batch size From f7ebd2dcade820e77f42b384f121f571fc4897dc Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 12 Mar 2025 12:15:58 -0700 Subject: [PATCH 2/6] finished softmax(q x k') --- .../cpu/sparse/sparse_attention_base.h | 51 +++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h index 435fad59d956d..aed126f882e4d 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -193,7 +193,7 @@ class SparseAttentionBase { const int total_seq_len = total_key_lengths[batch_index]; const ptrdiff_t output_offset = SafeInt(i) * sequence_length * total_sequence_length; - T* output = attention_probs + output_offset; + U* output = attention_probs + output_offset; const T* k; if (packed_qkv) { @@ -225,14 +225,34 @@ class SparseAttentionBase { DUMP_CPU_TENSOR("Q", q, sequence_length, head_size); DUMP_CPU_TENSOR("K", k, total_seq_len, head_size); - math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seq_len, head_size, alpha, q, - head_size, k, head_size, 0.0f /*bata*/, output, total_seq_len, - nullptr); + if constexpr (std::is_same::value) { + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seq_len, head_size, alpha, q, + head_size, k, head_size, 0.0f /*bata*/, output, total_seq_len, + nullptr); + } else if constexpr (std::is_same::value) { + MlasGemm(CblasNoTrans, CblasTrans, sequence_length, total_seq_len, head_size, + q, head_size, k, head_size, output, total_seq_len, + MLFloat16(alpha).val, static_cast(0) /*beta*/, nullptr); + } else { + size_t bytes = head_size * (sequence_length + total_seq_len) * sizeof(float); + auto q_k_fp32 = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator)); + + float* q_fp32 = static_cast(q_k_fp32); + MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length); + + float* k_fp32 = q_fp32 + head_size * sequence_length; + MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seq_len); + + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seq_len, head_size, + alpha, q_fp32, head_size, k_fp32, head_size, 0.0f /*bata*/, + output, total_seq_len, nullptr); + } DUMP_CPU_TENSOR("QK", output, sequence_length, total_seq_len); // Compute Softmax for causal and output result in place. - T* output_softmax = output; + U* output_softmax = output; int layout_id = head_index % parameters.num_sparse_layout; bool is_sparse_layout = layout_has_sparse[layout_id]; @@ -244,7 +264,11 @@ class SparseAttentionBase { int causal_length = past_seq_len + q_id + 1; ComputeAttentionSoftmaxInplace(output_softmax, 1, causal_length, nullptr); for (int remain_seq_id = causal_length; remain_seq_id < total_seq_len; remain_seq_id++) { - output_softmax[remain_seq_id] = 0.f; + if constexpr (std::is_same::value) { + output_softmax[remain_seq_id] = 0.f; + } else { + output_softmax[remain_seq_id] = MLFloat16::FromBits(static_cast(0)); + } } output_softmax += total_seq_len; } @@ -298,14 +322,23 @@ class SparseAttentionBase { // Update inline according to attention mask. if (has_sparse) { for (int s = 0; s < causal_length; s++) { - if (mask[s] == 0) - output_softmax[s] = std::numeric_limits::lowest(); + if (mask[s] == 0) { + if constexpr (std::is_same::value) { + output_softmax[s] = std::numeric_limits::lowest(); + } else { + output_softmax[s] = MLFloat16::FromBits(static_cast(0xFBFF)); + } + } } } ComputeAttentionSoftmaxInplace(output_softmax, 1, causal_length, nullptr); for (int remain_seq_id = causal_length; remain_seq_id < total_seq_len; remain_seq_id++) { - output_softmax[remain_seq_id] = 0.f; + if constexpr (std::is_same::value) { + output_softmax[remain_seq_id] = 0.f; + } else { + output_softmax[remain_seq_id] = MLFloat16::FromBits(static_cast(0)); + } } output_softmax += total_seq_len; From 40d1360784bce68ffc329bd13e9ab6ebc71f6bc6 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 12 Mar 2025 15:14:09 -0700 Subject: [PATCH 3/6] finished softmax * v --- .../cpu/sparse/sparse_attention_base.h | 43 +++++++++++++++++-- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h index aed126f882e4d..7155475474fa8 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -394,6 +394,13 @@ class SparseAttentionBase { unit_cost.bytes_stored += bytes_to_copy_value; } + size_t output_fp32_bytes = 0; + if constexpr (std::is_same::value && std::is_same::value) { + output_fp32_bytes = SafeInt(sequence_length) * batch_size * num_heads_ * head_size * sizeof(float); + } + auto output_fp32 = allocator->Alloc(output_fp32_bytes); + BufferUniquePtr scratch_buffer(output_fp32, BufferDeleter(allocator)); + DUMP_CPU_TENSOR_INIT(); ThreadPool::TryParallelFor( @@ -429,14 +436,42 @@ class SparseAttentionBase { DUMP_CPU_TENSOR("attention_probs", attention_probs + attention_probs_offset, sequence_length, total_seq_len); - math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seq_len, - 1.f, /*alpha*/ - attention_probs + attention_probs_offset, total_seq_len, v, - head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr); + if constexpr (std::is_same::value) { + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seq_len, + 1.f, /*alpha*/ + attention_probs + attention_probs_offset, total_seq_len, v, + head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr); + } else if constexpr (std::is_same::value) { + MlasGemm(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seq_len, + attention_probs + attention_probs_offset, total_seq_len, + v, head_size, output_current, hidden_size, + MLFloat16(1.0f).val, static_cast(0) /*beta*/, nullptr); + } else { + size_t bytes = head_size * total_seq_len * sizeof(float); + auto v_fp32 = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator)); + + float* v_fp32_ptr = static_cast(v_fp32); + MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seq_len); + + float* output_fp32_current = static_cast(output_fp32) + + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seq_len, + 1.f, /*alpha*/ attention_probs + attention_probs_offset, + total_seq_len, v_fp32_ptr, + head_size, 0.0f /*beta*/, output_fp32_current, + hidden_size, nullptr); + } DUMP_CPU_TENSOR("out", attention_probs + attention_probs_offset, sequence_length, head_size); } }); + + if constexpr (std::is_same::value && std::is_same::value) { + MlasConvertFloatToHalfBuffer(static_cast(output_fp32), + output, + SafeInt(sequence_length) * batch_size * num_heads_ * head_size); + } } }; From 84da21dcfb957a0d8e48360c093c543825183eb9 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 12 Mar 2025 15:28:17 -0700 Subject: [PATCH 4/6] fix linting --- .../cpu/sparse/sparse_attention.cc | 20 ++++----- .../cpu/sparse/sparse_attention_base.h | 41 +++++++++---------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc index 66fd30cf96b0b..2ad1ec1d6b328 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc @@ -22,16 +22,16 @@ namespace onnxruntime { namespace contrib { #define REGISTER_KERNEL_TYPED(T) - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - SparseAttention, \ - kMSDomain, \ - 1, \ - T, \ - kCpuExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ - SparseAttention); +ONNX_OPERATOR_TYPED_KERNEL_EX( + SparseAttention, + kMSDomain, + 1, + T, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + SparseAttention); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h index 7155475474fa8..da997c2532ef7 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -84,35 +84,34 @@ class SparseAttentionBase { if (attention_mlas_supported) { ComputeAttentionProbs( - static_cast(attention_probs), Q, k, total_key_lengths->Data(), - batch_size, sequence_length, parameters.total_sequence_length, - past_buffer_sequence_length, present_buffer_sequence_length, head_size, - past_key->Data(), present_key->MutableData(), past_present_share_buffer, packed_qkv, - block_row_indices->Data(), block_col_indices->Data(), parameters, tp); + static_cast(attention_probs), Q, k, total_key_lengths->Data(), + batch_size, sequence_length, parameters.total_sequence_length, + past_buffer_sequence_length, present_buffer_sequence_length, head_size, + past_key->Data(), present_key->MutableData(), past_present_share_buffer, packed_qkv, + block_row_indices->Data(), block_col_indices->Data(), parameters, tp); ComputeVxAttentionScore( - output->MutableData(), static_cast(attention_probs), v, - total_key_lengths->Data(), - batch_size, sequence_length, parameters.total_sequence_length, - past_buffer_sequence_length, present_buffer_sequence_length, head_size, parameters.hidden_size, - past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp); + output->MutableData(), static_cast(attention_probs), v, + total_key_lengths->Data(), + batch_size, sequence_length, parameters.total_sequence_length, + past_buffer_sequence_length, present_buffer_sequence_length, head_size, parameters.hidden_size, + past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp); } else { ComputeAttentionProbs( - static_cast(attention_probs), Q, k, total_key_lengths->Data(), - batch_size, sequence_length, parameters.total_sequence_length, - past_buffer_sequence_length, present_buffer_sequence_length, head_size, - past_key->Data(), present_key->MutableData(), past_present_share_buffer, packed_qkv, - block_row_indices->Data(), block_col_indices->Data(), parameters, tp); + static_cast(attention_probs), Q, k, total_key_lengths->Data(), + batch_size, sequence_length, parameters.total_sequence_length, + past_buffer_sequence_length, present_buffer_sequence_length, head_size, + past_key->Data(), present_key->MutableData(), past_present_share_buffer, packed_qkv, + block_row_indices->Data(), block_col_indices->Data(), parameters, tp); ComputeVxAttentionScore( - output->MutableData(), static_cast(attention_probs), v, - total_key_lengths->Data(), - batch_size, sequence_length, parameters.total_sequence_length, - past_buffer_sequence_length, present_buffer_sequence_length, head_size, parameters.hidden_size, - past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp); + output->MutableData(), static_cast(attention_probs), v, + total_key_lengths->Data(), + batch_size, sequence_length, parameters.total_sequence_length, + past_buffer_sequence_length, present_buffer_sequence_length, head_size, parameters.hidden_size, + past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp); } - return Status::OK(); } From cc0583fd91c51f9530d247f64288acf84f4be9a3 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 12 Mar 2025 22:50:51 +0000 Subject: [PATCH 5/6] fix build --- .../cpu/sparse/sparse_attention.cc | 22 +++++++++---------- .../cpu/sparse/sparse_attention_base.h | 14 +++++++----- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc index 2ad1ec1d6b328..469084e7b4491 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc @@ -21,17 +21,17 @@ using onnxruntime::concurrency::ThreadPool; namespace onnxruntime { namespace contrib { -#define REGISTER_KERNEL_TYPED(T) -ONNX_OPERATOR_TYPED_KERNEL_EX( - SparseAttention, - kMSDomain, - 1, - T, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("M", DataTypeImpl::GetTensorType()), - SparseAttention); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SparseAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + SparseAttention); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h index da997c2532ef7..0c483f21246e5 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -88,28 +88,28 @@ class SparseAttentionBase { batch_size, sequence_length, parameters.total_sequence_length, past_buffer_sequence_length, present_buffer_sequence_length, head_size, past_key->Data(), present_key->MutableData(), past_present_share_buffer, packed_qkv, - block_row_indices->Data(), block_col_indices->Data(), parameters, tp); + block_row_indices->Data(), block_col_indices->Data(), parameters, tp, allocator); ComputeVxAttentionScore( output->MutableData(), static_cast(attention_probs), v, total_key_lengths->Data(), batch_size, sequence_length, parameters.total_sequence_length, past_buffer_sequence_length, present_buffer_sequence_length, head_size, parameters.hidden_size, - past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp); + past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp, allocator); } else { ComputeAttentionProbs( static_cast(attention_probs), Q, k, total_key_lengths->Data(), batch_size, sequence_length, parameters.total_sequence_length, past_buffer_sequence_length, present_buffer_sequence_length, head_size, past_key->Data(), present_key->MutableData(), past_present_share_buffer, packed_qkv, - block_row_indices->Data(), block_col_indices->Data(), parameters, tp); + block_row_indices->Data(), block_col_indices->Data(), parameters, tp, allocator); ComputeVxAttentionScore( output->MutableData(), static_cast(attention_probs), v, total_key_lengths->Data(), batch_size, sequence_length, parameters.total_sequence_length, past_buffer_sequence_length, present_buffer_sequence_length, head_size, parameters.hidden_size, - past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp); + past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp, allocator); } return Status::OK(); @@ -138,7 +138,8 @@ class SparseAttentionBase { const int32_t* block_row_indices, // block row indices const int32_t* block_col_indices, // block column indices SparseAttentionParameters& parameters, // parameters - ThreadPool* tp) const { // thread pool + ThreadPool* tp, // thread pool + AllocatorPtr allocator) const { const bool is_prompt = (total_sequence_length == sequence_length); const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size @@ -367,7 +368,8 @@ class SparseAttentionBase { T* present_value, // present value only bool past_present_share_buffer, // whether past_key and present_key share the buffer bool packed_qkv, // whether Q, K, V are packed - ThreadPool* tp) const { + ThreadPool* tp, + AllocatorPtr allocator) const { const bool is_prompt = sequence_length == total_sequence_length; const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size From e320a976932fb4e91e3fe09589742747349ba1b8 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Thu, 13 Mar 2025 11:07:51 -0700 Subject: [PATCH 6/6] fix bot comments --- docs/OperatorKernels.md | 2 +- .../contrib_ops/cpu/sparse/sparse_attention_base.h | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8d256a2088279..60d9e8e747eeb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -551,7 +551,7 @@ Do not modify directly.* |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_row_indices:**M**
*in* block_col_indices:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)| +|SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_row_indices:**M**
*in* block_col_indices:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h index 0c483f21246e5..cccaec0b16ce5 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -234,15 +234,15 @@ class SparseAttentionBase { q, head_size, k, head_size, output, total_seq_len, MLFloat16(alpha).val, static_cast(0) /*beta*/, nullptr); } else { - size_t bytes = head_size * (sequence_length + total_seq_len) * sizeof(float); + size_t bytes = static_cast(head_size) * (sequence_length + total_seq_len) * sizeof(float); auto q_k_fp32 = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator)); float* q_fp32 = static_cast(q_k_fp32); - MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length); + MlasConvertHalfToFloatBuffer(q, q_fp32, static_cast(head_size) * sequence_length); float* k_fp32 = q_fp32 + head_size * sequence_length; - MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seq_len); + MlasConvertHalfToFloatBuffer(k, k_fp32, static_cast(head_size) * total_seq_len); math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seq_len, head_size, alpha, q_fp32, head_size, k_fp32, head_size, 0.0f /*bata*/, @@ -448,12 +448,12 @@ class SparseAttentionBase { v, head_size, output_current, hidden_size, MLFloat16(1.0f).val, static_cast(0) /*beta*/, nullptr); } else { - size_t bytes = head_size * total_seq_len * sizeof(float); + size_t bytes = static_cast(head_size) * total_seq_len * sizeof(float); auto v_fp32 = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator)); float* v_fp32_ptr = static_cast(v_fp32); - MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seq_len); + MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, static_cast(head_size) * total_seq_len); float* output_fp32_current = static_cast(output_fp32) + (batch_index * sequence_length * num_heads_ + head_index) * head_size;