Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions onnxruntime/core/providers/cpu/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ void make_copy<MLFloat16, MLFloat16>(MLFloat16* mask_data, const MLFloat16* mask
template <>
void make_copy<float, bool>(float* mask_data, const bool* mask_index, size_t size) {
for (size_t i = 0; i < size; ++i) {
mask_data[i] = mask_index[i] ? 0.0f : std::numeric_limits<float>::lowest();
mask_data[i] = mask_index[i] ? 0.0f : negative_infinity<float>();
}
}

template <>
void make_copy<MLFloat16, bool>(MLFloat16* mask_data, const bool* mask_index, size_t size) {
for (size_t i = 0; i < size; ++i) {
mask_data[i] = mask_index[i] ? MLFloat16(0.f) : std::numeric_limits<MLFloat16>::lowest();
mask_data[i] = mask_index[i] ? MLFloat16(0.f) : negative_infinity<MLFloat16>();
}
}

Expand Down Expand Up @@ -236,7 +236,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
mask_data = static_cast<T*>(allocated_ptr);
for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
mask_data[s_i * parameters.total_sequence_length + m_i] = std::numeric_limits<T>::lowest();
mask_data[s_i * parameters.total_sequence_length + m_i] = negative_infinity<T>();
}
}
delete_mask_data = true;
Expand All @@ -262,7 +262,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
for (int i = 0; i < n_iter; ++i) {
for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = std::numeric_limits<T>::lowest();
mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = negative_infinity<T>();
}
}
}
Expand Down Expand Up @@ -317,7 +317,8 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
}

// handling GQA
std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads;
std::ptrdiff_t head_ki = head_i * parameters.kv_num_heads / parameters.q_num_heads;
std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_ki;
const T* k = K + k_input_chunk_length * ki;

if (nullptr != present_key) {
Expand Down Expand Up @@ -347,7 +348,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
alpha,
Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size,
parameters.head_size * parameters.q_num_heads, // lda
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size : k,
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k,
transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb
beta,
output,
Expand Down Expand Up @@ -555,7 +556,8 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
// handling GQA
std::ptrdiff_t batch_i = i / num_heads;
std::ptrdiff_t head_i = i % num_heads;
std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads;
std::ptrdiff_t head_vi = head_i * kv_num_heads / num_heads;
std::ptrdiff_t vi = batch_i * kv_num_heads + head_vi;
const T* v = V + v_input_chunk_length * vi;

if (nullptr != present_value) {
Expand All @@ -579,15 +581,15 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
// V is transposed but not QK. We use GemmEx with a different value for ldb.
math::GemmEx<T, ThreadPool>(CblasNoTrans,
CblasNoTrans,
sequence_length, // M
v_head_size, // N
total_sequence_length, // K
1.f, // alpha
attention_probs + attention_probs_offset, // QK
total_sequence_length, // lda
transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
0.f, // beta
sequence_length, // M
v_head_size, // N
total_sequence_length, // K
1.f, // alpha
attention_probs + attention_probs_offset, // QK
total_sequence_length, // lda
transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
0.f, // beta
output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size),
v_head_size * num_heads, // ldc
nullptr);
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/cpu/llm/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@

namespace onnxruntime {

template <typename T>
inline T negative_infinity() {
return -std::numeric_limits<T>::infinity();
}

template <>
inline MLFloat16 negative_infinity() {
return MLFloat16(-std::numeric_limits<float>::infinity());
}

template <typename T>
class AttentionBase : public OpKernel {
public:
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,24 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
// Please make no more changes to the list
static const ORTCHAR_T* immutable_broken_tests[] =
{
// pending ONNX update
ORT_TSTR("attention_3d_gqa"),
ORT_TSTR("attention_3d_gqa_attn_mask"),
ORT_TSTR("attention_3d_gqa_causal"),
ORT_TSTR("attention_3d_gqa_scaled"),
ORT_TSTR("attention_3d_gqa_softcap"),
ORT_TSTR("attention_3d_gqa_with_past_and_present"),
ORT_TSTR("attention_4d_gqa"),
ORT_TSTR("attention_4d_gqa_attn_mask"),
ORT_TSTR("attention_4d_gqa_causal"),
ORT_TSTR("attention_4d_gqa_scaled"),
ORT_TSTR("attention_4d_gqa_softcap"),
ORT_TSTR("attention_4d_gqa_with_past_and_present"),
ORT_TSTR("attention_4d_diff_heads_mask4d_padded_kv"),
ORT_TSTR("attention_4d_gqa_with_past_and_present_fp16"),
ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal"),
ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal"),
// unsupported case
ORT_TSTR("AvgPool1d"),
ORT_TSTR("AvgPool1d_stride"),
ORT_TSTR("AvgPool2d"),
Expand Down
Loading
Loading