diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index cb7823f06b4c2..8a5117bcf1e46 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -429,7 +429,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Inputs (5 - 10)
-- input_ids : I
+- input_ids : F
- The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)
- max_length : I
- The maximum length of the sequence to be generated. Shape is (1)
@@ -466,7 +466,9 @@ This version of the operator has been available since version 1 of the 'com.micr
- T : tensor(float)
-- Constrain input and output types to float tensors.
+- Constrain to float tensors.
+- F : tensor(float), tensor(int32)
+- Constrain input type to float or int tensors.
- I : tensor(int32)
- Constrain to integer types
- M : tensor(int32)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 050d84b19cc97..336ef560a9fa3 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -419,7 +419,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)|
-|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)|
+|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)|
@@ -786,7 +786,7 @@ Do not modify directly.*
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
-|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
+|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
index de2513789c508..a201a2f5d8edd 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
@@ -34,6 +34,7 @@
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
#include "contrib_ops/cpu/transformers/beam_search_impl_gpt.h"
#include "contrib_ops/cpu/transformers/beam_search_impl_t5.h"
+#include "contrib_ops/cpu/transformers/subgraph_whisper_encoder.h"
#include "contrib_ops/cpu/transformers/greedy_search_impl_gpt.h"
using namespace ONNX_NAMESPACE;
@@ -62,7 +63,8 @@ void BeamSearch::Init(const OpKernelInfo& info) {
// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt ||
- parameters_.model_type == IGenerationParameters::kModelTypeT5);
+ parameters_.model_type == IGenerationParameters::kModelTypeT5 ||
+ parameters_.model_type == IGenerationParameters::kModelTypeWhisper);
ONNX_NAMESPACE::GraphProto proto;
@@ -148,6 +150,37 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
t5_decoder_subgraph_->num_layers);
}
}
+ else if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) {
+ if (attribute_name == "encoder") {
+ ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
+ "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
+ t5_encoder_subgraph_ = std::make_unique(node,
+ attribute_name,
+ subgraph_session_state.GetGraphViewer());
+ ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
+ encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();
+
+ if (parameters_.decoder_start_token_id < 0) {
+ ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
+ "Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
+ } else {
+ ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
+ "Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
+ }
+ } else if (attribute_name == "decoder") {
+ ORT_ENFORCE(t5_decoder_subgraph_ == nullptr,
+ "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
+ t5_decoder_subgraph_ = std::make_unique(node,
+ attribute_name,
+ subgraph_session_state.GetGraphViewer());
+ ORT_RETURN_IF_ERROR(t5_decoder_subgraph_->Setup(session_state, subgraph_session_state));
+ decoder_feeds_fetches_manager_ = t5_decoder_subgraph_->GetFeedsFetchesManager();
+ parameters_.SetSubgraphParameters(t5_decoder_subgraph_->vocab_size,
+ t5_decoder_subgraph_->num_heads,
+ t5_decoder_subgraph_->head_size,
+ t5_decoder_subgraph_->num_layers);
+ }
+ }
return Status::OK();
}
@@ -224,45 +257,94 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
ORT_ENFORCE(encoder_session_state, "Subgraph SessionState was not found for 'encoder' attribute.");
ORT_ENFORCE(encoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
- // Subgraph has constraint that the output is either float or float16
- if (!t5_decoder_subgraph_->IsOutputFloat16()) {
- BeamSearchT5 impl{
- *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
- *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
- add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
- topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
- process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits,
- init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState,
- device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy,
- device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy,
- create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
- update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds,
- expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer,
- expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer,
- expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer};
- ORT_RETURN_IF_ERROR(impl.Initialize());
-
- return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
- } else {
- BeamSearchT5 impl{
- *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
- *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
- add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
- topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
- process_logits_fp16_func_,
- init_beam_state_fp16_func_,
- device_copy_func_,
- device_copy_int32_func_,
- create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
- update_decoder_feeds_fp16_func_,
- expand_buffer_int32_func_,
- expand_buffer_float_func_,
- expand_buffer_float16_func_};
-
- ORT_RETURN_IF_ERROR(impl.Initialize());
-
- return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
+
+ if (parameters_.model_type == IGenerationParameters::kModelTypeT5) {
+ // Subgraph has constraint that the output is either float or float16
+ if (!t5_decoder_subgraph_->IsOutputFloat16()) {
+ BeamSearchT5 impl{
+ *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
+ *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
+ add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
+ topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
+ process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits,
+ init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState,
+ device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy,
+ device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy,
+ create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
+ update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds,
+ expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer,
+ expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer,
+ expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer};
+ ORT_RETURN_IF_ERROR(impl.Initialize());
+
+ return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
+ } else {
+ BeamSearchT5 impl{
+ *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
+ *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
+ add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
+ topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
+ process_logits_fp16_func_,
+ init_beam_state_fp16_func_,
+ device_copy_func_,
+ device_copy_int32_func_,
+ create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
+ update_decoder_feeds_fp16_func_,
+ expand_buffer_int32_func_,
+ expand_buffer_float_func_,
+ expand_buffer_float16_func_};
+
+ ORT_RETURN_IF_ERROR(impl.Initialize());
+
+ return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
+ }
+ }
+
+ // Change the CreateEncoderInputs function for Whisper shapes
+ if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) {
+ // Subgraph has constraint that the output is either float or float16
+ if (!t5_decoder_subgraph_->IsOutputFloat16()) {
+ BeamSearchT5 impl{
+ *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
+ *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
+ add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
+ topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
+ process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits,
+ init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState,
+ device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy,
+ device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy,
+ create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs,
+ update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds,
+ expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer,
+ expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer,
+ expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer};
+ ORT_RETURN_IF_ERROR(impl.Initialize());
+
+ return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
+ } else {
+ BeamSearchT5 impl{
+ *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
+ *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
+ add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
+ topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
+ process_logits_fp16_func_,
+ init_beam_state_fp16_func_,
+ device_copy_func_,
+ device_copy_int32_func_,
+ create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateWhisperEncoderInputs,
+ update_decoder_feeds_fp16_func_,
+ expand_buffer_int32_func_,
+ expand_buffer_float_func_,
+ expand_buffer_float16_func_};
+
+ ORT_RETURN_IF_ERROR(impl.Initialize());
+
+ return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
+ }
}
+
+ // Model type not supported in IGenerationParameters
+ ORT_THROW("Model type is not supported.");
}
} // namespace transformers
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
index bd3a72e989af0..f79f9b1dbf1cf 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
@@ -31,7 +31,12 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
ORT_ENFORCE(context != nullptr);
const Tensor* input_ids = context->Input(0);
const auto& dims = input_ids->Shape().GetDims();
- ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size());
+ if (this->model_type == IGenerationParameters::kModelTypeWhisper){
+ ORT_ENFORCE(dims.size() == 3, "input_features shall have 3 dimensions. Got ", dims.size());
+ }
+ else {
+ ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size());
+ }
batch_size = static_cast(dims[0]);
// For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h
index c6e267d26e6df..1ac01d34209d1 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h
@@ -124,10 +124,17 @@ class GenerateBase {
const Tensor* attention_mask,
const Tensor* presence_mask) const {
const auto& dims = input_ids->Shape().GetDims();
- if (dims.size() != 2) {
+ if (parameters->model_type == IGenerationParameters::kModelTypeWhisper){
+ if (dims.size() != 3){
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'input_features' is expected to have 3 dimensions, got ", dims.size());
+ }
+
+ }
+ else if (dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'input_ids' is expected to have 2 dimensions, got ", dims.size());
- }
+ }
if (vocab_mask != nullptr) { // vocab_mask is optional
const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
@@ -174,7 +181,13 @@ class GenerateBase {
if (attention_mask != nullptr) {
const auto& dims_attn = attention_mask->Shape().GetDims();
- if (dims_attn.size() != 2) {
+ if (parameters->model_type == IGenerationParameters::kModelTypeWhisper) {
+ if (dims_attn.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'attention_mask' is expected to have 3 dimensions, got ", dims_attn.size());
+ }
+ }
+ else if (dims_attn.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to have 2 dimensions, got ", dims_attn.size());
}
@@ -183,7 +196,6 @@ class GenerateBase {
"Input 'attention_mask' is expected to have same shape as input_ids");
}
}
-
if (presence_mask != nullptr) {
const auto& dims_presence = presence_mask->Shape().GetDims();
if (dims_presence.size() != 2) {
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
index e63f4b377726f..e971da11f9dd5 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
@@ -831,6 +831,80 @@ Status UpdateDecoderFeeds(
return Status::OK();
}
+//------------------------------------------------
+// Modified Encoder functions for Whisper Model
+//------------------------------------------------
+
+Status CreateWhisperEncoderInputs(
+ const Tensor* original_encoder_input_features,
+ const OrtValue* attn_mask_value,
+ int pad_token_id,
+ int start_token_id,
+ AllocatorPtr allocator,
+ OrtValue& encoder_input_features,
+ OrtValue& encoder_attention_mask,
+ OrtValue& decoder_input_ids) {
+ const TensorShape& input_features_shape = original_encoder_input_features->Shape();
+ ORT_ENFORCE(input_features_shape.NumDimensions() == 3);
+ const int64_t& batch_size = input_features_shape[0];
+ const int64_t& sequence_length = input_features_shape[1];
+
+ // Allocate attention_mask based on shape of input_ids
+ auto element_type = DataTypeImpl::GetType();
+
+ // Use original encoder_input_ids. This requires the input_ids for subgraph is also int32.
+ // Current shape is (batch_size, sequence_length)
+ // Note that we will expand it to (batch_size * num_beams, sequence_length) later.
+ // To avoid cloning input_ids, we use const_cast here since this function does not change its content.
+ Tensor::InitOrtValue(DataTypeImpl::GetType(),
+ input_features_shape,
+ const_cast(original_encoder_input_features)->MutableData(),
+ allocator->Info(),
+ encoder_input_features);
+
+ if (attn_mask_value != nullptr) {
+ const Tensor& attention_mask = attn_mask_value->Get();
+ Tensor::InitOrtValue(element_type, input_features_shape, const_cast(&attention_mask)->MutableData(),
+ allocator->Info(), encoder_attention_mask);
+ } else {
+ auto mask_type = DataTypeImpl::GetType();
+ Tensor::InitOrtValue(mask_type, input_features_shape, allocator, encoder_attention_mask);
+
+ // Set attention mask to be 0 for pad tokens, and 1 for all other tokens.
+ int32_t* mask_data = encoder_attention_mask.GetMutable()->MutableData();
+ const int32_t* word_id = original_encoder_input_features->Data();
+ int32_t* mask = mask_data;
+ for (int i = 0; i < batch_size; i++) {
+ int32_t abs_position = 0;
+ for (int j = 0; j < sequence_length; j++, word_id++, mask++) {
+ // T5Tokenizer might add one EOS pad token at the end.
+ // That EOS token shall have attention mask 1 even when EOS token is same as pad token.
+ // Here we only set attention mask to be 0 for left padding only, so as to be parity with huggingface.
+ if (*word_id == pad_token_id && abs_position == 0) {
+ *mask = 0;
+ } else {
+ *mask = 1;
+ abs_position++;
+ }
+ }
+ }
+ }
+
+ // decoder_input_ids is optional.
+ if (start_token_id >= 0) {
+ // Filled decoder_input_ids with start token ID
+ int64_t dims[] = {batch_size, 1};
+ TensorShape decoder_input_ids_shape(&dims[0], 2);
+ Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, decoder_input_ids);
+ int32_t* data = decoder_input_ids.GetMutable()->MutableData();
+ for (int i = 0; i < batch_size; i++, data++) {
+ *data = start_token_id;
+ }
+ }
+
+ return Status::OK();
+}
+
//------------------------------------------------
// Explicit template instantiations of functions
//------------------------------------------------
@@ -950,4 +1024,4 @@ template Status ExpandBuffer(
} // namespace GenerationCpuDeviceHelper
} // namespace contrib
-} // namespace onnxruntime
\ No newline at end of file
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
index 3ad7be76a1800..66a1cea083a31 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
@@ -306,6 +306,21 @@ Status UpdateDecoderFeeds(
transformers::Sequences& sequences,
const transformers::IConsoleDumper* dumper);
+// ---------------------------------------------------------------
+// Functions for encoder-decoder model with float input like Whisper
+// ---------------------------------------------------------------
+
+Status CreateWhisperEncoderInputs(
+ const Tensor* original_encoder_input_features,
+ const OrtValue* attn_mask_value,
+ int pad_token_id,
+ int start_token_id,
+ AllocatorPtr allocator,
+ OrtValue& encoder_input_ids,
+ OrtValue& encoder_attention_mask,
+ OrtValue& decoder_input_ids);
+
+
// ---------------------------------------------------------------
// Utility Functions
// ---------------------------------------------------------------
@@ -323,4 +338,4 @@ Status ExpandBuffer(
} // namespace GenerationCpuDeviceHelper
} // namespace contrib
-} // namespace onnxruntime
\ No newline at end of file
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
index 3faba9a856273..79e9f04fad1e2 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
@@ -129,9 +129,10 @@ class IBeamScorer {
struct IGenerationParameters {
static constexpr int kModelTypeGpt = 0;
static constexpr int kModelTypeT5 = 1;
+ static constexpr int kModelTypeWhisper = 2;
// Parameters from node attributes
- int model_type; // 0 for GPT-2; 1 for encoder-decoder like T5
+ int model_type; // 0 for GPT-2; 1 for encoder-decoder like T5; 2 for float inputs like Whisper
int eos_token_id;
int pad_token_id;
int decoder_start_token_id;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
index 0460841ae155a..6c744b627d364 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
@@ -52,7 +52,7 @@ Status T5DecoderSubgraph::Validate(const std::vector& subgraph_i
"kFirstPastInputIndex currently only supports 2 or 3");
ORT_RETURN_IF(num_subgraph_inputs < 4 + first_past_input_index_ ||
(num_subgraph_inputs - first_past_input_index_) % 4 != 0,
- "number of outputs expected to be kFirstPastInputIndex + 4 * layers, got:", num_subgraph_inputs);
+ "number of inputs expected to be kFirstPastInputIndex + 4 * layers, got:", num_subgraph_inputs);
ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - first_present_output_index_) % 2 != 0,
"number of outputs expected to be 1 + 2 * layers, got:", num_subgraph_outputs);
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h
index a2e2e9842097a..a79f677f5a043 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h
@@ -40,7 +40,7 @@ class T5EncoderSubgraph : public Subgraph {
return first_present_output_index_;
}
- private:
+ protected:
int first_present_output_index_;
};
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc
new file mode 100644
index 0000000000000..8b2dde9518335
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.cc
@@ -0,0 +1,99 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/framework/framework_common.h"
+#include "core/framework/session_state.h"
+#include "core/framework/tensorprotoutils.h"
+#include "core/framework/utils.h"
+#include "core/providers/cpu/tensor/utils.h"
+#include "core/common/gsl.h"
+#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h"
+#include "contrib_ops/cpu/transformers/subgraph_whisper_encoder.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace transformers {
+
+/* Whisper Encoder Subgraph (It also contains decoder initialization where decoder_input_ids are filled with start token ID).
+
+ Inputs:
+ encoder_input_features: float (B, encode_sequence_length)
+ encoder_attention_mask: int32 (B, encode_sequence_length)
+ decoder_input_ids: int32 (B, 1)
+
+ Outputs:
+ logits: (B, 1, vocab_size)
+ encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
+
+ present_key_self_0: (B, num_heads, 1, head_size)
+ present_value_self_0: (B, num_heads, 1, head_size)
+ ... (for each self attention layer)
+
+ present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
+ present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
+ ... (for each cross attention layer)
+
+ Note:
+ Here, B = batch_size * num_beams since we expand the inputs.
+ Ideally, we could use B=batch_size and expand the outputs with a factor of num_beams.
+ Data type of input or output is float or float16 if not specified.
+*/
+
+Status WhisperEncoderSubgraph::Validate(const std::vector& subgraph_inputs,
+ const std::vector& subgraph_outputs) {
+ ORT_RETURN_IF(num_subgraph_inputs != 3, "expect 3 inputs, got:", num_subgraph_inputs);
+
+ ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs);
+ ORT_RETURN_IF((static_cast(subgraph_outputs.size()) - first_present_output_index_) % 4 != 0,
+ "number of outputs expected to be 2 + 4 * layers, got:", num_subgraph_outputs);
+
+ ORT_RETURN_IF(subgraph_inputs[0]->Name() != "encoder_input_ids",
+ "encoder subgraph input 0 shall be named as encoder_input_ids, got: ", subgraph_inputs[0]->Name());
+ ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask",
+ "encoder subgraph input 1 shall be named as encoder_attention_mask, got: ", subgraph_inputs[1]->Name());
+ ORT_RETURN_IF(subgraph_inputs[2]->Name() != "decoder_input_ids",
+ "encoder subgraph input 2 shall be named as decoder_input_ids, got: ", subgraph_inputs[2]->Name());
+
+ ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits",
+ "encoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name());
+ ORT_RETURN_IF(subgraph_outputs[1]->Name() != "encoder_hidden_states",
+ "encoder subgraph output 1 shall be named encoder_hidden_states, got: ", subgraph_outputs[1]->Name());
+ ORT_RETURN_IF(subgraph_outputs[2]->Name() != "present_key_self_0",
+ "encoder subgraph output 2 shall be named as present_key_self_0, got: ", subgraph_outputs[2]->Name());
+ ORT_RETURN_IF(subgraph_outputs[3]->Name() != "present_value_self_0",
+ "encoder subgraph output 3 shall be named as present_value_self_0, got: ", subgraph_outputs[3]->Name());
+
+ const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[2]->Shape();
+ const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
+
+ // Save parameters related to the subgraph.
+ ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false));
+ num_layers = (static_cast(subgraph_outputs.size()) - first_present_output_index_) / 4;
+
+ constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
+ constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
+ constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
+
+ ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float32_type,
+ "encoder subgraph input 0 (encoder_input_features) shall have float32 type");
+ ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
+ "encoder subgraph input 1 (encoder_attention_mask) shall have int32 type");
+ ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type,
+ "encoder subgraph input 2 (decoder_input_ids) shall have int32 type");
+
+ auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type();
+ ORT_RETURN_IF(output_type != float32_type && output_type != float16_type,
+ "encoder subgraph output 0 (logits) shall be float or float16 data type");
+
+ for (int i = 1; i < num_subgraph_outputs; i++) {
+ ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != output_type,
+ "encoder subgraph outputs 1, 2, ... shall have same data type");
+ }
+
+ is_output_float16_ = (output_type == float16_type);
+
+ return Status::OK();
+}
+} // namespace transformers
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h
new file mode 100644
index 0000000000000..c48f3f10e5f5f
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_encoder.h
@@ -0,0 +1,26 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "contrib_ops/cpu/transformers/subgraph_base.h"
+#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace transformers {
+
+// A class for whisper encoder subgraph with validation to support float inputs.
+class WhisperEncoderSubgraph : public T5EncoderSubgraph {
+ public:
+ WhisperEncoderSubgraph(
+ const onnxruntime::Node& node_in,
+ const std::string& attribute_name,
+ const GraphViewer& subgraph_in) : T5EncoderSubgraph(node_in, attribute_name, subgraph_in) {}
+
+ Status Validate(const std::vector& subgraph_inputs,
+ const std::vector& subgraph_outputs) override;
+};
+} // namespace transformers
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 5ad939c8b8711..728de78de31a6 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -17,6 +17,7 @@
#include "core/graph/op.h"
#include "core/mlas/inc/mlas.h"
#include "core/graph/contrib_ops/onnx_function_util.h"
+#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
#include "onnx/defs/function.h"
// Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from
// ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build
@@ -421,8 +422,19 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
}
auto& input_ids_shape = getInputShape(ctx, 0);
auto& input_ids_dims = input_ids_shape.dim();
- if (input_ids_dims.size() != 2) {
- fail_shape_inference("Inputs 0 shall be 2 dimensions");
+ auto model_type_attr = ctx.getAttribute("model_type");
+ int64_t model_type = model_type_attr ? static_cast(model_type_attr->i()) : -1;
+ if (model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper) {
+ if (input_ids_dims.size() != 3)
+ {
+ fail_shape_inference("Inputs 0 shall be 3 dimensions in whisper graph");
+ }
+ if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value() && input_ids_dims[2].has_dim_value())) {
+ return;
+ }
+ }
+ else if (input_ids_dims.size() != 2) {
+ fail_shape_inference("Inputs 0 shall be 2 dimensions", model_type);
}
if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value())) {
return;
@@ -1071,7 +1083,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
"Size of the vocabulary. "
"If not provided, it will be inferred from the decoder subgraph's output shape",
AttributeProto::INT, static_cast(-1))
- .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I")
+ .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "F")
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
.Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
.Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I")
@@ -1092,7 +1104,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
"Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam."
"Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)",
"T", OpSchema::Optional)
- .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.")
+ .TypeConstraint("T", {"tensor(float)"}, "Constrain to float tensors.")
+ .TypeConstraint("F", {"tensor(float)", "tensor(int32)"}, "Constrain input type to float or int tensors.")
.TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc
index 2e709d3235545..437a540390e34 100644
--- a/onnxruntime/test/contrib_ops/beam_search_test.cc
+++ b/onnxruntime/test/contrib_ops/beam_search_test.cc
@@ -349,6 +349,5 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) {
ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end()));
}
}
-
} // namespace test
} // namespace onnxruntime