diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index cb7823f06b4c2..8a5117bcf1e46 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -429,7 +429,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Inputs (5 - 10)
-
input_ids : I
+
input_ids : F
The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)
max_length : I
The maximum length of the sequence to be generated. Shape is (1)
@@ -466,7 +466,9 @@ This version of the operator has been available since version 1 of the 'com.micr
T : tensor(float)
-
Constrain input and output types to float tensors.
+
Constrain to float tensors.
+
F : tensor(float), tensor(int32)
+
Constrain input type to float or int tensors.
I : tensor(int32)
Constrain to integer types
M : tensor(int32)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 050d84b19cc97..336ef560a9fa3 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -419,7 +419,7 @@ Do not modify directly.* |**Operator Domain:** *com.microsoft*|||| |Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| -|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| +|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| |BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)| |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| @@ -786,7 +786,7 @@ Do not modify directly.* | | |**Operator Domain:** *com.microsoft*|||| |Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| -|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| +|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index de2513789c508..a201a2f5d8edd 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -34,6 +34,7 @@ #include "contrib_ops/cpu/transformers/beam_search_scorer.h" #include "contrib_ops/cpu/transformers/beam_search_impl_gpt.h" #include "contrib_ops/cpu/transformers/beam_search_impl_t5.h" +#include "contrib_ops/cpu/transformers/subgraph_whisper_encoder.h" #include "contrib_ops/cpu/transformers/greedy_search_impl_gpt.h" using namespace ONNX_NAMESPACE; @@ -62,7 +63,8 @@ void BeamSearch::Init(const OpKernelInfo& info) { // Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5) ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt || - parameters_.model_type == IGenerationParameters::kModelTypeT5); + parameters_.model_type == IGenerationParameters::kModelTypeT5 || + parameters_.model_type == IGenerationParameters::kModelTypeWhisper); ONNX_NAMESPACE::GraphProto proto; @@ -148,6 +150,37 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, t5_decoder_subgraph_->num_layers); } } + else if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) { + if (attribute_name == "encoder") { + ORT_ENFORCE(t5_encoder_subgraph_ == nullptr, + "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + t5_encoder_subgraph_ = std::make_unique(node, + attribute_name, + subgraph_session_state.GetGraphViewer()); + ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state)); + encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager(); + + if (parameters_.decoder_start_token_id < 0) { + ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2, + "Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty"); + } else { + ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3, + "Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available"); + } + } else if (attribute_name == "decoder") { + ORT_ENFORCE(t5_decoder_subgraph_ == nullptr, + "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + t5_decoder_subgraph_ = std::make_unique(node, + attribute_name, + subgraph_session_state.GetGraphViewer()); + ORT_RETURN_IF_ERROR(t5_decoder_subgraph_->Setup(session_state, subgraph_session_state)); + decoder_feeds_fetches_manager_ = t5_decoder_subgraph_->GetFeedsFetchesManager(); + parameters_.SetSubgraphParameters(t5_decoder_subgraph_->vocab_size, + t5_decoder_subgraph_->num_heads, + t5_decoder_subgraph_->head_size, + t5_decoder_subgraph_->num_layers); + } + } return Status::OK(); } @@ -224,45 +257,94 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { ORT_ENFORCE(encoder_session_state, "Subgraph SessionState was not found for 'encoder' attribute."); ORT_ENFORCE(encoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); - // Subgraph has constraint that the output is either float or float16 - if (!t5_decoder_subgraph_->IsOutputFloat16()) { - BeamSearchT5 impl{ - *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_, - *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters, - add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, - topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, - process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits, - init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState, - device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, - device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy, - create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs, - update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, - expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer, - expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, - expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer}; - ORT_RETURN_IF_ERROR(impl.Initialize()); - - return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); - } else { - BeamSearchT5 impl{ - *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_, - *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters, - add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, - topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, - process_logits_fp16_func_, - init_beam_state_fp16_func_, - device_copy_func_, - device_copy_int32_func_, - create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs, - update_decoder_feeds_fp16_func_, - expand_buffer_int32_func_, - expand_buffer_float_func_, - expand_buffer_float16_func_}; - - ORT_RETURN_IF_ERROR(impl.Initialize()); - - return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); + + if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { + // Subgraph has constraint that the output is either float or float16 + if (!t5_decoder_subgraph_->IsOutputFloat16()) { + BeamSearchT5 impl{ + *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_, + *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters, + add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, + process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits, + init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState, + device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, + device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy, + create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs, + update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, + expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer, + expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, + expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer}; + ORT_RETURN_IF_ERROR(impl.Initialize()); + + return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); + } else { + BeamSearchT5 impl{ + *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_, + *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters, + add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, + process_logits_fp16_func_, + init_beam_state_fp16_func_, + device_copy_func_, + device_copy_int32_func_, + create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs, + update_decoder_feeds_fp16_func_, + expand_buffer_int32_func_, + expand_buffer_float_func_, + expand_buffer_float16_func_}; + + ORT_RETURN_IF_ERROR(impl.Initialize()); + + return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); + } + } + + // Change the CreateEncoderInputs function for Whisper shapes + if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) { + // Subgraph has constraint that the output is either float or float16 + if (!t5_decoder_subgraph_->IsOutputFloat16()) { + BeamSearchT5 impl{ + *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_, + *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters, + add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, + process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits, + init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState, + device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, + device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy, + create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs, + update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, + expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer, + expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, + expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer}; + ORT_RETURN_IF_ERROR(impl.Initialize()); + + return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); + } else { + BeamSearchT5 impl{ + *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_, + *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters, + add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, + process_logits_fp16_func_, + init_beam_state_fp16_func_, + device_copy_func_, + device_copy_int32_func_, + create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs, + update_decoder_feeds_fp16_func_, + expand_buffer_int32_func_, + expand_buffer_float_func_, + expand_buffer_float16_func_}; + + ORT_RETURN_IF_ERROR(impl.Initialize()); + + return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); + } } + + // Model type not supported in IGenerationParameters + ORT_THROW("Model type is not supported."); } } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index bd3a72e989af0..f79f9b1dbf1cf 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -31,7 +31,12 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { ORT_ENFORCE(context != nullptr); const Tensor* input_ids = context->Input(0); const auto& dims = input_ids->Shape().GetDims(); - ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size()); + if (this->model_type == IGenerationParameters::kModelTypeWhisper){ + ORT_ENFORCE(dims.size() == 3, "input_features shall have 3 dimensions. Got ", dims.size()); + } + else { + ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size()); + } batch_size = static_cast(dims[0]); // For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1 diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h index c6e267d26e6df..1ac01d34209d1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h @@ -124,10 +124,17 @@ class GenerateBase { const Tensor* attention_mask, const Tensor* presence_mask) const { const auto& dims = input_ids->Shape().GetDims(); - if (dims.size() != 2) { + if (parameters->model_type == IGenerationParameters::kModelTypeWhisper){ + if (dims.size() != 3){ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'input_features' is expected to have 3 dimensions, got ", dims.size()); + } + + } + else if (dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input_ids' is expected to have 2 dimensions, got ", dims.size()); - } + } if (vocab_mask != nullptr) { // vocab_mask is optional const auto& vocab_mask_dims = vocab_mask->Shape().GetDims(); @@ -174,7 +181,13 @@ class GenerateBase { if (attention_mask != nullptr) { const auto& dims_attn = attention_mask->Shape().GetDims(); - if (dims_attn.size() != 2) { + if (parameters->model_type == IGenerationParameters::kModelTypeWhisper) { + if (dims_attn.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'attention_mask' is expected to have 3 dimensions, got ", dims_attn.size()); + } + } + else if (dims_attn.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'attention_mask' is expected to have 2 dimensions, got ", dims_attn.size()); } @@ -183,7 +196,6 @@ class GenerateBase { "Input 'attention_mask' is expected to have same shape as input_ids"); } } - if (presence_mask != nullptr) { const auto& dims_presence = presence_mask->Shape().GetDims(); if (dims_presence.size() != 2) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index e63f4b377726f..e971da11f9dd5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -831,6 +831,80 @@ Status UpdateDecoderFeeds( return Status::OK(); } +//------------------------------------------------ +// Modified Encoder functions for Whisper Model +//------------------------------------------------ + +Status CreateWhisperEncoderInputs( + const Tensor* original_encoder_input_features, + const OrtValue* attn_mask_value, + int pad_token_id, + int start_token_id, + AllocatorPtr allocator, + OrtValue& encoder_input_features, + OrtValue& encoder_attention_mask, + OrtValue& decoder_input_ids) { + const TensorShape& input_features_shape = original_encoder_input_features->Shape(); + ORT_ENFORCE(input_features_shape.NumDimensions() == 3); + const int64_t& batch_size = input_features_shape[0]; + const int64_t& sequence_length = input_features_shape[1]; + + // Allocate attention_mask based on shape of input_ids + auto element_type = DataTypeImpl::GetType(); + + // Use original encoder_input_ids. This requires the input_ids for subgraph is also int32. + // Current shape is (batch_size, sequence_length) + // Note that we will expand it to (batch_size * num_beams, sequence_length) later. + // To avoid cloning input_ids, we use const_cast here since this function does not change its content. + Tensor::InitOrtValue(DataTypeImpl::GetType(), + input_features_shape, + const_cast(original_encoder_input_features)->MutableData(), + allocator->Info(), + encoder_input_features); + + if (attn_mask_value != nullptr) { + const Tensor& attention_mask = attn_mask_value->Get(); + Tensor::InitOrtValue(element_type, input_features_shape, const_cast(&attention_mask)->MutableData(), + allocator->Info(), encoder_attention_mask); + } else { + auto mask_type = DataTypeImpl::GetType(); + Tensor::InitOrtValue(mask_type, input_features_shape, allocator, encoder_attention_mask); + + // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. + int32_t* mask_data = encoder_attention_mask.GetMutable()->MutableData(); + const int32_t* word_id = original_encoder_input_features->Data(); + int32_t* mask = mask_data; + for (int i = 0; i < batch_size; i++) { + int32_t abs_position = 0; + for (int j = 0; j < sequence_length; j++, word_id++, mask++) { + // T5Tokenizer might add one EOS pad token at the end. + // That EOS token shall have attention mask 1 even when EOS token is same as pad token. + // Here we only set attention mask to be 0 for left padding only, so as to be parity with huggingface. + if (*word_id == pad_token_id && abs_position == 0) { + *mask = 0; + } else { + *mask = 1; + abs_position++; + } + } + } + } + + // decoder_input_ids is optional. + if (start_token_id >= 0) { + // Filled decoder_input_ids with start token ID + int64_t dims[] = {batch_size, 1}; + TensorShape decoder_input_ids_shape(&dims[0], 2); + Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, decoder_input_ids); + int32_t* data = decoder_input_ids.GetMutable()->MutableData(); + for (int i = 0; i < batch_size; i++, data++) { + *data = start_token_id; + } + } + + return Status::OK(); +} + //------------------------------------------------ // Explicit template instantiations of functions //------------------------------------------------ @@ -950,4 +1024,4 @@ template Status ExpandBuffer( } // namespace GenerationCpuDeviceHelper } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index 3ad7be76a1800..66a1cea083a31 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -306,6 +306,21 @@ Status UpdateDecoderFeeds( transformers::Sequences& sequences, const transformers::IConsoleDumper* dumper); +// --------------------------------------------------------------- +// Functions for encoder-decoder model with float input like Whisper +// --------------------------------------------------------------- + +Status CreateWhisperEncoderInputs( + const Tensor* original_encoder_input_features, + const OrtValue* attn_mask_value, + int pad_token_id, + int start_token_id, + AllocatorPtr allocator, + OrtValue& encoder_input_ids, + OrtValue& encoder_attention_mask, + OrtValue& decoder_input_ids); + + // --------------------------------------------------------------- // Utility Functions // --------------------------------------------------------------- @@ -323,4 +338,4 @@ Status ExpandBuffer( } // namespace GenerationCpuDeviceHelper } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 3faba9a856273..79e9f04fad1e2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -129,9 +129,10 @@ class IBeamScorer { struct IGenerationParameters { static constexpr int kModelTypeGpt = 0; static constexpr int kModelTypeT5 = 1; + static constexpr int kModelTypeWhisper = 2; // Parameters from node attributes - int model_type; // 0 for GPT-2; 1 for encoder-decoder like T5 + int model_type; // 0 for GPT-2; 1 for encoder-decoder like T5; 2 for float inputs like Whisper int eos_token_id; int pad_token_id; int decoder_start_token_id; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 0460841ae155a..6c744b627d364 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -52,7 +52,7 @@ Status T5DecoderSubgraph::Validate(const std::vector& subgraph_i "kFirstPastInputIndex currently only supports 2 or 3"); ORT_RETURN_IF(num_subgraph_inputs < 4 + first_past_input_index_ || (num_subgraph_inputs - first_past_input_index_) % 4 != 0, - "number of outputs expected to be kFirstPastInputIndex + 4 * layers, got:", num_subgraph_inputs); + "number of inputs expected to be kFirstPastInputIndex + 4 * layers, got:", num_subgraph_inputs); ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - first_present_output_index_) % 2 != 0, "number of outputs expected to be 1 + 2 * layers, got:", num_subgraph_outputs); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h index a2e2e9842097a..a79f677f5a043 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h @@ -40,7 +40,7 @@ class T5EncoderSubgraph : public Subgraph { return first_present_output_index_; } - private: + protected: int first_present_output_index_; }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc new file mode 100644 index 0000000000000..8b2dde9518335 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/framework_common.h" +#include "core/framework/session_state.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/common/gsl.h" +#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h" +#include "contrib_ops/cpu/transformers/subgraph_whisper_encoder.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +/* Whisper Encoder Subgraph (It also contains decoder initialization where decoder_input_ids are filled with start token ID). + + Inputs: + encoder_input_features: float (B, encode_sequence_length) + encoder_attention_mask: int32 (B, encode_sequence_length) + decoder_input_ids: int32 (B, 1) + + Outputs: + logits: (B, 1, vocab_size) + encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + + present_key_self_0: (B, num_heads, 1, head_size) + present_value_self_0: (B, num_heads, 1, head_size) + ... (for each self attention layer) + + present_key_cross_0: (B, num_heads, encode_sequence_length, head_size) + present_value_cross_0: (B, num_heads, encode_sequence_length, head_size) + ... (for each cross attention layer) + + Note: + Here, B = batch_size * num_beams since we expand the inputs. + Ideally, we could use B=batch_size and expand the outputs with a factor of num_beams. + Data type of input or output is float or float16 if not specified. +*/ + +Status WhisperEncoderSubgraph::Validate(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) { + ORT_RETURN_IF(num_subgraph_inputs != 3, "expect 3 inputs, got:", num_subgraph_inputs); + + ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs); + ORT_RETURN_IF((static_cast(subgraph_outputs.size()) - first_present_output_index_) % 4 != 0, + "number of outputs expected to be 2 + 4 * layers, got:", num_subgraph_outputs); + + ORT_RETURN_IF(subgraph_inputs[0]->Name() != "encoder_input_ids", + "encoder subgraph input 0 shall be named as encoder_input_ids, got: ", subgraph_inputs[0]->Name()); + ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask", + "encoder subgraph input 1 shall be named as encoder_attention_mask, got: ", subgraph_inputs[1]->Name()); + ORT_RETURN_IF(subgraph_inputs[2]->Name() != "decoder_input_ids", + "encoder subgraph input 2 shall be named as decoder_input_ids, got: ", subgraph_inputs[2]->Name()); + + ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", + "encoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name()); + ORT_RETURN_IF(subgraph_outputs[1]->Name() != "encoder_hidden_states", + "encoder subgraph output 1 shall be named encoder_hidden_states, got: ", subgraph_outputs[1]->Name()); + ORT_RETURN_IF(subgraph_outputs[2]->Name() != "present_key_self_0", + "encoder subgraph output 2 shall be named as present_key_self_0, got: ", subgraph_outputs[2]->Name()); + ORT_RETURN_IF(subgraph_outputs[3]->Name() != "present_value_self_0", + "encoder subgraph output 3 shall be named as present_value_self_0, got: ", subgraph_outputs[3]->Name()); + + const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[2]->Shape(); + const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); + + // Save parameters related to the subgraph. + ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false)); + num_layers = (static_cast(subgraph_outputs.size()) - first_present_output_index_) / 4; + + constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32; + constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT; + constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16; + + ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float32_type, + "encoder subgraph input 0 (encoder_input_features) shall have float32 type"); + ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "encoder subgraph input 1 (encoder_attention_mask) shall have int32 type"); + ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "encoder subgraph input 2 (decoder_input_ids) shall have int32 type"); + + auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type(); + ORT_RETURN_IF(output_type != float32_type && output_type != float16_type, + "encoder subgraph output 0 (logits) shall be float or float16 data type"); + + for (int i = 1; i < num_subgraph_outputs; i++) { + ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != output_type, + "encoder subgraph outputs 1, 2, ... shall have same data type"); + } + + is_output_float16_ = (output_type == float16_type); + + return Status::OK(); +} +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h new file mode 100644 index 0000000000000..c48f3f10e5f5f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/transformers/subgraph_base.h" +#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +// A class for whisper encoder subgraph with validation to support float inputs. +class WhisperEncoderSubgraph : public T5EncoderSubgraph { + public: + WhisperEncoderSubgraph( + const onnxruntime::Node& node_in, + const std::string& attribute_name, + const GraphViewer& subgraph_in) : T5EncoderSubgraph(node_in, attribute_name, subgraph_in) {} + + Status Validate(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) override; +}; +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5ad939c8b8711..728de78de31a6 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -17,6 +17,7 @@ #include "core/graph/op.h" #include "core/mlas/inc/mlas.h" #include "core/graph/contrib_ops/onnx_function_util.h" +#include "contrib_ops/cpu/transformers/beam_search_parameters.h" #include "onnx/defs/function.h" // Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from // ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build @@ -421,8 +422,19 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { } auto& input_ids_shape = getInputShape(ctx, 0); auto& input_ids_dims = input_ids_shape.dim(); - if (input_ids_dims.size() != 2) { - fail_shape_inference("Inputs 0 shall be 2 dimensions"); + auto model_type_attr = ctx.getAttribute("model_type"); + int64_t model_type = model_type_attr ? static_cast(model_type_attr->i()) : -1; + if (model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper) { + if (input_ids_dims.size() != 3) + { + fail_shape_inference("Inputs 0 shall be 3 dimensions in whisper graph"); + } + if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value() && input_ids_dims[2].has_dim_value())) { + return; + } + } + else if (input_ids_dims.size() != 2) { + fail_shape_inference("Inputs 0 shall be 2 dimensions", model_type); } if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value())) { return; @@ -1071,7 +1083,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Size of the vocabulary. " "If not provided, it will be inferred from the decoder subgraph's output shape", AttributeProto::INT, static_cast(-1)) - .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I") + .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "F") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") @@ -1092,7 +1104,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) - .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("T", {"tensor(float)"}, "Constrain to float tensors.") + .TypeConstraint("F", {"tensor(float)", "tensor(int32)"}, "Constrain input type to float or int tensors.") .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 2e709d3235545..437a540390e34 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -349,6 +349,5 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end())); } } - } // namespace test } // namespace onnxruntime