Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Inputs (5 - 10)

<dl>
<dt><tt>input_ids</tt> : I</dt>
<dt><tt>input_ids</tt> : F</dt>
<dd>The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)</dd>
<dt><tt>max_length</tt> : I</dt>
<dd>The maximum length of the sequence to be generated. Shape is (1)</dd>
Expand Down Expand Up @@ -466,7 +466,9 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>T</tt> : tensor(float)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dd>Constrain to float tensors.</dd>
<dt><tt>F</tt> : tensor(float), tensor(int32)</dt>
<dd>Constrain input type to float or int tensors.</dd>
<dt><tt>I</tt> : tensor(int32)</dt>
<dd>Constrain to integer types</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**<br> *in* cur_tokens:**T**<br> *in* prev_suffix_match_idx:**T**<br> *in* pred_tokens:**T**<br> *out* tokens:**T**<br> *out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
Expand Down Expand Up @@ -786,7 +786,7 @@ Do not modify directly.*
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasAdd|*in* X:**T**<br> *in* bias:**T**<br> *in* skip:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
Expand Down
160 changes: 121 additions & 39 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
#include "contrib_ops/cpu/transformers/beam_search_impl_gpt.h"
#include "contrib_ops/cpu/transformers/beam_search_impl_t5.h"
#include "contrib_ops/cpu/transformers/subgraph_whisper_encoder.h"
#include "contrib_ops/cpu/transformers/greedy_search_impl_gpt.h"

using namespace ONNX_NAMESPACE;
Expand Down Expand Up @@ -62,7 +63,8 @@ void BeamSearch::Init(const OpKernelInfo& info) {

// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt ||
parameters_.model_type == IGenerationParameters::kModelTypeT5);
parameters_.model_type == IGenerationParameters::kModelTypeT5 ||
parameters_.model_type == IGenerationParameters::kModelTypeWhisper);

ONNX_NAMESPACE::GraphProto proto;

Expand Down Expand Up @@ -148,6 +150,37 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
t5_decoder_subgraph_->num_layers);
}
}
else if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) {
if (attribute_name == "encoder") {
ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
t5_encoder_subgraph_ = std::make_unique<WhisperEncoderSubgraph>(node,
attribute_name,
subgraph_session_state.GetGraphViewer());
ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();

if (parameters_.decoder_start_token_id < 0) {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
} else {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
}
} else if (attribute_name == "decoder") {
ORT_ENFORCE(t5_decoder_subgraph_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
t5_decoder_subgraph_ = std::make_unique<T5DecoderSubgraph>(node,
attribute_name,
subgraph_session_state.GetGraphViewer());
ORT_RETURN_IF_ERROR(t5_decoder_subgraph_->Setup(session_state, subgraph_session_state));
decoder_feeds_fetches_manager_ = t5_decoder_subgraph_->GetFeedsFetchesManager();
parameters_.SetSubgraphParameters(t5_decoder_subgraph_->vocab_size,
t5_decoder_subgraph_->num_heads,
t5_decoder_subgraph_->head_size,
t5_decoder_subgraph_->num_layers);
}
}

return Status::OK();
}
Expand Down Expand Up @@ -224,45 +257,94 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
ORT_ENFORCE(encoder_session_state, "Subgraph SessionState was not found for 'encoder' attribute.");
ORT_ENFORCE(encoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");

// Subgraph has constraint that the output is either float or float16
if (!t5_decoder_subgraph_->IsOutputFloat16()) {
BeamSearchT5<float> impl{
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<float>,
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer<int32_t>,
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>};
ORT_RETURN_IF_ERROR(impl.Initialize());

return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
} else {
BeamSearchT5<MLFloat16> impl{
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_fp16_func_,
init_beam_state_fp16_func_,
device_copy_func_,
device_copy_int32_func_,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
update_decoder_feeds_fp16_func_,
expand_buffer_int32_func_,
expand_buffer_float_func_,
expand_buffer_float16_func_};

ORT_RETURN_IF_ERROR(impl.Initialize());

return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);

if (parameters_.model_type == IGenerationParameters::kModelTypeT5) {
// Subgraph has constraint that the output is either float or float16
if (!t5_decoder_subgraph_->IsOutputFloat16()) {
BeamSearchT5<float> impl{
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<float>,
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer<int32_t>,
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>};
ORT_RETURN_IF_ERROR(impl.Initialize());

return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
} else {
BeamSearchT5<MLFloat16> impl{
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_fp16_func_,
init_beam_state_fp16_func_,
device_copy_func_,
device_copy_int32_func_,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
update_decoder_feeds_fp16_func_,
expand_buffer_int32_func_,
expand_buffer_float_func_,
expand_buffer_float16_func_};

ORT_RETURN_IF_ERROR(impl.Initialize());

return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
}
}

// Change the CreateEncoderInputs function for Whisper shapes
if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) {
// Subgraph has constraint that the output is either float or float16
if (!t5_decoder_subgraph_->IsOutputFloat16()) {
BeamSearchT5<float> impl{
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs,
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<float>,
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer<int32_t>,
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>};
ORT_RETURN_IF_ERROR(impl.Initialize());

return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
} else {
BeamSearchT5<MLFloat16> impl{
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_fp16_func_,
init_beam_state_fp16_func_,
device_copy_func_,
device_copy_int32_func_,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs,
update_decoder_feeds_fp16_func_,
expand_buffer_int32_func_,
expand_buffer_float_func_,
expand_buffer_float16_func_};

ORT_RETURN_IF_ERROR(impl.Initialize());

return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
}
}

// Model type not supported in IGenerationParameters
ORT_THROW("Model type is not supported.");
}

} // namespace transformers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
ORT_ENFORCE(context != nullptr);
const Tensor* input_ids = context->Input<Tensor>(0);
const auto& dims = input_ids->Shape().GetDims();
ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size());
if (this->model_type == IGenerationParameters::kModelTypeWhisper){
ORT_ENFORCE(dims.size() == 3, "input_features shall have 3 dimensions. Got ", dims.size());
}
else {
ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size());
}
batch_size = static_cast<int>(dims[0]);

// For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1
Expand Down
20 changes: 16 additions & 4 deletions onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,17 @@ class GenerateBase {
const Tensor* attention_mask,
const Tensor* presence_mask) const {
const auto& dims = input_ids->Shape().GetDims();
if (dims.size() != 2) {
if (parameters->model_type == IGenerationParameters::kModelTypeWhisper){
if (dims.size() != 3){
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'input_features' is expected to have 3 dimensions, got ", dims.size());
}

}
else if (dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'input_ids' is expected to have 2 dimensions, got ", dims.size());
}
}

if (vocab_mask != nullptr) { // vocab_mask is optional
const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
Expand Down Expand Up @@ -174,7 +181,13 @@ class GenerateBase {

if (attention_mask != nullptr) {
const auto& dims_attn = attention_mask->Shape().GetDims();
if (dims_attn.size() != 2) {
if (parameters->model_type == IGenerationParameters::kModelTypeWhisper) {
if (dims_attn.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to have 3 dimensions, got ", dims_attn.size());
}
}
else if (dims_attn.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to have 2 dimensions, got ", dims_attn.size());
}
Expand All @@ -183,7 +196,6 @@ class GenerateBase {
"Input 'attention_mask' is expected to have same shape as input_ids");
}
}

if (presence_mask != nullptr) {
const auto& dims_presence = presence_mask->Shape().GetDims();
if (dims_presence.size() != 2) {
Expand Down
Loading