Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
bc95327
Work in progress
kunal-vaishnavi Feb 11, 2023
d95d4a3
Work in progress 2
kunal-vaishnavi Feb 11, 2023
dc0c918
Work in progress 3
kunal-vaishnavi Feb 17, 2023
aa36390
Work in progress 4
kunal-vaishnavi Feb 18, 2023
65ea436
Work in progress 5
kunal-vaishnavi Feb 18, 2023
c3b2564
Work in progress 6
kunal-vaishnavi Feb 22, 2023
8c30983
Work in progress 7
kunal-vaishnavi Feb 27, 2023
b3d1e26
Work in progress 8
kunal-vaishnavi Mar 1, 2023
2a24376
Work in progress 9
kunal-vaishnavi Mar 9, 2023
8aea1da
Work in progress 10
kunal-vaishnavi Mar 10, 2023
5de0331
Work in progress 11
kunal-vaishnavi Mar 10, 2023
dedd007
Work in progress 12
kunal-vaishnavi Mar 13, 2023
a7bff6b
Work in progress 13
kunal-vaishnavi Mar 15, 2023
2fa2201
Cleaning up comments
kunal-vaishnavi Mar 15, 2023
53416bf
Cleaning up more comments
kunal-vaishnavi Mar 15, 2023
ea23e01
Merge branch 'main' into dev
kunal-vaishnavi Mar 15, 2023
bf7f23f
Merge branch 'microsoft:main' into dev
kunal-vaishnavi Mar 15, 2023
7e7b19f
Fixing few issues after merging with main
kunal-vaishnavi Mar 17, 2023
4bf560a
Fix multihead attention flag
kunal-vaishnavi Mar 18, 2023
5ef69a5
Changing attention fusion in decoder with past to multihead attention…
kunal-vaishnavi Mar 22, 2023
09235ba
Fix separating present KV into present K and present V
kunal-vaishnavi Mar 24, 2023
911768c
Adding test cases, fusion changes, and kernel changes
kunal-vaishnavi Apr 11, 2023
f8389eb
Removing commented out code
kunal-vaishnavi Apr 11, 2023
96e061c
Remove QKV format assert
kunal-vaishnavi Apr 11, 2023
c106f32
Remove condition for memory efficient attention
kunal-vaishnavi Apr 11, 2023
b2f3d99
Adding onnx test files
kunal-vaishnavi Apr 11, 2023
406a5d9
Merge branch 'main' into dev
kunal-vaishnavi Apr 11, 2023
4003653
Add ORT return if error
kunal-vaishnavi Apr 11, 2023
d1aaa56
Fix allocator naming and casting
kunal-vaishnavi Apr 12, 2023
9b341cf
Fix casting and remove extra parameter
kunal-vaishnavi Apr 12, 2023
ee32f88
Fix CodeQL scan errors and convert value to float
kunal-vaishnavi Apr 12, 2023
7d36cae
Fix test cases
kunal-vaishnavi Apr 13, 2023
33299d1
Fix more test cases
kunal-vaishnavi Apr 13, 2023
0e5d42c
Add whisper folder to build
kunal-vaishnavi Apr 13, 2023
8c2b2a4
Adding format changes suggested by linter
kunal-vaishnavi Apr 14, 2023
2b002ab
Remove extra parenthesis
kunal-vaishnavi Apr 14, 2023
0388430
Adding more format changes suggested by linter
kunal-vaishnavi Apr 14, 2023
2a94bdf
Adding space and comma suggestions from linter
kunal-vaishnavi Apr 14, 2023
97aaedb
Fix allocator initialization
kunal-vaishnavi Apr 14, 2023
dbda09d
Remove commented out line
kunal-vaishnavi Apr 14, 2023
b65e668
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Apr 15, 2023
52e34c8
Remove packed qkv and simplify calculating present kv
kunal-vaishnavi Apr 18, 2023
a75c121
Add changes suggested by new linter
kunal-vaishnavi Apr 18, 2023
70eab06
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Apr 18, 2023
6868e94
Add changes suggested by new C++ linter
kunal-vaishnavi Apr 18, 2023
bc17d24
Merge branch 'main' into kvaishnavi/whisper
kunal-vaishnavi Apr 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
file(GLOB onnxruntime_python_transformers_testdata_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/*.onnx"
)
file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx"
)
endif()

file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS
Expand Down Expand Up @@ -523,6 +526,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/quantization
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/eager_test
COMMAND ${CMAKE_COMMAND} -E copy
${ONNXRUNTIME_ROOT}/__init__.py
Expand Down Expand Up @@ -661,6 +665,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_whisper}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper/
)
endif()

Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ Do not modify directly.*
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcMaxPool|*in* x:**T**<br> *out* y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
auto* packed_weights_data = static_cast<uint8_t*>(alloc->AllocArray(packb_size, loop_len));

// Initialize memory to 0 as there could be some padding associated with pre-packed
// buffer memory and we don not want it uninitialized and generate different hashes
// buffer memory and we do not want it uninitialized and generate different hashes
// if and when we try to cache this pre-packed buffer for sharing between sessions.
memset(packed_weights_data, 0, packed_weights_data_size);
packed_weights_[qkv_index] = BufferUniquePtr(packed_weights_data, BufferDeleter(std::move(alloc)));
Expand Down Expand Up @@ -328,8 +328,9 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
}

// Compute the attention score and apply the score to V
return ApplyAttention(Q, K, V, mask_index, past, output,
batch_size, sequence_length,
return ApplyAttention(Q, K, V, mask_index, past, nullptr /* past_key */, nullptr /* past_value */,
output, nullptr /* present_key */, nullptr /* present_value */,
batch_size, sequence_length, sequence_length,
parameters.head_size, parameters.v_head_size, parameters.v_hidden_size,
relative_position_bias, context);
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class AttentionBase {
const Tensor* past,
int batch_size,
int head_size,
int sequence_length,
int kv_sequence_length,
int& past_sequence_length) const;

protected:
Expand Down
85 changes: 58 additions & 27 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,35 @@ class AttentionCPUBase : public AttentionBase {

template <typename T>
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
const T* K, // K data with shape BxNxSxH
const T* V, // V value with size BxNxSxH_v
const T* K, // K data with shape BxNxLxH
const T* V, // V value with size BxNxLxH_v
const Tensor* mask_index, // mask index. nullptr if no mask or its size is B
const Tensor* past, // past state
const Tensor* past_key, // past K input tensor (if not using past state)
const Tensor* past_value, // past V input tensor (if not using past state)
Tensor* output, // output tensor
Tensor* present_key, // present K output tensor (if separating present KV)
Tensor* present_value, // present V output tensor (if separating present KV)
int batch_size, // batch size (B)
int sequence_length, // sequence length (S)
int sequence_length, // sequence length of Q (S)
int kv_sequence_length, // sequence length of K or V (L)
int qk_head_size, // head size of Q or K (H)
int v_head_size, // head size of V (H_v)
int v_hidden_size, // hidden size of V (D_v)
const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT
OpKernelContext* context) const {
const int kv_sequence_length = sequence_length;

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

auto* tp = context->GetOperatorThreadPool();

int past_sequence_length = 0;
Tensor* present = GetPresent(context, past, batch_size, v_head_size, sequence_length, past_sequence_length);
Tensor* present = nullptr;
if (present_key == nullptr && present_value == nullptr) {
present = GetPresent(context, past, batch_size, v_head_size, kv_sequence_length, past_sequence_length);
} else if (past_key != nullptr && past_value != nullptr) {
past_sequence_length = static_cast<int>(past_key->Shape().GetDims()[2]);
}

// Total sequence length including that of past state: T = P + L
const int total_sequence_length = past_sequence_length + kv_sequence_length;
Expand All @@ -66,6 +74,10 @@ class AttentionCPUBase : public AttentionBase {
: gsl::span<const int64_t>{};
const T* past_data = past != nullptr ? past->Data<T>() : nullptr;
T* present_data = present != nullptr ? present->MutableData<T>() : nullptr;
const T* past_key_data = past_key != nullptr ? past_key->Data<T>() : nullptr;
T* present_key_data = present_key != nullptr ? present_key->MutableData<T>() : nullptr;
const T* past_value_data = past_value != nullptr ? past_value->Data<T>() : nullptr;
T* present_value_data = present_value != nullptr ? present_value->MutableData<T>() : nullptr;

const T* relative_position_bias_data = nullptr;
if (relative_position_bias != nullptr) {
Expand All @@ -74,9 +86,9 @@ class AttentionCPUBase : public AttentionBase {

ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, K,
mask_index_data, mask_index_dims, static_cast<T*>(mask_data), has_unidirectional,
batch_size, sequence_length, past_sequence_length,
qk_head_size == 0 ? v_head_size : qk_head_size,
past_data, present_data, tp, relative_position_bias_data);
batch_size, sequence_length, kv_sequence_length, past_sequence_length,
qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data,
present_data, present_key_data, tp, relative_position_bias_data);

// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
auto out_tmp_data =
Expand All @@ -86,8 +98,8 @@ class AttentionCPUBase : public AttentionBase {
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(out_tmp_data),
static_cast<T*>(attention_probs), V,
batch_size, sequence_length, kv_sequence_length, past_sequence_length,
v_head_size, v_hidden_size,
past_data, present_data, tp);
v_head_size, v_hidden_size, past_data, past_value_data,
present_data, present_value_data, tp);

return Status::OK();
}
Expand All @@ -106,27 +118,39 @@ class AttentionCPUBase : public AttentionBase {
T* mask_data, // buffer for mask data.
bool has_unidirectional, // has unidirectional mask
int batch_size, // batch size of self-attention
int sequence_length, // sequence length of self-attention
int sequence_length, // sequence length of self-attention (S)
int kv_sequence_length, // sequence length of cross-attention (L)
int past_sequence_length, // sequence length of past state
int head_size, // head size of self-attention
const T* past, // past state
const T* past_key, // past key only (if not using past state)
T* present, // present state
T* present_key, // present key only (if not using present state)
ThreadPool* tp, // thread pool
const T* relative_position_bias_data // bias addition matrix with shape BxNxSxT
) const {
const int total_sequence_length = past_sequence_length + sequence_length; // T = P + L
const size_t past_chunk_length = static_cast<size_t>(past_sequence_length) * head_size; // P x H
const size_t input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // L x H
const size_t present_chunk_length = past_chunk_length + input_chunk_length; // T x H
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
const size_t past_chunk_length = static_cast<size_t>(past_sequence_length) * head_size; // P x H
const size_t q_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // S x H
const size_t kv_input_chunk_length = static_cast<size_t>(kv_sequence_length) * head_size; // L x H
const size_t present_chunk_length = past_chunk_length + kv_input_chunk_length; // T x H

{
// mask_data is nullptr when mask_index is nullptr and not unidirectional, otherwise its shape is BxSxT
if (mask_data != nullptr) {
PrepareMask(mask_index, mask_index_dims, mask_data,
has_unidirectional, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
} else { // no any mask
size_t bytes = static_cast<size_t>(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T);
memset(attention_probs, 0, bytes);
const int memset_loop_len = batch_size * num_heads_;
const double memset_cost = static_cast<double>(sequence_length) * total_sequence_length;

ThreadPool::TryParallelFor(tp, memset_loop_len, memset_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int output_offset = static_cast<int>(i) * sequence_length * total_sequence_length;
T* output = attention_probs + output_offset;
memset(output, 0, static_cast<size_t>(sequence_length) * total_sequence_length * sizeof(T));
}
});
}

const int loop_len = batch_size * num_heads_;
Expand All @@ -150,10 +174,12 @@ class AttentionCPUBase : public AttentionBase {
static_cast<size_t>(sequence_length) * total_sequence_length * sizeof(T));
}

const T* k = K + input_chunk_length * i;
const T* k = K + kv_input_chunk_length * i;
if (nullptr != present) {
// Concatenate past_K and K : (BxNx)PxH, (BxNx)LxH -> (BxNx)TxH
k = ConcatStateChunk(past, k, present, past_chunk_length, present_chunk_length, i);
} else if (nullptr != present_key) {
k = ConcatStateChunk(past_key, k, present_key, past_chunk_length, present_chunk_length, i);
}

// Compute Q*K' + AttentionMask
Expand All @@ -162,7 +188,7 @@ class AttentionCPUBase : public AttentionBase {
// B: K' (B x N x) T x H (B x N x) H x T H x T
// C: attention_probs (B x N x) S x T (B x N x) S x T S x T
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha,
Q + input_chunk_length * i, k, 1.0,
Q + q_input_chunk_length * i, k, 1.0,
output, nullptr);

// Fix unidirectional mask to be parity with huggingface implementation.
Expand All @@ -184,7 +210,7 @@ class AttentionCPUBase : public AttentionBase {
});
}

// attention_probs(B, N, S, T) = Softmax(attention_probs)
// attention_probs(B, N, S, T) = Softmax(attention_probs)
{
const int N = batch_size * num_heads_ * sequence_length;
const int D = total_sequence_length;
Expand All @@ -204,12 +230,15 @@ class AttentionCPUBase : public AttentionBase {
int v_head_size, // head size of V (H_v)
int v_hidden_size, // hidden size of V (D_v)
const T* past, // past state
const T* past_value, // past value only (if not using past state)
T* present, // present state
T* present_value, // present value only (if not using present state)
ThreadPool* tp) const {
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
const ptrdiff_t past_chunk_length = SafeInt<ptrdiff_t>(past_sequence_length) * v_head_size; // P x H_v
const ptrdiff_t input_chunk_length = SafeInt<ptrdiff_t>(kv_sequence_length) * v_head_size; // L x H_v
const ptrdiff_t present_chunk_length = past_chunk_length + input_chunk_length; // T x H_v
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
const ptrdiff_t past_chunk_length = SafeInt<ptrdiff_t>(past_sequence_length) * v_head_size; // P x H_v
const ptrdiff_t q_input_chunk_length = SafeInt<ptrdiff_t>(sequence_length) * v_head_size; // S x H_v
const ptrdiff_t kv_input_chunk_length = SafeInt<ptrdiff_t>(kv_sequence_length) * v_head_size; // L x H_v
const ptrdiff_t present_chunk_length = past_chunk_length + kv_input_chunk_length; // T x H_v

// Move the pointer of past and present to start of v values.
if (nullptr != past) {
Expand All @@ -224,13 +253,15 @@ class AttentionCPUBase : public AttentionBase {

ThreadPool::TryParallelFor(tp, SafeInt<ptrdiff_t>(batch_size) * num_heads_, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const T* v = V + input_chunk_length * i;
const T* v = V + kv_input_chunk_length * i;
if (nullptr != present) {
// Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v
v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i);
} else if (nullptr != present_value) {
v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i);
}

T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + input_chunk_length * i;
T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + q_input_chunk_length * i;
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * total_sequence_length * i;
math::MatMul<T>(sequence_length, v_head_size, total_sequence_length,
attention_probs + attention_probs_offset,
Expand Down
Loading