diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 12fae5ccf0983..7962662ff6088 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -139,13 +139,19 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state)); encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager(); - if (parameters_->decoder_start_token_id < 0) { - ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2, - "Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty"); + if (!t5_encoder_subgraph_->HasLogitsOutput()) { + // New format requires start token id. + ORT_ENFORCE(parameters_->decoder_start_token_id >= 0); } else { - ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3, - "Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available"); + if (parameters_->decoder_start_token_id < 0) { + ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2, + "Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty"); + } else { + ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3, + "Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available"); + } } + } else if (attribute_name == "decoder") { ORT_ENFORCE(t5_decoder_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index b67d003eaceeb..c9646cf0fab2e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -51,7 +51,13 @@ class BeamSearchT5 : public BeamSearchBase { expand_buffer_int32_func_(expand_buffer_int32_func), expand_buffer_float_func_(expand_buffer_float_func), expand_buffer_float16_func_(expand_buffer_float16_func), - create_beam_scorer_func_(create_beam_scorer_func) {} + create_beam_scorer_func_(create_beam_scorer_func) { + // When decoder uses encoder_hidden_state, make sure the encoder outputs it. + if (decoder_subgraph_.UseEncoderHiddenState()) { + ORT_ENFORCE(encoder_subgraph_.subgraph_output_names[1] == "encoder_hidden_states"); + } + ORT_ENFORCE(encoder_subgraph_.num_layers == decoder_subgraph_.num_layers); + } #ifdef USE_CUDA Status InitializeCuda( @@ -160,7 +166,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches this->create_encoder_inputs_func_, this->add_to_feeds_func_, buffer, - decoder_input_ids, + decoder_input_ids, // new format does not use decoder_input_ids in encoder, it is still initialized here when decoder_start_token_id >= 0. this->ort_stream_)); #ifdef DEBUG_NODE_INPUTS_OUTPUTS @@ -233,35 +239,47 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches std::vector decoder_fetches; - if (current_length + 1 < parameters->max_length) { + // When encoder outputs logits (in old format), we need get the next token from logits. + if (current_length + 1 < parameters->max_length && encoder_subgraph_.HasLogitsOutput()) { ++iteration_counter; - ORT_RETURN_IF_ERROR(this->GenerateNextToken(encoder_fetches[0], + const OrtValue& logits = encoder_fetches[0]; + ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits, beam_next_tokens, beam_state, cpu_state, iteration_counter)); ++current_length; // Increase sequence length after a new token is generated. + } - ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(this->cpu_allocator_, - ReinterpretAsSpan(beam_next_tokens), - this->implicit_inputs_, - encoder_feeds, - encoder_fetches, - decoder_feeds, - this->device_copy_int32_func_, - this->expand_buffer_int32_func_, - this->expand_buffer_float_func_, - this->expand_buffer_float16_func_, - parameters->num_beams, - this->ort_stream_, - decoder_subgraph_.UseSequenceAsInputIds(), - current_length, - cpu_state.sequences, - parameters->max_length, - decoder_subgraph_.has_decoder_masked_attention_, - this->cuda_device_prop_ != nullptr)); + if (current_length < parameters->max_length) { + // when no logits, copy sequence (filled with start token IDs) to input_ids for decoder. + bool copy_sequence_to_input_ids = decoder_subgraph_.UseSequenceAsInputIds() || !encoder_subgraph_.HasLogitsOutput(); + if (copy_sequence_to_input_ids) { + ORT_ENFORCE(current_length == cpu_state.sequences.GetSequenceLength()); + } + + // Generate inputs for next decoder subgraph call. + ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds( + this->cpu_allocator_, + ReinterpretAsSpan(beam_next_tokens), + this->implicit_inputs_, + encoder_feeds, + encoder_fetches, + decoder_feeds, + this->device_copy_int32_func_, + this->expand_buffer_int32_func_, + this->expand_buffer_float_func_, + this->expand_buffer_float16_func_, + parameters->num_beams, + this->ort_stream_, + copy_sequence_to_input_ids, + cpu_state.sequences, + parameters->max_length, + decoder_subgraph_.has_decoder_masked_attention_, + this->cuda_device_prop_ != nullptr)); if (decoder_subgraph_.past_present_share_buffer_) { + // Configure buffer sharing of past and present kv cache. decoder_fetches.reserve(static_cast(decoder_subgraph_.GetFirstPresentOutputIndex()) + 2 * static_cast(decoder_subgraph_.num_layers)); decoder_fetches.resize(decoder_subgraph_.GetFirstPresentOutputIndex(), OrtValue()); @@ -299,14 +317,19 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches while (current_length < parameters->max_length) { iteration_counter++; + #ifdef DEBUG_GENERATION - auto cur_len = std::to_string(current_length); - dumper->Print("***CurrentLength", cur_len, true); + dumper->Print(::onnxruntime::MakeString("Iteration=", iteration_counter, + ", CurrentLength=", current_length, + ", num_layers=", decoder_subgraph_.num_layers, + ", decoder_feeds=", decoder_feeds.size(), + ", start_token_id=", parameters->decoder_start_token_id)); for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) { dumper->Print("decoder_feeds", i, true); dumper->Print("", decoder_feeds[i]); } + for (int i = 0; i < decoder_subgraph_.num_layers; i++) { int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i; int self_value_idx = self_key_idx + 1; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index 7757435990a65..537d066b264a1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -36,12 +36,9 @@ Subgraph::Subgraph( auto& subgraph_inputs = subgraph.GetInputs(); auto& subgraph_outputs = subgraph.GetOutputs(); - // inputs: input_ids, position_ids, attention_mask, past_0, past_1, ... - // outputs: logits, present_0, present_1, ... num_subgraph_inputs = static_cast(subgraph_inputs.size()); num_subgraph_outputs = static_cast(subgraph_outputs.size()); - // CheckSubgraph will verify inputs and outputs later. subgraph_input_names.reserve(num_subgraph_inputs); for (int i = 0; i < num_subgraph_inputs; ++i) { subgraph_input_names.push_back(subgraph_inputs[i]->Name()); @@ -68,10 +65,9 @@ Status Subgraph::Setup(const SessionState& session_state, InlinedVector feed_names; feed_names.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); - // Use the first output (logits) to find device location. + // Use the first output to find device location. const OrtDevice& default_location = utils::FindDeviceForValue(subgraph_session_state, subgraph_output_names[0]); - // The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter. feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end()); const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap(); @@ -174,13 +170,15 @@ Status Subgraph::GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shap } // Logits shape is like (batch_size, seq_len, vocabulary_size) - ORT_RETURN_IF(logits_shape->dim_size() != 3, - "subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size()); + if (logits_shape != nullptr) { + ORT_RETURN_IF(logits_shape->dim_size() != 3, + "subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size()); - ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0, - "subgraph past state dimension 2 shall have a positive value for vocabulary size"); + ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for vocabulary size"); - this->vocab_size = static_cast(logits_shape->dim(2).dim_value()); + this->vocab_size = static_cast(logits_shape->dim(2).dim_value()); + } return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 997beb198f450..09bce9828aa33 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -6,11 +6,12 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" -#include #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/utils/dump_tensor.h" #include "contrib_ops/cpu/transformers/generation_device_helper.h" #include "contrib_ops/cpu/transformers/sequences.h" +#include +#include namespace onnxruntime { namespace contrib { @@ -20,9 +21,9 @@ namespace transformers { Inputs: input_ids: int32 (B, 1) - encoder_input_ids: int32 (B, encode_sequence_length) (optional) + encoder_input_ids: int32 (B, encode_sequence_length) (optional for old format; removed in new format) encoder_attention_mask: int32 (B, encode_sequence_length) - encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional) + encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional for old format; removed in new format) past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size) past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size) @@ -141,14 +142,23 @@ Status T5DecoderSubgraph::Validate(const std::vector& subgraph_i } // Create inputs for decoder from the following data sources: -// encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens) -// encoder fetches: logits, -// encoder_hidden_states, -// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ... -// decoder_feeds: input_ids, -// encoder_attention_mask, -// encoder_hidden_states, -// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ... +// New format: +// encoder feeds: encoder_input_ids, encoder_attention_mask +// encoder fetches: present_key_cross_0, present_value_cross_0, ... +// decoder_feeds: input_ids, encoder_attention_mask, +// present_key_self_0, present_value_self_0, ..., +// present_key_cross_0, present_value_cross_0, ... +// past_seq_len (optional), num_beams (optional), cache_indirection (optional) +// +// Old format: +// encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens) +// encoder fetches: logits, encoder_hidden_states, +// present_key_self_0, present_value_self_0, ..., +// present_key_cross_0, present_value_cross_0, ... +// decoder_feeds: input_ids, encoder_input_ids (optional), encoder_attention_mask, encoder_hidden_states (optional), +// present_key_self_0, present_value_self_0, ..., +// present_key_cross_0, present_value_cross_0, ... +// past_seq_len (optional), num_beams (optional), cache_indirection (optional) Status T5DecoderSubgraph::CreateInitialFeeds( AllocatorPtr cpu_allocator, gsl::span beam_next_tokens, @@ -162,8 +172,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, int num_beam, Stream* stream, - bool use_sequence_as_input_ids, - int cur_len, + bool copy_sequence_to_input_ids, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len, bool need_cache_indir, @@ -173,34 +182,30 @@ Status T5DecoderSubgraph::CreateInitialFeeds( // Allocate subgraph inputs from same device as inputs of encoder subgraph. AllocatorPtr allocator = session_state_->GetAllocator(encoder_feeds[0].Get().Location()); + int batch_beam_size = static_cast(encoder_fetches[0].Get().Shape()[0]) * num_beam; + // Copy beam next tokens in CPU to input_ids in provider device (CPU for CPU EP, or GPU for CUDA EP). - int batch_beam_size = static_cast(beam_next_tokens.size()); - int sequence_length = !use_sequence_as_input_ids ? 1 : cur_len; + int sequence_length = !copy_sequence_to_input_ids ? 1 : sequences.GetSequenceLength(); int64_t dims[] = {batch_beam_size, sequence_length}; TensorShape input_ids_shape(&dims[0], 2); OrtValue input_ids; Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); - int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); - AllocatorPtr buffer_allocator = std::make_shared(); - size_t total_size = static_cast(cur_len) * static_cast(batch_beam_size); - size_t total_size_bytes = total_size * sizeof(int); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size_bytes, false, stream); - int* seq_copy_ptr = seq_copy.get(); - - if (!use_sequence_as_input_ids_) { + + // Prepare data for input_ids. + if (!copy_sequence_to_input_ids) { // use next tokens for input_ids. ORT_RETURN_IF_ERROR(device_copy_int32_func( input_ids.GetMutable()->MutableDataAsSpan(), beam_next_tokens, stream, DeviceCopyDirection::hostToDevice)); - } else { + } else { // use whole sequences for input_ids. + int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); if (use_cuda) { auto sequences_buffer = sequences.GetCurrentDeviceSequences(); for (int i = 0; i < batch_beam_size; i++) { - size_t batch_beam_stride = static_cast(i) * static_cast(sequences.GetMaxLength()); - int seq_size = sequences.GetSequenceLength(); - gsl::span sequence = sequences_buffer.subspan(batch_beam_stride, seq_size); - gsl::span temp_input(input_ids_data + static_cast(i) * seq_size, seq_size); + size_t offset = static_cast(i) * static_cast(sequences.GetMaxLength()); + gsl::span sequence = sequences_buffer.subspan(offset, sequence_length); + gsl::span temp_input(input_ids_data + static_cast(i) * sequence_length, sequence_length); ORT_RETURN_IF_ERROR(device_copy_int32_func( temp_input, sequence, @@ -208,12 +213,19 @@ Status T5DecoderSubgraph::CreateInitialFeeds( DeviceCopyDirection::deviceToDevice)); } } else { - const size_t cur_len_bytes = cur_len * sizeof(int); + size_t total_size = static_cast(sequence_length) * static_cast(batch_beam_size); + size_t total_size_bytes = total_size * sizeof(int); + AllocatorPtr buffer_allocator = std::make_shared(); + // TODO: not need extra buffer. Copy directly to input_ids_data instead like the user_cuda above. + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size_bytes, false, stream); + int* seq_copy_ptr = seq_copy.get(); + + const size_t sequence_bytes = sequence_length * sizeof(int); for (int i = 0; i < batch_beam_size; i++) { gsl::span sequence = sequences.GetSequence(i); const int32_t* sequence_data = sequence.data(); - ptrdiff_t seq_index = static_cast(i) * cur_len; - memcpy(seq_copy_ptr + seq_index, sequence_data, cur_len_bytes); + ptrdiff_t seq_index = static_cast(i) * sequence_length; + memcpy(seq_copy_ptr + seq_index, sequence_data, sequence_bytes); } gsl::span temp_input(input_ids_data, total_size); gsl::span temp_sequence(seq_copy_ptr, total_size); @@ -227,9 +239,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds( // The ordering is the same as used in Setup. decoder_feeds.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); + + // input 0: input_ids decoder_feeds.push_back(input_ids); - if (has_encoder_input_ids_) { + if (has_encoder_input_ids_) { // encoder_input_ids is optional // The encoder_input_ids is copied from the first input of encoder. OrtValue expanded_encoder_input_ids; ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream, @@ -251,70 +265,66 @@ Status T5DecoderSubgraph::CreateInitialFeeds( expanded_decoder_attention_masks, false, 0 /*max_sequence_length*/)); - decoder_feeds.push_back(expanded_decoder_attention_masks); if (!past_present_share_buffer_) { past_present_share_buffer_max_seq_len = 0; } - // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output - // of encoder. - // When first_past_input_index_ == 2, the past states are copied from the second output of encoder. - // TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions. - // What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds? - for (size_t j = static_cast(2) - has_hidden_state_; j < encoder_fetches.size(); j++) { - if (j == 1) { - ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false"); - OrtValue expanded_hidden_states; - if (is_output_float16_) { - ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream, - encoder_fetches[j], - num_beam, - allocator, - expanded_hidden_states, - false, - 0 /*max_sequence_length*/)); - } else { - ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, - encoder_fetches[j], - num_beam, - allocator, - expanded_hidden_states, - false, - 0 /*max_sequence_length*/)); - } - decoder_feeds.push_back(expanded_hidden_states); - } else { +// macro to expand encoder outputs and append to decoder feeds. +#define ADD_DECODER_FEED(encoder_output, is_dynamic_kv_cache) \ + OrtValue expanded; \ + if (is_output_float16_) { \ + ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream, encoder_output, num_beam, allocator, expanded, false, \ + is_dynamic_kv_cache ? past_present_share_buffer_max_seq_len : 0)); \ + } else { \ + ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, encoder_output, num_beam, allocator, expanded, false, \ + is_dynamic_kv_cache ? past_present_share_buffer_max_seq_len : 0)); \ + } \ + decoder_feeds.push_back(expanded); + + // The encoder_hidden_states is copied from the second output of encoder. + if (has_hidden_state_) { + ADD_DECODER_FEED(encoder_fetches[1], false); + } + + // New format of encoder has only cross outputs. + bool is_new_format = (static_cast(encoder_fetches.size()) == 2 * num_layers); + if (is_new_format) { + for (int i = 0; i < 2 * num_layers; i++) { + // cross shape is (batch_size, num_heads, encode_sequence_length, head_size) + const TensorShape& cross_shape = encoder_fetches[0].Get().Shape(); + ORT_ENFORCE(cross_shape.NumDimensions() == 4); + + // Shape for kv cache: (batch_size * num_beam, num_heads, max_seq_len, head_size) + int64_t cache_dims[4] = {0}; + cross_shape.CopyDims(cache_dims, cross_shape.NumDimensions()); + cache_dims[0] *= num_beam; + cache_dims[2] = past_present_share_buffer_max_seq_len; + TensorShape expanded_shape(&cache_dims[0], cross_shape.NumDimensions()); + + MLDataType element_type = encoder_fetches[0].Get().DataType(); + OrtValue past; + Tensor::InitOrtValue(element_type, expanded_shape, allocator, past); + decoder_feeds.push_back(past); + } + + // Add cross inputs from encoder output. + for (size_t j = 0; j < encoder_fetches.size(); j++) { + ADD_DECODER_FEED(encoder_fetches[j], false); + } + } else { + // present_* output of encoder are added as decoder inputs. + for (size_t j = 2; j < encoder_fetches.size(); j++) { // past key/value for cross attention does not need to be initialized with max_seq_len since they are static. - bool use_max_seq_len = (j - first_past_input_index_) < 2 * static_cast(num_layers); - - OrtValue expanded_cache; - if (is_output_float16_) { - ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream, - encoder_fetches[j], - num_beam, - allocator, - expanded_cache, - false, - use_max_seq_len ? past_present_share_buffer_max_seq_len : 0)); - } else { - ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, - encoder_fetches[j], - num_beam, - allocator, - expanded_cache, - false, - use_max_seq_len ? past_present_share_buffer_max_seq_len : 0)); - } - decoder_feeds.push_back(expanded_cache); + bool is_dynamic_kv_cache = (j - first_past_input_index_) < 2 * static_cast(num_layers); + ADD_DECODER_FEED(encoder_fetches[j], is_dynamic_kv_cache); } } - // TODO: This part shares the similar logic with CreateInitialFeeds() in subgraph_gpt.cc. We should refactor it. if (past_present_share_buffer_) { - // Past sequence length feed - ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, 1)); + // Past sequence length set to 0 + ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, is_new_format ? 0 : 1)); // Add beam search specific inputs if (need_cache_indir) { const int64_t batch_size = static_cast(batch_beam_size / num_beam); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index b5d727b67924c..87782d47cdbe1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -45,7 +45,6 @@ class T5DecoderSubgraph : public Subgraph { int num_beam, Stream* stream, bool use_sequence_as_input_ids, - int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len = -1, bool need_cache_indir = false, @@ -72,6 +71,10 @@ class T5DecoderSubgraph : public Subgraph { return use_sequence_as_input_ids_; } + inline bool UseEncoderHiddenState() const { + return has_hidden_state_; + } + protected: int first_past_input_index_; int first_present_output_index_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc index d59db4afac2c2..a54c0d960980c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -15,70 +15,97 @@ namespace transformers { /* T5 Encoder Subgraph (It also contains decoder initialization where decoder_input_ids are filled with start token ID). - Inputs: + New format: + Inputs: encoder_input_ids: int32 (B, encode_sequence_length) encoder_attention_mask: int32 (B, encode_sequence_length) - decoder_input_ids: int32 (B, 1) Outputs: - logits: (B, 1, vocab_size) - encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) - - present_key_self_0: (B, num_heads, 1, head_size) - present_value_self_0: (B, num_heads, 1, head_size) - ... (for each self attention layer) - present_key_cross_0: (B, num_heads, encode_sequence_length, head_size) present_value_cross_0: (B, num_heads, encode_sequence_length, head_size) ... (for each cross attention layer) - Note: - Here, B = batch_size * num_beams since we expand the inputs. - Ideally, we could use B=batch_size and expand the outputs with a factor of num_beams. - Data type of input or output is float or float16 if not specified. + Old format: + Inputs: + encoder_input_ids: int32 (B, encode_sequence_length) + encoder_attention_mask: int32 (B, encode_sequence_length) + decoder_input_ids: int32 (B, 1) + + Outputs: + logits: (B, 1, vocab_size) + encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + + present_key_self_0: (B, num_heads, 1, head_size) + present_value_self_0: (B, num_heads, 1, head_size) + ... (for each self attention layer) + + present_key_cross_0: (B, num_heads, encode_sequence_length, head_size) + present_value_cross_0: (B, num_heads, encode_sequence_length, head_size) + ... (for each cross attention layer) + + Note: + Here, B = batch_size * num_beams since we expand the inputs. + Ideally, we could use B=batch_size and expand the outputs with a factor of num_beams. + Data type of input or output is float or float16 if not specified. */ Status T5EncoderSubgraph::Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) { - ORT_RETURN_IF(num_subgraph_inputs != 3, "expect 3 inputs, got:", num_subgraph_inputs); - - ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs); - ORT_RETURN_IF((static_cast(subgraph_outputs.size()) - first_present_output_index_) % 4 != 0, - "number of outputs expected to be 2 + 4 * layers, got:", num_subgraph_outputs); + constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32; + constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT; + constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16; + ORT_RETURN_IF(num_subgraph_inputs != 2 && num_subgraph_inputs != 3, "expect 2 or 3 inputs, got:", num_subgraph_inputs); ORT_RETURN_IF(subgraph_inputs[0]->Name() != "encoder_input_ids", "encoder subgraph input 0 shall be named as encoder_input_ids, got: ", subgraph_inputs[0]->Name()); ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask", "encoder subgraph input 1 shall be named as encoder_attention_mask, got: ", subgraph_inputs[1]->Name()); - ORT_RETURN_IF(subgraph_inputs[2]->Name() != "decoder_input_ids", - "encoder subgraph input 2 shall be named as decoder_input_ids, got: ", subgraph_inputs[2]->Name()); - - ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", - "encoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name()); - ORT_RETURN_IF(subgraph_outputs[1]->Name() != "encoder_hidden_states", - "encoder subgraph output 1 shall be named encoder_hidden_states, got: ", subgraph_outputs[1]->Name()); - ORT_RETURN_IF(subgraph_outputs[2]->Name() != "present_key_self_0", - "encoder subgraph output 2 shall be named as present_key_self_0, got: ", subgraph_outputs[2]->Name()); - ORT_RETURN_IF(subgraph_outputs[3]->Name() != "present_value_self_0", - "encoder subgraph output 3 shall be named as present_value_self_0, got: ", subgraph_outputs[3]->Name()); - - const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[2]->Shape(); - const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); - - // Save parameters related to the subgraph. - ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false)); - num_layers = (static_cast(subgraph_outputs.size()) - first_present_output_index_) / 4; - - constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32; - constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT; - constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16; ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type, "encoder subgraph input 0 (encoder_input_ids) shall have int32 type"); ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type, "encoder subgraph input 1 (encoder_attention_mask) shall have int32 type"); - ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type, - "encoder subgraph input 2 (decoder_input_ids) shall have int32 type"); + + if (num_subgraph_inputs == 2) { + ORT_RETURN_IF(num_subgraph_outputs < 2 || num_subgraph_outputs % 2 != 0, + "number of outputs expected to be 2 * layers, got:", num_subgraph_outputs); + + ORT_RETURN_IF(subgraph_outputs[0]->Name() != "present_key_cross_0", + "encoder subgraph output 0 shall be named as present_key_cross_0, got: ", subgraph_outputs[0]->Name()); + ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_value_cross_0", + "encoder subgraph output 1 shall be named as present_value_cross_0, got: ", subgraph_outputs[1]->Name()); + + // Deduce num_heads and head_size parameters from shape of graph outputs + const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[0]->Shape(); + const ONNX_NAMESPACE::TensorShapeProto* logits_shape = nullptr; + ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false)); + + num_layers = num_subgraph_outputs / 2; + } else { + ORT_RETURN_IF(num_subgraph_outputs < 6 || (num_subgraph_outputs - first_present_output_index_) % 4 != 0, + "number of outputs expected to be 2 + 4 * layers, got:", num_subgraph_outputs); + + ORT_RETURN_IF(subgraph_inputs[2]->Name() != "decoder_input_ids", + "encoder subgraph input 2 shall be named as decoder_input_ids, got: ", subgraph_inputs[2]->Name()); + ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "encoder subgraph input 2 (decoder_input_ids) shall have int32 type"); + + ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", + "encoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name()); + ORT_RETURN_IF(subgraph_outputs[1]->Name() != "encoder_hidden_states", + "encoder subgraph output 1 shall be named encoder_hidden_states, got: ", subgraph_outputs[1]->Name()); + ORT_RETURN_IF(subgraph_outputs[2]->Name() != "present_key_self_0", + "encoder subgraph output 2 shall be named as present_key_self_0, got: ", subgraph_outputs[2]->Name()); + ORT_RETURN_IF(subgraph_outputs[3]->Name() != "present_value_self_0", + "encoder subgraph output 3 shall be named as present_value_self_0, got: ", subgraph_outputs[3]->Name()); + + // Deduce num_heads, head_size and vocab_size from shape of graph outputs + const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[2]->Shape(); + const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); + ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false)); + + num_layers = (num_subgraph_outputs - first_present_output_index_) / 4; + } auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type(); ORT_RETURN_IF(output_type != float32_type && output_type != float16_type, @@ -86,7 +113,7 @@ Status T5EncoderSubgraph::Validate(const std::vector& subgraph_i for (int i = 1; i < num_subgraph_outputs; i++) { ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != output_type, - "encoder subgraph outputs 1, 2, ... shall have same data type"); + "encoder subgraph outputs shall have same data type"); } is_output_float16_ = (output_type == float16_type); @@ -120,7 +147,6 @@ Status T5EncoderSubgraph::CreateInitialFeeds( } ORT_RETURN_IF(cpu_allocator == nullptr, "cpu_allocator shouldn't be nullptr"); - // TODO(tianleiwu): expand the outputs instead of inputs to save computation. OrtValue encoder_input_ids; OrtValue encoder_attention_mask; ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&original_encoder_input_ids, @@ -136,9 +162,10 @@ Status T5EncoderSubgraph::CreateInitialFeeds( AllocatorPtr default_allocator = session_state_->GetAllocator(provider->GetOrtDeviceByMemType(OrtMemTypeDefault)); AllocatorPtr pinned_allocator = session_state_->GetAllocator(provider->GetOrtDeviceByMemType(OrtMemTypeCPU)); const OrtMemoryInfo& location = default_allocator->Info(); + ORT_RETURN_IF_ERROR(add_to_feeds_func( ort_stream, - {encoder_input_ids, encoder_attention_mask, decoder_input_ids}, + num_subgraph_inputs == 2 ? std::initializer_list{encoder_input_ids, encoder_attention_mask} : std::initializer_list{encoder_input_ids, encoder_attention_mask, decoder_input_ids}, feeds, buffer, default_allocator, diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h index a79f677f5a043..33fd522bdfd82 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h @@ -16,7 +16,11 @@ class T5EncoderSubgraph : public Subgraph { const onnxruntime::Node& node_in, const std::string& attribute_name, const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) { - first_present_output_index_ = 2; + has_logits_output_ = num_subgraph_outputs > 0 && subgraph_output_names[0] == "logits"; + + // Old format: The first output is logits, the second one is encoder_hidden_states. + // New format: No logits and encoder_hidden_states. All outputs are cross. + first_present_output_index_ = HasLogitsOutput() ? 2 : 0; } // Create inputs for first inference of subgraph. @@ -36,11 +40,18 @@ class T5EncoderSubgraph : public Subgraph { Status Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) override; +#ifdef DEBUG_GENERATION int GetFirstPresentOutputIndex() const { return first_present_output_index_; } +#endif + + bool HasLogitsOutput() const { + return has_logits_output_; + } protected: + bool has_logits_output_; int first_present_output_index_; }; diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 2a210729112d7..3dd2c2ef945ec 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -88,61 +88,62 @@ def create_onnxruntime_session( enable_mlas_gemm_fastmath_arm64_bfloat16=False, provider_options={}, # map execution provider name to its option # noqa: B006 ): - session = None - try: - sess_options = onnxruntime.SessionOptions() + sess_options = onnxruntime.SessionOptions() - if enable_all_optimization: - sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - else: - sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC + if enable_all_optimization: + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + else: + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC - if enable_profiling: - sess_options.enable_profiling = True + if enable_profiling: + sess_options.enable_profiling = True - if num_threads > 0: - sess_options.intra_op_num_threads = num_threads - logger.debug(f"Session option: intra_op_num_threads={sess_options.intra_op_num_threads}") + if num_threads > 0: + sess_options.intra_op_num_threads = num_threads + logger.debug(f"Session option: intra_op_num_threads={sess_options.intra_op_num_threads}") - if verbose: - sess_options.log_severity_level = 0 - else: - sess_options.log_severity_level = 4 - - logger.debug(f"Create session for onnx model: {onnx_model_path}") - if use_gpu: - if provider == "dml": - providers = ["DmlExecutionProvider", "CPUExecutionProvider"] - elif provider == "rocm": - providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] - elif provider == "migraphx": - providers = [ - "MIGraphXExecutionProvider", - "ROCMExecutionProvider", - "CPUExecutionProvider", - ] - elif provider == "cuda": - providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] - elif provider == "tensorrt": - providers = [ - "TensorrtExecutionProvider", - "CUDAExecutionProvider", - "CPUExecutionProvider", - ] - else: - providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + if verbose: + sess_options.log_severity_level = 0 + else: + sess_options.log_severity_level = 4 + + if provider in onnxruntime.get_available_providers(): + providers = [provider] + elif use_gpu: + if provider == "dml": + providers = ["DmlExecutionProvider", "CPUExecutionProvider"] + elif provider == "rocm": + providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] + elif provider == "migraphx": + providers = [ + "MIGraphXExecutionProvider", + "ROCMExecutionProvider", + "CPUExecutionProvider", + ] + elif provider == "cuda" or provider is None: + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + elif provider == "tensorrt": + providers = [ + "TensorrtExecutionProvider", + "CUDAExecutionProvider", + "CPUExecutionProvider", + ] else: - providers = ["CPUExecutionProvider"] + raise RuntimeError(f"The execution provider is not supported: {provider}") + else: + providers = ["CPUExecutionProvider"] - if provider_options: - providers = [(name, provider_options[name]) if name in provider_options else name for name in providers] + if provider_options: + providers = [(name, provider_options[name]) if name in provider_options else name for name in providers] - if enable_mlas_gemm_fastmath_arm64_bfloat16: - sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1") + if enable_mlas_gemm_fastmath_arm64_bfloat16: + sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1") + session = None + try: session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers) except Exception: - logger.error("Exception", exc_info=True) # noqa: G201 + logger.exception(f"Failed to create session for {onnx_model_path} with providers={providers}") return session diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 045910ea20828..8eb2afb3db896 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -16,19 +16,17 @@ python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx --use_gpu -p fp16 --use_sln_strict_mode Example 4: convert T5 model with beam search in two steps: - cd ./models/t5 - python convert_to_onnx.py -m t5-small - cd ../.. - python convert_generation.py -m t5-small --model_type t5 \ - --decoder_onnx ./models/t5/onnx_models/t5-small_decoder.onnx \ - --encoder_decoder_init_onnx ./models/t5/onnx_models/t5-small_encoder_decoder_init.onnx \ - --output ./models/t5/onnx_models/t5_small_beam_search.onnx + python -m models.t5.convert_to_onnx -m t5-small + python convert_generation.py -m t5-small --model_type t5 \ + --decoder_onnx ./onnx_models/t5-small_decoder.onnx \ + --encoder_decoder_init_onnx ./onnx_models/t5-small_encoder.onnx \ + --output ./onnx_models/t5_small_beam_search.onnx Example 5: convert T5 model with beam search. All in one step: - python convert_generation.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx + python convert_generation.py -m t5-small --model_type t5 --output t5_small_beam_search.onnx Example 6: convert T5 model with beam search containing specific cuda optimizations. All in one step: - python convert_generation.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx \ + python convert_generation.py -m t5-small --model_type t5 --output t5_small_beam_search.onnx \ --use_gpu --past_present_share_buffer --use_decoder_masked_attention Example 7: convert MT5 model with external data file like mt5-base-beamsearch.onnx.data in below example. @@ -68,11 +66,23 @@ T5Tokenizer, ) -from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers -from onnxruntime.transformers.models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx +from onnxruntime import ( + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, +) +from onnxruntime.transformers.models.gpt2.convert_to_onnx import ( + main as convert_gpt2_to_onnx, +) from onnxruntime.transformers.models.gpt2.gpt2_helper import PRETRAINED_GPT2_MODELS -from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models -from onnxruntime.transformers.models.t5.t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS +from onnxruntime.transformers.models.t5.convert_to_onnx import ( + export_onnx_models as export_t5_onnx_models, +) +from onnxruntime.transformers.models.t5.t5_helper import ( + PRETRAINED_MT5_MODELS, + PRETRAINED_T5_MODELS, +) logger = logging.getLogger("") @@ -162,9 +172,9 @@ def parse_arguments(argv: list[str] | None = None) -> argparse.Namespace: "-p", "--precision", required=False, - type=Precision, - default=Precision.FLOAT32, - choices=[Precision.FLOAT32, Precision.FLOAT16], + type=str, + default=Precision.FLOAT32.value, + choices=[Precision.FLOAT32.value, Precision.FLOAT16.value], help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision", ) @@ -189,7 +199,11 @@ def parse_arguments(argv: list[str] | None = None) -> argparse.Namespace: output_group.set_defaults(use_external_data_format=False) output_group.add_argument( - "-s", "--run_shape_inference", required=False, action="store_true", help="run shape inference" + "-s", + "--run_shape_inference", + required=False, + action="store_true", + help="run shape inference", ) output_group.set_defaults(run_shape_inference=False) @@ -223,6 +237,14 @@ def parse_arguments(argv: list[str] | None = None) -> argparse.Namespace: ) output_group.set_defaults(disable_shared_initializers=False) + output_group.add_argument( + "--encoder_decoder_init", + required=False, + action="store_true", + help="Add decoder initialization to encoder for T5 model. This is legacy format that will be deprecated.", + ) + output_group.set_defaults(encoder_decoder_init=False) + model_group = parser.add_argument_group("Beam search parameters that stored in the output model") model_group.add_argument( @@ -426,7 +448,10 @@ def parse_arguments(argv: list[str] | None = None) -> argparse.Namespace: test_group.set_defaults(use_sln_strict_mode=False) test_group.add_argument( - "--use_gpu", required=False, action="store_true", help="use GPU for inference. Required for fp16." + "--use_gpu", + required=False, + action="store_true", + help="use GPU for inference. Required for fp16.", ) test_group.set_defaults(use_gpu=False) @@ -490,7 +515,7 @@ def gpt2_to_onnx(args: argparse.Namespace): args.decoder_onnx, "--optimize_onnx", "--precision", - "fp32" if args.precision == Precision.FLOAT32 else "fp16", + args.precision, "--test_runs", "1", "--test_cases", @@ -508,7 +533,7 @@ def gpt2_to_onnx(args: argparse.Namespace): arguments.extend(["--op_block_list"]) arguments.extend(args.op_block_list) - if args.precision == Precision.FLOAT16: + if args.precision == Precision.FLOAT16.value: assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu" # TODO(tianleiwu): Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision') # Need change cuda kernel to support a combination of fp32 logits and fp16 past state. @@ -527,20 +552,21 @@ def t5_to_onnx(args: argparse.Namespace): args (argparse.Namespace): arguments parsed from command line """ paths = export_t5_onnx_models( - args.model_name_or_path, - args.cache_dir, - Path(args.output).parent, + model_name_or_path=args.model_name_or_path, + cache_dir=args.cache_dir, + output_dir=Path(args.output).parent, use_gpu=args.use_gpu, use_external_data_format=args.use_external_data_format, - optimize_onnx=(args.precision != Precision.FLOAT16), + optimize_onnx=(args.precision != Precision.FLOAT16.value), precision=args.precision, verbose=False, use_decoder_start_token=False, - merge_encoder_and_decoder_init=True, overwrite=True, disable_auto_mixed_precision=False, use_int32_inputs=True, model_type=args.model_type, + encoder_decoder_init=args.encoder_decoder_init, + force_fp16_io=(args.precision == Precision.FLOAT16.value), # required by BeamSearch op implementation. ) logger.debug(f"onnx model for encoder: {paths[0]}") @@ -693,7 +719,7 @@ def verify_gpt2_subgraph(graph: onnx.GraphProto, precision: Precision): ValueError: Output name is not expected. ValueError: Output data type is not expected. """ - is_float16 = precision == Precision.FLOAT16 + is_float16 = precision == Precision.FLOAT16.value input_count = len(graph.input) layer_count = input_count - 3 @@ -749,7 +775,7 @@ def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision): ValueError: Output name is not expected. ValueError: Output data type is not expected. """ - is_float16 = precision == Precision.FLOAT16 + is_float16 = precision == Precision.FLOAT16.value float_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT input_count = len(graph.input) @@ -825,15 +851,20 @@ def verify_t5_encoder_decoder_init_subgraph(graph: onnx.GraphProto, precision: P ValueError: Output name is not expected. ValueError: Output data type is not expected. """ - is_float16 = precision == Precision.FLOAT16 - layer_count = (len(graph.output) - 2) // 4 - assert layer_count >= 1 + is_float16 = precision == Precision.FLOAT16.value + new_format = "cross" in graph.output[0].name # Expect 3 inputs: # encoder_input_ids: int32 (B, encode_sequence_length) # encoder_attention_mask: int32 (B, encode_sequence_length) # decoder_input_ids: int32 (B, 1) - expected_inputs = ["encoder_input_ids", "encoder_attention_mask", "decoder_input_ids"] + expected_inputs = [ + "encoder_input_ids", + "encoder_attention_mask", + "decoder_input_ids", + ] + if new_format: + expected_inputs = expected_inputs[:2] if len(graph.input) != len(expected_inputs): raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}") @@ -846,22 +877,41 @@ def verify_t5_encoder_decoder_init_subgraph(graph: onnx.GraphProto, precision: P if input_type != expected_type: raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}") - # Expected outputs: - # logits: (B, 1, vocab_size) - # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) - # present_key_self_0: (B, num_heads, 1, head_size) - # present_value_self_0: (B, num_heads, 1, head_size) - # ... (for each self attention layer) - # present_key_cross_0: (B, num_heads, encode_sequence_length, head_size) - # present_value_cross_0: (B, num_heads, encode_sequence_length, head_size) - # ... (for each cross attention layer) - expected_outputs = ["logits", "encoder_hidden_states"] - for i in range(layer_count): - expected_outputs.append(f"present_key_self_{i}") - expected_outputs.append(f"present_value_self_{i}") - for i in range(layer_count): - expected_outputs.append(f"present_key_cross_{i}") - expected_outputs.append(f"present_value_cross_{i}") + if new_format: + assert len(graph.output) % 2 == 0 + layer_count = len(graph.output) // 2 + assert layer_count >= 1 + + # Expected outputs: + # present_key_cross_0: (B, num_heads, encode_sequence_length, head_size) + # present_value_cross_0: (B, num_heads, encode_sequence_length, head_size) + # ... (for each cross attention layer) + expected_outputs = [] + for i in range(layer_count): + expected_outputs.append(f"present_key_cross_{i}") + expected_outputs.append(f"present_value_cross_{i}") + else: + logger.warning("This format is deprecated. Please export T5 encoder in new format with only cross outputs.") + assert (len(graph.output) - 2) % 4 == 0 + layer_count = (len(graph.output) - 2) // 4 + assert layer_count >= 1 + + # Expected outputs: + # logits: (B, 1, vocab_size) + # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + # present_key_self_0: (B, num_heads, 1, head_size) + # present_value_self_0: (B, num_heads, 1, head_size) + # ... (for each self attention layer) + # present_key_cross_0: (B, num_heads, encode_sequence_length, head_size) + # present_value_cross_0: (B, num_heads, encode_sequence_length, head_size) + # ... (for each cross attention layer) + expected_outputs = ["logits", "encoder_hidden_states"] + for i in range(layer_count): + expected_outputs.append(f"present_key_self_{i}") + expected_outputs.append(f"present_value_self_{i}") + for i in range(layer_count): + expected_outputs.append(f"present_key_cross_{i}") + expected_outputs.append(f"present_value_cross_{i}") if len(graph.output) != len(expected_outputs): raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}") @@ -1116,6 +1166,7 @@ def update_decoder_subgraph_past_present_share_buffer(subg: GraphProto): new_nodes = [] for node in subg.node: + new_node = node if node.op_type == "Attention": kwargs = kwargs_of(node) kwargs.update({"past_present_share_buffer": 1}) @@ -1125,8 +1176,8 @@ def update_decoder_subgraph_past_present_share_buffer(subg: GraphProto): nis.extend([""]) if len(nis) < 7: nis.extend(["past_sequence_length"]) - node = onnx.helper.make_node("Attention", nis, node.output, name=node.name, **kwargs) # noqa: PLW2901 - new_nodes.extend([node]) + new_node = onnx.helper.make_node("Attention", nis, node.output, name=node.name, **kwargs) + new_nodes.extend([new_node]) subg.ClearField("node") subg.node.extend(new_nodes) return subg @@ -1152,7 +1203,9 @@ def update_decoder_subgraph_use_decoder_masked_attention( new_inputs.extend( [ onnx.helper.make_tensor_value_info( - "cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"] + "cache_indirection", + onnx.TensorProto.INT32, + shape=["batch_size", "beam_width", "max_seq_len"], ) ] ) @@ -1203,7 +1256,11 @@ def update_decoder_subgraph_use_decoder_masked_attention( nis.extend(["cache_indirection"]) node = onnx.helper.make_node( # noqa: PLW2901 - "DecoderMaskedSelfAttention", nis, node.output, name=node.name, **kwargs + "DecoderMaskedSelfAttention", + nis, + node.output, + name=node.name, + **kwargs, ) new_nodes.extend([node]) subg.ClearField("node") @@ -1573,7 +1630,11 @@ def replace_mha_with_dmmha(model: OnnxModel, past_seq_len_name: str): def replace_mha_with_gqa( - model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = -1 + model: OnnxModel, + attn_mask: str, + kv_num_heads: int = 0, + world_size: int = 1, + window_size: int = -1, ): # Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes # @@ -1635,7 +1696,14 @@ def replace_mha_with_gqa( to=TensorProto.INT32, ) model.model.graph.node.extend( - [reduce_sum_node, sub_node, seqlen_k_cast_node, shape_node, gather_node, total_seqlen_cast_node] + [ + reduce_sum_node, + sub_node, + seqlen_k_cast_node, + shape_node, + gather_node, + total_seqlen_cast_node, + ] ) # Replace MultiHeadAttention with GroupQueryAttention @@ -1776,14 +1844,14 @@ def replace_mha_with_gqa( node.input[7], # past_value seqlen_k_cast_node.output[0], # seqlens_k (for attention mask) total_seqlen_cast_node.output[0], # total_seq_len (for attention mask) - q_rotary.input[2] if q_rotary is not None else "", # cos_cache (for rotary embeddings) - q_rotary.input[3] if q_rotary is not None else "", # sin_cache (for rotary embeddings) + (q_rotary.input[2] if q_rotary is not None else ""), # cos_cache (for rotary embeddings) + (q_rotary.input[3] if q_rotary is not None else ""), # sin_cache (for rotary embeddings) ], outputs=node.output, name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), domain="com.microsoft", num_heads=num_heads // world_size, - kv_num_heads=num_heads // world_size if kv_num_heads == 0 else kv_num_heads // world_size, + kv_num_heads=(num_heads // world_size if kv_num_heads == 0 else kv_num_heads // world_size), local_window_size=window_size, do_rotary=int(q_rotary is not None and k_rotary is not None), rotary_interleaved=interleaved, @@ -1831,7 +1899,9 @@ def update_decoder_subgraph_output_cross_attention(subg: GraphProto): node.attribute.extend([onnx.helper.make_attribute("output_qk", 1)]) cross_attention = onnx.helper.make_tensor_value_info( - cross_attention_out_name, TensorProto.FLOAT, [batch_size_dim, num_heads_dim, 1, cross_seq_len_dim] + cross_attention_out_name, + TensorProto.FLOAT, + [batch_size_dim, num_heads_dim, 1, cross_seq_len_dim], ) subg.output.extend([cross_attention]) if num_layer_output_qk != num_layers: @@ -1935,7 +2005,11 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelP kwargs["past_present_share_buffer"] = 1 node = onnx.helper.make_node( # noqa: PLW2901 - "DecoderMaskedMultiHeadAttention", nis, node.output, name=node.name, **kwargs + "DecoderMaskedMultiHeadAttention", + nis, + node.output, + name=node.name, + **kwargs, ) if node not in nodes_to_remove: @@ -1968,7 +2042,9 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelP new_inputs.extend( [ onnx.helper.make_tensor_value_info( - "cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"] + "cache_indirection", + onnx.TensorProto.INT32, + shape=["batch_size", "beam_width", "max_seq_len"], ) ] ) @@ -2020,7 +2096,7 @@ def pack_qkv_for_decoder_masked_mha(model_proto: ModelProto): matmul_node_name = onnx_model.create_node_name("MatMul", name_prefix="MatMul_QKV") weight = onnx.helper.make_tensor( name=matmul_node_name + "_weight", - data_type=TensorProto.FLOAT if q_weight.data_type == 1 else TensorProto.FLOAT16, + data_type=(TensorProto.FLOAT if q_weight.data_type == 1 else TensorProto.FLOAT16), dims=[qkv_weight.shape[0], qkv_weight.shape[1]], vals=qkv_weight.flatten().tolist(), ) @@ -2074,12 +2150,18 @@ def update_input_shapes_for_gpt2_decoder_model(decoder_onnx_path: str, use_exter # Update dim_value to be 1 shape_dim_proto.dim_value = 1 - OnnxModel.save(decoder_model_proto, decoder_onnx_path, save_as_external_data=use_external_data_format) + OnnxModel.save( + decoder_model_proto, + decoder_onnx_path, + save_as_external_data=use_external_data_format, + ) return True def generate_gpt2_init_decoder( - decoder_onnx_path: str, init_decoder_onnx_path: str, use_external_data_format: bool = True + decoder_onnx_path: str, + init_decoder_onnx_path: str, + use_external_data_format: bool = True, ) -> bool: """Generates the initial decoder GPT2 subgraph and saves it for downstream use. The initial decoder model will be saved to init_decoder_onnx_path. @@ -2152,7 +2234,16 @@ def generate_gpt2_init_decoder( # Normalization Node is : LayerNormalization logits_matmul_to_residual_add_path = gpt2_init_decoder_model.match_parent_path( logits_matmul_node, - ["LayerNormalization", "Add", "Add", "MatMul", "FastGelu", "MatMul", "LayerNormalization", "Add"], + [ + "LayerNormalization", + "Add", + "Add", + "MatMul", + "FastGelu", + "MatMul", + "LayerNormalization", + "Add", + ], [0, 0, 1, 0, 0, 0, 0, 0], ) @@ -2183,7 +2274,9 @@ def generate_gpt2_init_decoder( if not is_skiplayernorm_path: residual_add_to_attention_parent_index = 0 residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path( - residual_add_node, ["Add", "Cast", "MatMul", "Attention"], [residual_add_to_attention_parent_index, 0, 0, 0] + residual_add_node, + ["Add", "Cast", "MatMul", "Attention"], + [residual_add_to_attention_parent_index, 0, 0, 0], ) # Try other parent index of the residual Add node @@ -2199,42 +2292,54 @@ def generate_gpt2_init_decoder( if residual_add_to_attention_path is None: residual_add_to_attention_parent_index = 0 residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path( - residual_add_node, ["Add", "MatMul", "Attention"], [residual_add_to_attention_parent_index, 0, 0] + residual_add_node, + ["Add", "MatMul", "Attention"], + [residual_add_to_attention_parent_index, 0, 0], ) # Try without the Casts before and after the MatMuls and other parent index of the residual Add node if residual_add_to_attention_path is None: residual_add_to_attention_parent_index = 1 residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path( - residual_add_node, ["Add", "MatMul", "Attention"], [residual_add_to_attention_parent_index, 0, 0] + residual_add_node, + ["Add", "MatMul", "Attention"], + [residual_add_to_attention_parent_index, 0, 0], ) # SkipLayerNormalization path else: residual_add_to_attention_parent_index = 0 residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path( - residual_add_node, ["Cast", "MatMul", "Attention"], [residual_add_to_attention_parent_index, 0, 0] + residual_add_node, + ["Cast", "MatMul", "Attention"], + [residual_add_to_attention_parent_index, 0, 0], ) # Try other parent index of the residual Add node if residual_add_to_attention_path is None: residual_add_to_attention_parent_index = 1 residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path( - residual_add_node, ["Cast", "MatMul", "Attention"], [residual_add_to_attention_parent_index, 0, 0] + residual_add_node, + ["Cast", "MatMul", "Attention"], + [residual_add_to_attention_parent_index, 0, 0], ) # Try without the Casts before and after the MatMuls if residual_add_to_attention_path is None: residual_add_to_attention_parent_index = 0 residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path( - residual_add_node, ["MatMul", "Attention"], [residual_add_to_attention_parent_index, 0] + residual_add_node, + ["MatMul", "Attention"], + [residual_add_to_attention_parent_index, 0], ) # Try without the Casts before and after the MatMuls and other parent index of the residual Add node if residual_add_to_attention_path is None: residual_add_to_attention_parent_index = 1 residual_add_to_attention_path = gpt2_init_decoder_model.match_parent_path( - residual_add_node, ["MatMul", "Attention"], [residual_add_to_attention_parent_index, 0] + residual_add_node, + ["MatMul", "Attention"], + [residual_add_to_attention_parent_index, 0], ) # TODO(hasesh): Are there more permutations to try before returning ? @@ -2252,7 +2357,9 @@ def generate_gpt2_init_decoder( # SkipLayerNormalization path else: add_before_residual_add = gpt2_init_decoder_model.match_parent( - residual_add_node, "SkipLayerNormalization", residual_add_to_add_parent_index + residual_add_node, + "SkipLayerNormalization", + residual_add_to_add_parent_index, ) if add_before_residual_add is None: @@ -2342,7 +2449,11 @@ def generate_gpt2_init_decoder( gpt2_init_decoder_model.topological_sort() # Save the init decoder model - OnnxModel.save(init_decoder_model_proto, init_decoder_onnx_path, save_as_external_data=use_external_data_format) + OnnxModel.save( + init_decoder_model_proto, + init_decoder_onnx_path, + save_as_external_data=use_external_data_format, + ) return True @@ -2383,7 +2494,10 @@ def make_dim_proto_numeric_t5(model, config): dim_proto.dim_value = dim_value -def convert_generation_model(args: argparse.Namespace, generation_type: GenerationType = GenerationType.BEAMSEARCH): +def convert_generation_model( + args: argparse.Namespace, + generation_type: GenerationType = GenerationType.BEAMSEARCH, +): """Convert model according to command line arguments. Args: @@ -2397,8 +2511,13 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati logger.info(f"**** past_present_share_buffer={past_present_share_buffer}") if len(args.op_block_list) == 1 and args.op_block_list[0] == "auto": - if is_gpt2 and args.precision == Precision.FLOAT16: - args.op_block_list = ["Add", "LayerNormalization", "SkipLayerNormalization", "FastGelu"] + if is_gpt2 and args.precision == Precision.FLOAT16.value: + args.op_block_list = [ + "Add", + "LayerNormalization", + "SkipLayerNormalization", + "FastGelu", + ] logger.info(f"**** Setting op_block_list to {args.op_block_list}") logger.info("**** use --op_block_list if you want to override the block operator list.") else: @@ -2434,9 +2553,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati logger.info(f"skip convert_to_onnx since path existed: {args.decoder_onnx}") else: if not args.decoder_onnx: - onnx_filename = "{}_past_{}.onnx".format( - args.model_name_or_path, "fp16" if args.precision == Precision.FLOAT16 else "fp32" - ) + onnx_filename = f"{args.model_name_or_path}_past_{args.precision}.onnx" args.decoder_onnx = Path(Path(args.output).parent, onnx_filename).as_posix() logger.info(f"Convert GPT model {args.model_name_or_path} to onnx {args.decoder_onnx} ...") @@ -2458,7 +2575,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati logits_matmul_weight_padded = False if ( not args.disable_pad_vocab_size - and args.precision == Precision.FLOAT16 + and args.precision == Precision.FLOAT16.value and is_gpt2 and (is_beamsearch or is_greedysearch or is_sampling) ): @@ -2481,14 +2598,14 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati ): logger.info(f"Creating an initial run GPT2 decoder from {args.decoder_onnx}. ") - gpt2_init_decoder_onnx_filename = "gpt2_init_past_{}.onnx".format( - "fp16" if args.precision == Precision.FLOAT16 else "fp32" - ) + gpt2_init_decoder_onnx_filename = f"gpt2_init_past_{args.precision}.onnx" gpt2_init_decoder_onnx_path = Path(Path(args.output).parent, gpt2_init_decoder_onnx_filename).as_posix() gpt2_init_decoder_generated = generate_gpt2_init_decoder( - args.decoder_onnx, gpt2_init_decoder_onnx_path, args.use_external_data_format + args.decoder_onnx, + gpt2_init_decoder_onnx_path, + args.use_external_data_format, ) if not gpt2_init_decoder_generated: @@ -2672,7 +2789,8 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati logger.info(f"Symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.") shape_inference(args.encoder_decoder_init_onnx, args.use_external_data_format) encoder_model = onnx.load_model(args.encoder_decoder_init_onnx, load_external_data=True) - encoder_model.graph.name = f"{args.model_type} encoder and decoder init" + suffix = "encoder" if len(encoder_model.graph.input) == 2 else "encoder and decoder init" + encoder_model.graph.name = f"{args.model_type} {suffix}" verify_t5_encoder_decoder_init_subgraph(encoder_model.graph, args.precision) make_dim_proto_numeric_t5(encoder_model, config) @@ -2711,14 +2829,13 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati # ) # initializers.extend(moved_initializers) + assert config.decoder_start_token_id >= 0, "decoder_start_token_id should be >= 0" + node.attribute.extend( [ onnx.helper.make_attribute("encoder", encoder_model.graph), onnx.helper.make_attribute("decoder", decoder_model.graph), - onnx.helper.make_attribute( - "decoder_start_token_id", - config.decoder_start_token_id if len(encoder_model.graph.input) == 3 else -1, - ), + onnx.helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), ] ) else: @@ -2838,7 +2955,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati if args.output_sequences_scores: sequences_scores = onnx.helper.make_tensor_value_info( - "sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"] + "sequences_scores", + TensorProto.FLOAT, + ["batch_size", "num_return_sequences"], ) graph_outputs.append(sequences_scores) @@ -2852,7 +2971,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati new_graph = onnx.helper.make_graph( [node], - f"{args.model_type} beam search" if not is_greedysearch else f"{args.model_type} greedy search", + (f"{args.model_type} beam search" if not is_greedysearch else f"{args.model_type} greedy search"), graph_inputs, graph_outputs, initializers, @@ -2912,7 +3031,7 @@ def test_torch_performance( if args.use_gpu and not torch.cuda.is_available(): raise RuntimeError("Please install PyTorch with Cuda for testing gpu performance.") - if args.precision == Precision.FLOAT16: + if args.precision == Precision.FLOAT16.value: model.half() device = torch.device("cuda:0" if args.use_gpu else "cpu") @@ -2961,7 +3080,11 @@ def create_attention_mask(input_ids, pad_token_id): return attention_mask -def test_gpt_model(args: argparse.Namespace, sentences: list[str] | None = None, is_greedy: bool = False): +def test_gpt_model( + args: argparse.Namespace, + sentences: list[str] | None = None, + is_greedy: bool = False, +): """Test GPT-2 model Args: @@ -3152,7 +3275,7 @@ def test_gpt_model(args: argparse.Namespace, sentences: list[str] | None = None, print("-" * 50) # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not. is_same = torch_decoded_sequences == ort_decoded_sequences - print("Torch and ORT result is ", "same" if is_same else "different") + print("Torch and ORT result is", "same" if is_same else "different") output["parity"] = is_same if args.torch_performance: diff --git a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py index a0eff081675fe..5ce089712ccb1 100644 --- a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py @@ -51,6 +51,7 @@ def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict): mul_node, div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes if not self.model.has_constant_input(div_node, 1.0): return + node_parent = mul_node else: # Div(1, RMS) can also be represented as Reciprocal(RMS) like # @@ -66,6 +67,7 @@ def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict): # Mul --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node) # (B=2) (A/B=eps) (A/B=scale) # + return_indice = [] sim_ln_nodes = self.model.match_parent_path( node, ["Mul", "Reciprocal", "Sqrt", "Add", "ReduceMean"], @@ -73,24 +75,50 @@ def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict): output_name_to_node=output_name_to_node, return_indice=return_indice, ) - if sim_ln_nodes is None: - return - mul_node, _reciprocal_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes - - pow_or_mul_node = self.model.get_parent(reduce_mean_node, 0, output_name_to_node) - if pow_or_mul_node is None or pow_or_mul_node.op_type not in ["Pow", "Mul"]: + if sim_ln_nodes is not None: + mul_node, _reciprocal_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes + node_parent = mul_node + else: + # (root_input) --------------------------------+ + # | | + # v v + # Pow --> ReduceMean --> Add ---> Sqrt --> Div --> Mul (node) + # (B=2) (A/B=eps) (A/B=scale) + # + # (root_input) --------------------------------+ + # | | | + # v v v + # Mul --> ReduceMean --> Add ---> Sqrt --> Div --> Mul (node) + # (B=2) (A/B=eps) (A/B=scale) + # + return_indice = [] + sim_ln_nodes = self.model.match_parent_path( + node, + ["Div", "Sqrt", "Add", "ReduceMean"], + [None, 1, 0, None], + output_name_to_node=output_name_to_node, + return_indice=return_indice, + ) + if sim_ln_nodes is not None: + div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes + node_parent = div_node + else: + return + + reduce_mean_parent = self.model.get_parent(reduce_mean_node, 0, output_name_to_node) + if reduce_mean_parent is None or reduce_mean_parent.op_type not in ["Pow", "Mul"]: return - if pow_or_mul_node.op_type == "Pow": - if self.model.find_constant_input(pow_or_mul_node, 2.0) != 1: + if reduce_mean_parent.op_type == "Pow": + if self.model.find_constant_input(reduce_mean_parent, 2.0) != 1: return else: - assert pow_or_mul_node.op_type == "Mul" - if pow_or_mul_node[0] != pow_or_mul_node[1]: + assert reduce_mean_parent.op_type == "Mul" + if reduce_mean_parent[0] != reduce_mean_parent[1]: return - root_input = pow_or_mul_node.input[0] - if root_input != mul_node.input[0]: + root_input = reduce_mean_parent.input[0] + if root_input not in node_parent.input: return _i, epsilon = self.model.get_constant_input(add_node) @@ -113,7 +141,7 @@ def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict): return self.nodes_to_remove.extend(sim_ln_nodes) - self.nodes_to_remove.append(pow_or_mul_node) + self.nodes_to_remove.append(reduce_mean_parent) self.nodes_to_remove.append(node) normalize_node = helper.make_node( diff --git a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py index 75887cc744081..f8b7dd80710ae 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py @@ -371,9 +371,6 @@ def main(argv=None, experiment_name: str = "", run_id: str = "0", csv_filename: model_size_in_MB = int(get_onnx_model_size(output_path, args.use_external_data_format) / 1024 / 1024) # noqa: N806 provider = args.provider - if args.provider == "migraphx": - provider = "MIGraphXExecutionProvider" - session = create_onnxruntime_session( output_path, args.use_gpu, provider, enable_all_optimization=True, verbose=args.verbose ) diff --git a/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py index adf5206be8353..dd519e36cfa88 100755 --- a/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py @@ -10,8 +10,15 @@ import os import torch -from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger +from benchmark_helper import ( + Precision, + create_onnxruntime_session, + prepare_environment, + setup_logger, +) +from onnx.shape_inference import infer_shapes_path from t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS, T5Helper +from transformers import MT5Config, T5Config logger = logging.getLogger("") @@ -70,9 +77,9 @@ def parse_arguments(): "-p", "--precision", required=False, - type=Precision, - default=Precision.FLOAT32, - choices=[Precision.FLOAT32, Precision.FLOAT16], + type=str, + default=Precision.FLOAT32.value, + choices=[Precision.FLOAT32.value, Precision.FLOAT16.value], help="Precision of model to run. fp32 for full precision, fp16 for half precision", ) @@ -104,17 +111,17 @@ def parse_arguments(): "--disable_auto_mixed_precision", required=False, action="store_true", - help="use pure fp16 instead of mixed precision", + help="do not use auto mixed precision conversion", ) parser.set_defaults(disable_auto_mixed_precision=False) parser.add_argument( - "--separate_encoder_and_decoder_init", + "--force_fp16_io", required=False, action="store_true", - help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.", + help="Force to convert all float inputs and outputs to fp16 when precision is fp16.", ) - parser.set_defaults(separate_encoder_and_decoder_init=False) + parser.set_defaults(force_fp16_io=False) parser.add_argument( "--use_int64_inputs", @@ -131,34 +138,52 @@ def parse_arguments(): help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", ) + parser.add_argument( + "--encoder_decoder_init", + required=False, + action="store_true", + help="Combine encoder and decoder kv cache initialization into one model. It is legacy format that will be deprecated.", + ) + parser.set_defaults(encoder_decoder_init=False) + args = parser.parse_args() return args def export_onnx_models( - model_name_or_path, - cache_dir, - output_dir, - use_gpu, - use_external_data_format, - optimize_onnx, - precision, - verbose, + model_name_or_path: str, + cache_dir: str, + output_dir: str, + use_gpu: bool = False, + use_external_data_format: bool = False, + optimize_onnx: bool = False, + precision: str = Precision.FLOAT32.value, + verbose: bool = False, use_decoder_start_token: bool = False, - merge_encoder_and_decoder_init: bool = True, overwrite: bool = False, disable_auto_mixed_precision: bool = False, use_int32_inputs: bool = True, model_type: str = "t5", state_dict_path: str = "", + encoder_decoder_init: bool = False, + force_fp16_io: bool = False, + shape_infer_before_optimization: bool = False, ): + assert precision in [Precision.FLOAT32.value, Precision.FLOAT16.value], ( + f"Invalid precision: {precision}. Use 'fp32' or 'fp16'." + ) device = torch.device("cuda:0" if use_gpu else "cpu") models = T5Helper.load_model( - model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, model_type, state_dict_path + model_name_or_path, + cache_dir, + device, + model_type, + state_dict_path, + encoder_decoder_init=encoder_decoder_init, ) - config = models["decoder"].config + config: T5Config | MT5Config = models["decoder"].config if (not use_external_data_format) and (config.num_layers > 24): logger.info("Try use_external_data_format when model size > 2GB") @@ -191,8 +216,20 @@ def export_onnx_models( else: logger.info(f"Skip exporting: existed ONNX model {onnx_path}") - # Optimize ONNX graph. Note that we have not implemented graph optimization for T5 yet. - if optimize_onnx or precision != Precision.FLOAT32: + # Optimize ONNX graph. + # The precision shall be compared with string value. It is because the Precision enum loaded from local file + # (like by transformers test in CI pipeline) are not same as Precision enum from package. + if optimize_onnx or precision != Precision.FLOAT32.value: + onnx_shape_path = None + if shape_infer_before_optimization: + onnx_shape_path = T5Helper.get_onnx_path( + output_dir, + model_name_or_path, + suffix=filename_suffix + "_shape", + new_folder=False, + ) + infer_shapes_path(onnx_path, onnx_shape_path) + output_path = T5Helper.get_onnx_path( output_dir, model_name_or_path, @@ -203,30 +240,35 @@ def export_onnx_models( if overwrite or not os.path.exists(output_path): logger.info(f"Optimizing model to {output_path}") T5Helper.optimize_onnx( - onnx_path, + onnx_shape_path or onnx_path, output_path, - precision == Precision.FLOAT16, + precision == Precision.FLOAT16.value, config.num_heads, config.hidden_size, use_external_data_format, auto_mixed_precision=not disable_auto_mixed_precision, use_gpu=use_gpu, + force_fp16_io=force_fp16_io, ) else: - logger.info(f"Skip optimizing: existed ONNX model {onnx_path}") + logger.info(f"Skip optimizing: existed ONNX model {output_path}") else: output_path = onnx_path ort_session = create_onnxruntime_session( output_path, use_gpu=use_gpu, - provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"], + verbose=verbose, ) + if ort_session is None: + break with torch.no_grad(): max_diff = T5Helper.verify_onnx(model, ort_session, device, use_int32_inputs) logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}") - if max_diff > 1e-4: + + # The threshold cannot apply to fp16 model, which need a larger threshold. + if precision == Precision.FLOAT32.value and max_diff > 1e-4: logger.warning("PyTorch and OnnxRuntime results are NOT close") output_paths.append(output_path) @@ -245,15 +287,12 @@ def main(): output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output) prepare_environment(cache_dir, output_dir, args.use_gpu) - if args.precision != Precision.FLOAT32: + if args.precision != Precision.FLOAT32.value: assert args.optimize_onnx, "fp16/int8 requires --optimize_onnx" - if args.precision == Precision.FLOAT16: + if args.precision == Precision.FLOAT16.value: assert args.use_gpu, "fp16 requires --use_gpu" - if args.optimize_onnx: - logger.warning("Graph optimization for T5 is not implemented yet.") - output_paths = export_onnx_models( args.model_name_or_path, cache_dir, @@ -264,11 +303,12 @@ def main(): args.precision, args.verbose, args.use_decoder_start_token, - not args.separate_encoder_and_decoder_init, args.overwrite, args.disable_auto_mixed_precision, not args.use_int64_inputs, args.model_type, + encoder_decoder_init=args.encoder_decoder_init, + force_fp16_io=args.force_fp16_io, ) logger.info(f"Done! Outputs: {output_paths}") diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py b/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py index c6b0f7ee3adc2..df3a416f2947c 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py @@ -1,24 +1,14 @@ # ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- import logging -import os import random -import tempfile -from pathlib import Path -import numpy -import onnx import torch -from onnx_model import OnnxModel -from torch_onnx_export_helper import torch_onnx_export from transformers import MT5Config, T5Config -from onnxruntime import InferenceSession - logger = logging.getLogger(__name__) @@ -41,7 +31,11 @@ def __init__(self, input_ids, attention_mask): @staticmethod def create_dummy( - batch_size: int, sequence_length: int, vocab_size: int, device: torch.device, use_int32_inputs: bool = False + batch_size: int, + sequence_length: int, + vocab_size: int, + device: torch.device, + use_int32_inputs: bool = False, ): # -> T5EncoderInputs """Create dummy inputs for T5 encoder. @@ -74,97 +68,3 @@ def create_dummy( def to_list(self) -> list: input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None] return input_list - - -class T5EncoderHelper: - @staticmethod - def export_onnx( - encoder: T5Encoder, - device: torch.device, - onnx_model_path: str, - verbose: bool = True, - use_external_data_format: bool = False, - use_int32_inputs: bool = False, - ): - """Export encoder to ONNX - - Args: - encoder (T5Encoder): encoder object - device (torch.device): device of encoder object - onnx_model_path (str): onnx path - verbose (bool, optional): print verbose information. Defaults to True. - use_external_data_format (bool, optional): use external data format or not. Defaults to False. - """ - config = encoder.config - encoder_inputs = T5EncoderInputs.create_dummy( - batch_size=2, - sequence_length=4, - vocab_size=config.vocab_size, - device=device, - use_int32_inputs=use_int32_inputs, - ) - - Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) - - with tempfile.TemporaryDirectory() as tmp_dir_name: - temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx") - Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True) - torch_onnx_export( - encoder, - args=tuple(encoder_inputs.to_list()), - f=temp_onnx_model_path if use_external_data_format else onnx_model_path, - export_params=True, - input_names=["input_ids", "attention_mask"], - output_names=["hidden_states"], - dynamic_axes={ - "input_ids": {0: "batch_size", 1: "sequence_length"}, - "attention_mask": {0: "batch_size", 1: "sequence_length"}, - "hidden_states": {0: "batch_size", 1: "sequence_length"}, - }, - opset_version=12, - do_constant_folding=True, - use_external_data_format=use_external_data_format, - verbose=verbose, - ) - - if use_external_data_format: - model = onnx.load_model(temp_onnx_model_path, load_external_data=True) - OnnxModel.save( - model, - onnx_model_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - ) - - @staticmethod - def onnxruntime_inference(ort_session, inputs: T5EncoderInputs): - """Run inference of ONNX model.""" - ort_inputs = { - "input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()), - "attention_mask": numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy()), - } - - return ort_session.run(None, ort_inputs) - - @staticmethod - def verify_onnx( - model: T5Encoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False - ): - """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" - inputs = T5EncoderInputs.create_dummy( - batch_size=4, - sequence_length=11, - vocab_size=model.config.vocab_size, - device=device, - use_int32_inputs=use_int32_inputs, - ) - input_list = inputs.to_list() - torch_outputs = model(*input_list) - - ort_outputs = T5EncoderHelper.onnxruntime_inference(ort_session, inputs) - - max_diff = numpy.amax(numpy.abs(torch_outputs.cpu().numpy() - ort_outputs[0])) - - logger.info(f"max_diff={max_diff}") - - return max_diff diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py index c76d7aabdf11a..98df18eab6064 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py @@ -1,8 +1,7 @@ # ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- import logging import os @@ -31,33 +30,40 @@ def __init__( self, encoder: torch.nn.Module, decoder: torch.nn.Module, - lm_head: torch.nn.Module, + lm_head: torch.nn.Linear, config: T5Config | MT5Config, decoder_start_token_id: int | None = None, + output_cross_only: bool = False, ): super().__init__() - self.config = config + self.config: T5Config | MT5Config = config self.t5_encoder = T5Encoder(encoder, config) self.t5_decoder_init = T5DecoderInit(decoder, lm_head, config, decoder_start_token_id) + self.output_cross_only = output_cross_only def forward( self, encoder_input_ids: torch.Tensor, encoder_attention_mask: torch.Tensor, - decoder_input_ids: torch.Tensor = None, + decoder_input_ids: torch.Tensor | None = None, ): encoder_hidden_states: torch.FloatTensor = self.t5_encoder(encoder_input_ids, encoder_attention_mask) + lm_logits, past_self, past_cross = self.t5_decoder_init( decoder_input_ids, encoder_attention_mask, encoder_hidden_states ) - return lm_logits, encoder_hidden_states, past_self, past_cross + + if self.output_cross_only: + return past_cross + else: + return lm_logits, encoder_hidden_states, past_self, past_cross class T5EncoderDecoderInitInputs: def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids=None): self.encoder_input_ids: torch.LongTensor = encoder_input_ids self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask - self.decoder_input_ids: torch.LongTensor = decoder_input_ids + self.decoder_input_ids: torch.LongTensor | None = decoder_input_ids @staticmethod def create_dummy( @@ -108,9 +114,14 @@ def export_onnx( onnx_model_path (str): onnx path verbose (bool, optional): print verbose information. Defaults to True. use_external_data_format (bool, optional): use external data format or not. Defaults to False. + use_int32_inputs (bool, optional): use int32 instead of int64 for integer inputs. Defaults to False. """ assert isinstance(model, T5EncoderDecoderInit) + # Do not exclude decoder in torch onnx export so that cross can show up. + output_cross_only = model.output_cross_only + model.output_cross_only = False + inputs = T5EncoderDecoderInitInputs.create_dummy( model.config, batch_size=2, @@ -139,7 +150,7 @@ def export_onnx( input_names = ["encoder_input_ids", "encoder_attention_mask"] - # ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference. + # ONNX exporter might mark dimension like 'present_value_self_1_dim_2' in shape inference. # We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value. sequence_length = "1" num_heads = str(model.config.num_heads) @@ -201,9 +212,12 @@ def export_onnx( verbose=verbose, ) + # Restore output_cross_only setting. + model.output_cross_only = output_cross_only + # Workaround as mentioned earlier: change numeric dim_param to dim_value - model = onnx.load(temp_onnx_model_path) - for tensor in model.graph.output: + exported_model: onnx.ModelProto = onnx.load(temp_onnx_model_path) + for tensor in exported_model.graph.output: for dim_proto in tensor.type.tensor_type.shape.dim: if dim_proto.HasField("dim_param") and dim_proto.dim_param in [ sequence_length, @@ -215,8 +229,50 @@ def export_onnx( dim_proto.Clear() dim_proto.dim_value = dim_value + if output_cross_only: + # Rewrite onnx graph to only keep present_[key|value]_cross_* outputs. + onnx_model = OnnxModel(exported_model) + output_name_to_node = onnx_model.output_name_to_node() + + for output in exported_model.graph.output: + if "cross" in output.name: + assert output.name in output_name_to_node + + transpose_node = output_name_to_node[output.name] + assert transpose_node and transpose_node.op_type == "Transpose" + + permutation = OnnxModel.get_node_attribute(transpose_node, "perm") + assert isinstance(permutation, list) + assert permutation == [0, 2, 1, 3] + + matched_nodes = onnx_model.match_parent_path( + transpose_node, + ["Reshape", "MatMul"], + [0, 0], + output_name_to_node, + ) + assert matched_nodes is not None + + reshape_node, matmul_node = matched_nodes + assert "encoder_hidden_states" in matmul_node.input + + if not onnx_model.get_initializer("cross_reshape_shape"): + shape_tensor = onnx.helper.make_tensor( + name="cross_reshape_shape", + data_type=onnx.TensorProto.INT64, + dims=[4], + vals=[0, 0, int(num_heads), int(head_size)], + raw=False, + ) + onnx_model.add_initializer(shape_tensor) + + reshape_node.input[1] = "cross_reshape_shape" + + cross_outputs = [output.name for output in exported_model.graph.output if "cross" in output.name] + onnx_model.prune_graph(cross_outputs, allow_remove_graph_inputs=True) + OnnxModel.save( - model, + exported_model, onnx_model_path, save_as_external_data=use_external_data_format, all_tensors_to_one_file=True, @@ -269,27 +325,34 @@ def verify_onnx( num_decoder_layers = model.config.num_decoder_layers - assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape - max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0])) - logger.debug(f"logits max_diff={max_diff}") - max_diff_all = max_diff - - assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape - max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1])) - logger.debug(f"encoder_hidden_states max_diff={max_diff}") - max_diff_all = max(max_diff_all, max_diff) - - for i in range(2 * num_decoder_layers): - max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i])) - logger.debug(f"self attention past state {i} max_diff={max_diff}") - - for i in range(2 * num_decoder_layers): - max_diff = numpy.amax( - numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i]) - ) - logger.debug(f"cross attention past state {i} max_diff={max_diff}") + if not model.output_cross_only: + assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape + max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0])) + logger.debug(f"logits max_diff={max_diff}") + max_diff_all = max_diff + + assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape + max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1])) + logger.debug(f"encoder_hidden_states max_diff={max_diff}") max_diff_all = max(max_diff_all, max_diff) + for i in range(2 * num_decoder_layers): + max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i])) + logger.debug(f"self attention past state {i} max_diff={max_diff}") + + for i in range(2 * num_decoder_layers): + max_diff = numpy.amax( + numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i]) + ) + logger.debug(f"cross attention past state {i} max_diff={max_diff}") + max_diff_all = max(max_diff_all, max_diff) + else: + max_diff_all = -float("inf") + for i in range(2 * num_decoder_layers): + max_diff = numpy.amax(numpy.abs(torch_outputs[i].cpu().numpy() - ort_outputs[i])) + logger.debug(f"cross attention past state {i} max_diff={max_diff}") + max_diff_all = max(max_diff_all, max_diff) + test_cases_max_diff.append(max_diff_all) logger.info( f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}" diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py index d3f25e979887d..7552008f920e0 100755 --- a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py @@ -1,8 +1,7 @@ # ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- import logging import os @@ -12,8 +11,7 @@ from float16 import float_to_float16_max_diff from onnx_model import OnnxModel from optimizer import optimize_model -from t5_decoder import T5Decoder, T5DecoderHelper, T5DecoderInit -from t5_encoder import T5Encoder, T5EncoderHelper +from t5_decoder import T5Decoder, T5DecoderHelper from t5_encoder_decoder_init import T5EncoderDecoderInit, T5EncoderDecoderInitHelper from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration @@ -22,7 +20,13 @@ logger = logging.getLogger(__name__) PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"] -PRETRAINED_MT5_MODELS = ["google/mt5-small", "google/mt5-base", "google/mt5-large", "google/mt5-xl", "google/mt5-xxl"] +PRETRAINED_MT5_MODELS = [ + "google/mt5-small", + "google/mt5-base", + "google/mt5-large", + "google/mt5-xl", + "google/mt5-xxl", +] class T5Helper: @@ -60,18 +64,19 @@ def load_model( model_name_or_path: str, cache_dir: str, device: torch.device, - merge_encoder_and_decoder_init: bool = True, model_type: str = "t5", state_dict_path: str = "", - ) -> dict[str, torch.nn.Module]: + encoder_decoder_init: bool = False, + ) -> dict[str, T5EncoderDecoderInit | T5Decoder]: """Load model given a pretrained name or path, then build models for ONNX conversion. Args: model_name_or_path (str): pretrained model name or path cache_dir (str): cache directory device (torch.device): device to run the model - merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True. - is_mt5 (bool, optional): whether the model is MT5 instead of T5 + model_type (str, optional): model type "t5" or "mt5" + state_dict_path(str, optional): state dictionary path + encoder_decoder_init (bool, optional): combine encoder and decoder kv cache initialization into one model. Returns: Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion. """ @@ -88,29 +93,21 @@ def load_model( decoder = T5Decoder(model.decoder, model.lm_head, model.config) decoder.eval().to(device) - if merge_encoder_and_decoder_init: - encoder_decoder_init = T5EncoderDecoderInit( - model.encoder, - model.decoder, - model.lm_head, - model.config, - decoder_start_token_id=None, - ) - return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder} - else: - encoder = T5Encoder(model.encoder, model.config) - encoder.eval().to(device) - decoder_init = T5DecoderInit(model.decoder, model.lm_head, model.config) - decoder_init.eval().to(device) - return { - "encoder": encoder, - "decoder": decoder, - "decoder_init": decoder_init, - } + encoder = T5EncoderDecoderInit( + model.encoder, + model.decoder, + model.lm_head, + model.config, + decoder_start_token_id=None, + output_cross_only=not encoder_decoder_init, + ) + + encoder_name = "encoder_decoder_init" if encoder_decoder_init else "encoder" + return {encoder_name: encoder, "decoder": decoder} @staticmethod def export_onnx( - model: T5Encoder | T5Decoder | T5DecoderInit | T5EncoderDecoderInit, + model: T5Decoder | T5EncoderDecoderInit, device: torch.device, onnx_model_path: str, verbose: bool = True, @@ -118,16 +115,7 @@ def export_onnx( use_decoder_input_ids: bool = True, use_int32_inputs: bool = False, ): - if isinstance(model, T5Encoder): - T5EncoderHelper.export_onnx( - model, - device, - onnx_model_path, - verbose, - use_external_data_format, - use_int32_inputs, - ) - elif isinstance(model, T5EncoderDecoderInit): + if isinstance(model, T5EncoderDecoderInit): T5EncoderDecoderInitHelper.export_onnx( model, device, @@ -150,21 +138,28 @@ def export_onnx( @staticmethod def auto_mixed_precision( onnx_model: OnnxModel, - op_block_list: list[str] = [ # noqa: B006 - "SimplifiedLayerNormalization", - "SkipSimplifiedLayerNormalization", - "Relu", - "Add", - ], + op_block_list: list[str] | None = None, + force_fp16_logits: bool = False, + use_symbolic_shape_infer: bool = True, ): """Convert model to mixed precision. It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically. Args: onnx_model (OnnxModel): optimized ONNX model - op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"] + op_block_list (List[str], optional): operators need to run in fp32. + force_fp16_logits (bool, optional): force logits and last MatMul node to be in float16. Defaults to False. + use_symbolic_shape_infer (bool, optional): use symbolic shape inference to convert float to float16. Defaults to True. Returns: parameters(dict): a dictionary of parameters used in float16 conversion """ + if op_block_list is None: + op_block_list = [ + "SimplifiedLayerNormalization", + "SkipSimplifiedLayerNormalization", + "Relu", + "Add", + ] + op_full_set = {node.op_type for node in onnx_model.nodes()} fp32_op_set = set(op_block_list) fp16_op_set = op_full_set.difference(fp32_op_set) @@ -198,11 +193,38 @@ def auto_mixed_precision( keep_io_types = [] node_block_list = [] - if (not is_weight_fp16_precision) and (last_matmul_node is not None): + if (not is_weight_fp16_precision) and (last_matmul_node is not None) and not force_fp16_logits: # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision. keep_io_types = [logits_output_name] node_block_list = [last_matmul_node.name] + if "Add" not in op_block_list: + input_name_to_nodes = onnx_model.input_name_to_nodes() + fp32_add = 0 + changed = True + add_nodes = onnx_model.get_nodes_by_op_type("Add") + while changed: + changed = False + for node in add_nodes: + if node.name not in node_block_list: + parents = onnx_model.get_parents(node, output_name_to_node) + children = onnx_model.get_children(node, input_name_to_nodes) + blocked_children = [ + child for child in children if child.op_type in op_block_list or child in node_block_list + ] + blocked_parents = [ + parent for parent in parents if parent.op_type in op_block_list or parent in node_block_list + ] + # If any child or parent is in fp32, we place the Add node to fp32. + if (len(blocked_children) + len(blocked_parents)) > 0: + node_block_list.append(node.name) + fp32_add += 1 + changed = True + fp16_add = len(add_nodes) - fp32_add + logger.info(f"node counter of Add operator: fp32={fp32_add} fp16={fp16_add}") + + logger.info(f"node_block_list: {node_block_list}") + parameters = { "keep_io_types": keep_io_types, "op_block_list": op_block_list, @@ -211,7 +233,18 @@ def auto_mixed_precision( } logger.info(f"auto_mixed_precision parameters: {parameters}") - onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters) + if use_symbolic_shape_infer: + onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters) + else: + # Workaround when symbolic shape inference fails. + # Need enable shape_infer_before_optimization in convert_to_onnx.py as well. + from float16 import convert_float_to_float16 + + convert_float_to_float16( + onnx_model.model, + disable_shape_infer=True, + **parameters, + ) return parameters @@ -225,6 +258,7 @@ def optimize_onnx( use_external_data_format: bool = False, auto_mixed_precision: bool = True, use_gpu: bool = False, + force_fp16_io: bool = False, ): """Optimize ONNX model with an option to convert it to use mixed precision.""" @@ -233,38 +267,35 @@ def optimize_onnx( optimization_options = None if is_float16: optimization_options = FusionOptions("t5") - optimization_options.enable_skip_layer_norm = False + # SkipLayerNormalization is faster but might bring accuracy drop since it uses fp16 accumulation. + optimization_options.enable_skip_layer_norm = not auto_mixed_precision m = optimize_model( onnx_model_path, model_type="t5", num_heads=num_attention_heads, hidden_size=hidden_size, - opt_level=2 if not use_external_data_format else 0, + opt_level=0, optimization_options=optimization_options, - use_gpu=False, - only_onnxruntime=not use_gpu, + use_gpu=use_gpu, ) if is_float16: if auto_mixed_precision: - T5Helper.auto_mixed_precision(m) + T5Helper.auto_mixed_precision(m, force_fp16_logits=force_fp16_io) else: - m.convert_model_float32_to_float16(cast_input_output=False) + m.convert_model_float32_to_float16(cast_input_output=force_fp16_io) m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True) @staticmethod def verify_onnx( - model: T5Encoder | T5Decoder | T5DecoderInit | T5EncoderDecoderInit, + model: T5Decoder | T5EncoderDecoderInit, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool, ): """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" - if isinstance(model, T5Encoder): - return T5EncoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs) - if isinstance(model, T5EncoderDecoderInit): return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index c0310b3e8c663..8add38b5a7d07 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -1183,11 +1183,21 @@ def graph_topological_sort(graph, is_deterministic=False): graph.ClearField("node") graph.node.extend(sorted_nodes) - def topological_sort(self, is_deterministic=False): + def topological_sort(self, is_deterministic=False, dump_model_on_failure=False): # TODO: support graph_topological_sort() in subgraphs # for graph in self.graphs(): # self.graph_topological_sort(graph) - OnnxModel.graph_topological_sort(self.model.graph, is_deterministic) + try: + OnnxModel.graph_topological_sort(self.model.graph, is_deterministic) + except RuntimeError as e: + if dump_model_on_failure: + logger.info( + "Failed to sort graph in topological order. Dumping model to _topo_sort_failed.onnx for debugging." + ) + OnnxModel.save( + self.model, "_topo_sort_failed.onnx", save_as_external_data=True, all_tensors_to_one_file=True + ) + raise e @staticmethod def save( diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index 33dcc7795a465..de299a970ffd3 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -34,13 +34,13 @@ def __init__( num_heads, attention_mask, use_multi_head_attention=False, - search_op_types=["SkipSimplifiedLayerNormalization", "Add"], + search_op_types=["Softmax"], ) self.static_kv = 1 - def create_attention_node( + def make_attention_node( self, - mask_index: str, + mask_index: str | None, q_matmul: NodeProto, k_matmul: NodeProto, v_matmul: NodeProto, @@ -48,8 +48,8 @@ def create_attention_node( hidden_size: int, input: str, output: str, - add_qk_str: str, - scale: float | None = None, + attn_bias: str | None, + scale: float, ) -> NodeProto | None: """Create an Attention node. Args: @@ -122,14 +122,17 @@ def create_attention_node( attention_node_name + "_qkv_weight", "", ] - if mask_index is not None: + if mask_index: attention_inputs.append(mask_index) else: attention_inputs.append("") - if add_qk_str is not None: + if attn_bias: attention_inputs.append("") # no past - attention_inputs.append(add_qk_str) + attention_inputs.append(attn_bias) + + while attention_inputs and attention_inputs[-1] == "": + attention_inputs.pop() attention_node = helper.make_node( "Attention", @@ -153,50 +156,55 @@ def create_mha_node( query: str, key: str, value: str, - mask_index: str, - res_pos_bias: str, - past_key: str, - past_value: str, + mask_index: str | None, + attn_bias: str | None, + past_key: str | None, + past_value: str | None, output: str, - present_key: str, - present_value: str, + present_key: str | None, + present_value: str | None, num_heads: int, hidden_size: int, ) -> NodeProto | None: - assert num_heads > 0 + assert num_heads > 0 and hidden_size > 0 and query and key and value - if hidden_size > 0 and (hidden_size % num_heads) != 0: + if (hidden_size % num_heads) != 0: logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") return None attention_node_name = self.model.create_node_name("MultiHeadAttention") attention_inputs = [ query, - "" if key is None else key, # key - "" if value is None else value, # value + key, + value, "", # bias ] - if mask_index is not None: + + if mask_index: attention_inputs.append(mask_index) else: attention_inputs.append("") - if res_pos_bias is not None: - attention_inputs.append(res_pos_bias) + if attn_bias: + attention_inputs.append(attn_bias) else: attention_inputs.append("") - if past_key is not None: - assert past_value is not None + if past_key: + assert past_value attention_inputs.append(past_key) attention_inputs.append(past_value) + while attention_inputs and attention_inputs[-1] == "": + attention_inputs.pop() + attention_outputs = [output] - if present_key is not None: - assert present_value is not None + if present_key: + assert present_value attention_outputs.append(present_key) attention_outputs.append(present_value) + print(f"{attention_inputs=}, {attention_outputs=}, {attention_node_name=}") attention_node = helper.make_node( "MultiHeadAttention", inputs=attention_inputs, @@ -213,21 +221,23 @@ def create_mha_node( self.increase_counter("MultiHeadAttention") return attention_node - def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): - self.fuse_t5_encoder(normalize_node, input_name_to_nodes, output_name_to_node) - self.fuse_t5_decoder(normalize_node, input_name_to_nodes, output_name_to_node) - - def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_node): - if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add": + def fuse(self, node, input_name_to_nodes, output_name_to_node): + if self.fuse_t5_encoder(node, input_name_to_nodes, output_name_to_node): return - qkv_nodes = self.model.match_parent_path( - normalize_node, ["MatMul", "Reshape", "Transpose", "MatMul"], [1, 0, 0, 0], output_name_to_node + self.fuse_t5_decoder(node, input_name_to_nodes, output_name_to_node) + + def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node): + assert softmax_node.op_type == "Softmax" + qkv_nodes = self.model.match_child_path( + softmax_node, + ["MatMul", "Transpose", "Reshape"], + edges=[(0, 0), (0, 0), (0, 0)], + input_name_to_nodes=input_name_to_nodes, ) if qkv_nodes is None: - return - - _, reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes + return False + matmul_qkv, _, reshape_qkv = qkv_nodes qkv_shape_nodes = self.model.match_parent_path( reshape_qkv, @@ -236,7 +246,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no output_name_to_node, ) if qkv_shape_nodes is None: - return + return False input_shape_node = qkv_shape_nodes[-1] v_nodes = self.model.match_parent_path( @@ -246,7 +256,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no output_name_to_node, ) if v_nodes is None: - return + return False _, reshape_v, matmul_v = v_nodes # todo: check reshape_v parent nodes @@ -257,7 +267,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no output_name_to_node, ) if qk_nodes is None: - return + return False _, add_qk, matmul_qk = qk_nodes mask_nodes = self.model.match_parent_path( @@ -268,7 +278,9 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no ) is_pattern_for_one_graph_input = mask_nodes is None - if mask_nodes is None: + if mask_nodes is not None: + mul_node = mask_nodes[1] + else: # Pattern for SD3 and Flux. mask_nodes = self.model.match_parent_path( add_qk, @@ -276,15 +288,22 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no [1, 1, 0, 0, 1, 0], output_name_to_node, ) + + # If the model is not optimized by ORT, there might be an additional Cast node. if mask_nodes is None: - return + mask_nodes = self.model.match_parent_path( + add_qk, + ["Add", "Slice", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], + [1, 1, 0, 0, 1, 0, 0], + output_name_to_node, + ) + if mask_nodes is None: + return False mul_node = mask_nodes[2] - else: - mul_node = mask_nodes[1] _, mul_val = self.model.get_constant_input(mul_node) if mul_val is None: - return + return False if mul_val != -10000: self.mask_filter_value = float(mul_val) @@ -327,7 +346,7 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no [1, 0, 0], ) if rpb_nodes is None: - return + return False res_pos_bias = rpb_nodes[-1].output[0] @@ -337,8 +356,8 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no [1, 0, 0], ) if k_nodes is None: - return - _, reshape_k, matmul_k = k_nodes + return False + _, _, matmul_k = k_nodes # todo: check reshape_k parent nodes q_nodes = self.model.match_parent_path( @@ -347,50 +366,50 @@ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_no [0, 0, 0], ) if q_nodes is None: - return + return False - transpose_q, reshape_q, matmul_q = q_nodes + _, reshape_q, matmul_q = q_nodes # todo: check reshape_q parent nodes if matmul_q.input[0] != input_shape_node.input[0]: - return + return False q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q) - new_node = self.create_attention_node( + new_node = self.make_attention_node( mask_index, matmul_q, matmul_k, matmul_v, - q_num_heads, - q_hidden_size, - input_shape_node.input[0], - reshape_qkv.output[0], - res_pos_bias, - 1.0, + num_heads=q_num_heads, + hidden_size=q_hidden_size, + input=input_shape_node.input[0], + output=reshape_qkv.output[0], + attn_bias=res_pos_bias, + scale=1.0, ) if new_node is None: - return + return False self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name self.nodes_to_remove.append(reshape_qkv) self.prune_graph = True + return True - def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_node): - if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add": - return + def fuse_t5_decoder(self, softmax_node, input_name_to_nodes, output_name_to_node): + assert softmax_node.op_type == "Softmax" - qkv_nodes = self.model.match_parent_path( - normalize_node, - ["MatMul", "Reshape", "Transpose", "MatMul"], - [1, 0, 0, 0], + qkv_nodes = self.model.match_child_path( + softmax_node, + ["MatMul", "Transpose", "Reshape"], + edges=[(0, 0), (0, 0), (0, 0)], + input_name_to_nodes=input_name_to_nodes, ) if qkv_nodes is None: return - - _, reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes + matmul_qkv, _transpose_qkv, reshape_qkv = qkv_nodes qkv_shape_nodes = self.model.match_parent_path( reshape_qkv, @@ -462,11 +481,17 @@ def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_no ["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [1, 1, 0, 1, 0, 0], ) - if mask_nodes is None: - return - mul_node = mask_nodes[1] - if mask_nodes[1].op_type != "Mul": - return + if mask_nodes is not None: + mul_node = mask_nodes[1] + else: + mask_nodes = self.model.match_parent_path( + add_qk, + ["Add", "Slice", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], + [1, 1, 0, 0, 1, 0, 0], + ) + if mask_nodes is None: + return + mul_node = mask_nodes[2] _, mul_val = self.model.get_constant_input(mul_node) if mul_val != -10000: @@ -474,22 +499,19 @@ def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_no mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) else: - rpb_nodes = self.model.match_parent_path( + matched_path_index, _, _ = self.model.match_parent_paths( add_qk, - ["Add", "Slice"], - [1, 0], + [ + (["Add", "Slice"], [1, 0]), + (["Add", "RelativePositionBias"], [1, 0]), + ], + output_name_to_node, ) - if rpb_nodes is not None: - res_pos_bias = add_qk.input[1] - else: - rpb_nodes = self.model.match_parent_path( - add_qk, - ["Add", "RelativePositionBias"], - [1, 0], - ) - if rpb_nodes is None: - return - res_pos_bias = add_qk.input[1] + if matched_path_index < 0: + logger.debug("Skip MultiHeadAttention fusion since attention bias pattern not matched") + return + + res_pos_bias = add_qk.input[1] key = None past_key = None @@ -608,56 +630,73 @@ def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_no past_key = None past_value = None + if not (key and value and q_num_heads > 0 and q_hidden_size > 0): + return + new_node = self.create_mha_node( - matmul_q.output[0], - key, - value, - mask_index, - res_pos_bias, - past_key, - past_value, - reshape_qkv.output[0], - present_key, - present_value, - q_num_heads, - q_hidden_size, + query=matmul_q.output[0], + key=key, + value=value, + mask_index=mask_index, + attn_bias=res_pos_bias, + past_key=past_key, + past_value=past_value, + output=reshape_qkv.output[0], + present_key=present_key, + present_value=present_value, + num_heads=q_num_heads, + hidden_size=q_hidden_size, ) - if new_node is None: - return - self.nodes_to_add.append(new_node) - self.node_name_to_graph_name[new_node.name] = self.this_graph_name + if new_node: + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name - self.nodes_to_remove.append(reshape_qkv) + # Since present_* is graph output, we need update the graph to avoid circular. + if present_key or present_value: + for graph_output in [present_key, present_value]: + if not (graph_output and self.model.find_graph_output(graph_output)): + print(f"{graph_output=} does not exist in graph output") + return + assert graph_output in output_name_to_node + output_name_to_node[graph_output].output[0] = graph_output + "_copy" + self.model.replace_input_of_all_nodes(graph_output, graph_output + "_copy") - self.prune_graph = True + self.nodes_to_remove.append(reshape_qkv) + self.prune_graph = False class FusionRelativePositionBiasBlock(Fusion): - def __init__(self, model: OnnxModel, max_distance: int): - super().__init__(model, "RelativePositionBias", ["Add", "Slice"]) - self.max_distance = max_distance - self.is_bidirectional = False + def __init__(self, model: OnnxModel): + super().__init__(model, "RelativePositionBias", ["Softmax"]) def fuse(self, node, input_name_to_nodes, output_name_to_node): - # TODO: Optimization opportunity: only last dimension of relative_position_bias is used in decoder. - # Cuda kernel can be optimized to only compute last dimension. - if node.op_type != "Add" and node.op_type != "Slice": - return - compute_bias_nodes = self.model.match_parent_path( - node, ["Unsqueeze", "Transpose", "Gather", "Where"], [0, 0, 0, 1], output_name_to_node + node, + ["Add", "Add", "Slice", "Unsqueeze", "Transpose", "Gather", "Where"], + [0, 1, 0, 0, 0, 0, 1], + output_name_to_node, ) + if compute_bias_nodes is None: compute_bias_nodes = self.model.match_parent_path( - node, ["Unsqueeze", "Transpose", "Gather", "Add", "Where"], [0, 0, 0, 1, 1], output_name_to_node + node, + ["Add", "Add", "Slice", "Unsqueeze", "Transpose", "Gather", "Add", "Where"], + [0, 1, 0, 0, 0, 0, 1, 1], + output_name_to_node, ) if compute_bias_nodes is None: return - gather = compute_bias_nodes[2] + gather = compute_bias_nodes[5] where = compute_bias_nodes[-1] - unsqueeze = compute_bias_nodes[0] + slice = compute_bias_nodes[2] + unsqueeze = compute_bias_nodes[3] + + # Current fusion will not remove the node until the graph is processed. + # This avoids to fuse it again when it is shared by multiple layers. + if unsqueeze in self.nodes_to_remove: + return compute_buckets_nodes = self.model.match_parent_path( where, @@ -668,12 +707,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if compute_buckets_nodes is None: return - # It is possible to deduce max_distance from a Div node: - # The value of self.model.get_constant_value(compute_buckets_nodes[-3].input[1]) is close to - # math.log(max_distance / (relative_attention_num_buckets // (4 if is_bidirectional else 2))) - # See https://github.com/huggingface/transformers/blob/608e163b527eaee41e650ffb9eb4c422d2679902/src/transformers/models/t5/modeling_t5.py#L397. - # Most t5 models use max_distance=128, so we hardcode it unitl we see a model with different value. - # TODO: maybe add a sanity check here. + # This value is to used to compute max_distance later. + log_max = self.model.get_constant_value(compute_buckets_nodes[-3].input[1]) div = compute_buckets_nodes[-1] @@ -683,21 +718,33 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): [0, 0, 0, 1, 0, 0, 0, 0], output_name_to_node, ) + + is_bidirectional = False if range_nodes is None: range_nodes = self.model.match_parent_path( div, ["Cast", "Abs", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 0, 0], output_name_to_node ) - self.is_bidirectional = True + is_bidirectional = True if range_nodes is None: return - range_node = range_nodes[-1] - self.nodes_to_remove.append(unsqueeze) - self.prune_graph = True + # Double check that the constant relative to max_distance and relative_attention_num_buckets. + # Most t5 models use max_distance=128, so we hardcode it unitl we see a model with different value. + + # The log_max is the value of the following formula: + # math.log(max_distance / (relative_attention_num_buckets // (4 if is_bidirectional else 2))) + # See https://github.com/huggingface/transformers/blob/608e163b527eaee41e650ffb9eb4c422d2679902/src/transformers/models/t5/modeling_t5.py#L397. + # Here is the value based on max_distance=128 and relative_attention_num_buckets=32: + max_distance = int(np.round(np.exp(log_max) * (32 // (4 if is_bidirectional else 2)))) + if max_distance != 128: + logger.warning( + f"max_distance is {max_distance}, which is different from the default value 128. " + "Please double check the model configuration." + ) node_name = self.model.create_node_name( - "RelativePositionBias", name_prefix="RelPosBias_" + ("encoder" if self.is_bidirectional else "decoder") + "RelativePositionBias", name_prefix="RelPosBias_" + ("encoder" if is_bidirectional else "decoder") ) table_weight_i = self.model.get_initializer(gather.input[0]) @@ -712,22 +759,64 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): vals=table_weight_t.tobytes(), raw=True, ) - self.model.add_initializer(bias_table, self.this_graph_name) + + # Relative position is like the following in encoder: + # seq_len + # | + # Range(0, *) + # / \ + # Unsqueeze(axes=0) Unsqueeze(axes=1) + # \ / + # Sub + # | + # Abs + # + # Relative position is like the following in decoder: + # past_seq_len seq_len + # \ / + # Add + # / \ + # Range(0, *) Range(0, *) + # \ / + # Sub + # Note that the graph will slice the attention bias to get last seq_len rows. + # + # In new version of transformers, the pattern of decoder is changed like the following + # + # total_seq_len Range(start=past_seq_len, end=total_seq_len) + # | | + # Range(0, *) Unsqueeze(axes=1) + # | | + # Unsqueeze(axes=0) Cast(to=int64) + # \ / + # Sub + # Currently, there is still Slice to get last seq_len rows so end result is same. + # But need to be careful that the shape of bias tensor is changed before Slice. + # + # RelativePositionBias operator requires query_length == key_length so we shall pass in total_seq_len. + # Here we get the end value of the Range node as length to pass to the RelativePositionBias node. + + # TODO: Optimization opportunity: change RelativePositionBias op to support query_length != key_length. + # only compute seq_len rows, then we can remove the Slice after the RelativePositionBias node. inputs = [bias_table.name, range_node.input[1], range_node.input[1]] - outputs = [unsqueeze.output[0]] + + # Use a new tensor name since the shape might be different as mentioned above. + bias_output = node_name + "_rel_pos_bias" + slice.input[0] = bias_output + rpb_node = helper.make_node( "RelativePositionBias", inputs=inputs, - outputs=outputs, + outputs=[bias_output], name=node_name, ) rpb_node.domain = "com.microsoft" - rpb_node.attribute.extend([helper.make_attribute("max_distance", self.max_distance)]) - rpb_node.attribute.extend([helper.make_attribute("is_bidirectional", self.is_bidirectional)]) - - self.nodes_to_add.append(rpb_node) + rpb_node.attribute.extend([helper.make_attribute("max_distance", max_distance)]) + rpb_node.attribute.extend([helper.make_attribute("is_bidirectional", is_bidirectional)]) self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name + self.nodes_to_add.append(rpb_node) + self.prune_graph = True class T5OnnxModel(BertOnnxModel): @@ -744,7 +833,7 @@ def __init__(self, model, num_heads: int = 0, hidden_size: int = 0): self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask) self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self) self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self) - self.rpb_fusion = FusionRelativePositionBiasBlock(self, 128) + self.rpb_fusion = FusionRelativePositionBiasBlock(self) def fuse_attention(self): self.attention_fusion.apply() diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index 7a94519c92bc8..c5cf8a07f557d 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -28,6 +28,10 @@ from onnxruntime.transformers.models.whisper.convert_to_onnx import main as run_whisper +def has_cuda_environment(): + return torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers() + + class TestBeamSearchGpt(unittest.TestCase): """Test BeamSearch for GPT-2 model""" @@ -49,7 +53,7 @@ def setUp(self): # "The selloff in tech shares deepened", # "Abortion rights take center stage", ] - self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers() + self.enable_cuda = has_cuda_environment() self.remove_onnx_files() def tearDown(self): @@ -176,112 +180,253 @@ def test_external_data(self): ) -class TestBeamSearchT5(unittest.TestCase): - """Test BeamSearch for T5 model""" +def get_tiny_t5_model_dir(): + """Get the path to the tiny T5 model directory.""" + # This function is used to get the path to the tiny T5 model directory. + # It is used in the TestBeamSearchT5 and TestBeamSearchT5Fp16 classes. - def setUp(self): - self.model_name = "t5-small" - self.decoder_onnx_path = os.path.join(".", "onnx_models", "t5-small_decoder.onnx") - self.encoder_onnx_path = os.path.join(".", "onnx_models", "t5-small_encoder_decoder_init.onnx") - self.beam_search_onnx_path = os.path.join(".", "onnx_models", "t5_small_beam_search.onnx") - self.default_arguments = [ - f"-m {self.model_name}", + # Path relative to the build\Release directory, where transformers test is launched in pipeline. + tiny_model_dir = os.path.join( + "testdata", + "transformers", + "tiny_t5", + ) + if os.path.exists(tiny_model_dir): + return os.path.normpath(tiny_model_dir) + + # The path is relative to the current file's directory. + tiny_model_dir = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "testdata", + "transformers", + "tiny_t5", + ) + return os.path.normpath(tiny_model_dir) + + +use_tiny_model = True + + +class TestBeamSearchT5(unittest.TestCase): + """Test BeamSearch for T5 model with fp32 in CPU""" + + @classmethod + def setUpClass(cls): + tiny_model_dir = get_tiny_t5_model_dir() + model_name = "tiny_t5" if use_tiny_model and os.path.exists(tiny_model_dir) else "t5-small" + cls.model_name = tiny_model_dir if model_name == "tiny_t5" else "t5-small" + cls.decoder_onnx_path = os.path.join(".", "t5_onnx_models", f"{model_name}_decoder.onnx") + cls.encoder_onnx_path = os.path.join(".", "t5_onnx_models", f"{model_name}_encoder.onnx") + cls.beam_search_onnx_path = os.path.join(".", "t5_onnx_models", f"{model_name}_beam_search.onnx") + cls.default_arguments = [ + f"-m {cls.model_name}", "--model_type t5", - f"--decoder_onnx {self.decoder_onnx_path}", - f"--encoder_decoder_init_onnx {self.encoder_onnx_path}", - f"--output {self.beam_search_onnx_path}", + f"--decoder_onnx {cls.decoder_onnx_path}", + f"--encoder_decoder_init_onnx {cls.encoder_onnx_path}", + f"--output {cls.beam_search_onnx_path}", "--output_sequences_score", "--repetition_penalty 2.0", ] - self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers() + # Remove onnx files if existed for any reason. + cls.remove_onnx_files() - export_t5_onnx_models( - self.model_name, + # This is in class setup so that we only export t5 model once. + paths = export_t5_onnx_models( + cls.model_name, os.path.join(".", "cache_models"), - os.path.join(".", "onnx_models"), + os.path.join(".", "t5_onnx_models"), use_gpu=False, use_external_data_format=False, optimize_onnx=False, - precision=Precision.FLOAT32, + precision=Precision.FLOAT32.value, verbose=False, use_decoder_start_token=False, - merge_encoder_and_decoder_init=True, overwrite=True, disable_auto_mixed_precision=False, use_int32_inputs=True, ) + assert len(paths) == 2 - self.sentences = [ + cls.sentences = [ "translate English to French: The product is released", "summarize: research continues to show that pets bring real health benefits to their owners. Having a dog around can lead to lower levels of stress for both adults and kids.", ] - if os.path.exists(self.beam_search_onnx_path): - os.remove(self.beam_search_onnx_path) + @classmethod + def remove_onnx_files(cls, beam_search_onnx_only: bool = False): + if os.path.exists(cls.beam_search_onnx_path): + os.remove(cls.beam_search_onnx_path) + if os.path.exists(cls.beam_search_onnx_path + ".data"): + os.remove(cls.beam_search_onnx_path + ".data") - def tearDown(self): - self.remove_onnx_files() + if not beam_search_onnx_only: + if os.path.exists(cls.encoder_onnx_path): + os.remove(cls.encoder_onnx_path) + if os.path.exists(cls.decoder_onnx_path): + os.remove(cls.decoder_onnx_path) - def remove_onnx_files(self): - if os.path.exists(self.beam_search_onnx_path): - os.remove(self.beam_search_onnx_path) + @classmethod + def tearDownClass(cls): + # cls.remove_onnx_files() + pass - if os.path.exists(self.decoder_onnx_path): - os.remove(self.decoder_onnx_path) + def setUp(self): + pass - if os.path.exists(self.encoder_onnx_path): - os.remove(self.encoder_onnx_path) + def tearDown(self): + # self.remove_onnx_files(beam_search_onnx_only=True) + pass - def run_beam_search(self, extra_arguments: str, sentences=None, append_arguments=True): - if append_arguments: - arguments = " ".join([*self.default_arguments, extra_arguments]).split() - else: - arguments = extra_arguments.split() + def run_beam_search(self, extra_arguments: str): + arguments = " ".join([*self.default_arguments, extra_arguments]).split() # Test CPU - result = run(arguments, sentences=self.sentences if sentences is None else sentences) + result = run(arguments) self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}") - # Test GPU - if self.enable_cuda: - if "--use_gpu" not in arguments: - arguments.append("--use_gpu") - result = run(arguments, sentences=self.sentences if sentences is None else sentences) - self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}") - - os.remove(self.beam_search_onnx_path) - - @pytest.mark.slow def test_return_sequences(self): for return_sequences in [1, 2]: self.run_beam_search(f"--num_return_sequences {return_sequences}") - @pytest.mark.slow def test_early_stopping(self): self.run_beam_search("--early_stopping") - @pytest.mark.slow def test_length_penalty(self): for length_penalty in [0.5, 2.0]: self.run_beam_search(f"--length_penalty {length_penalty}") - @pytest.mark.slow def test_no_repeat_ngram(self): for ngram_size in [1, 2]: self.run_beam_search(f"--no_repeat_ngram_size {ngram_size}") - @pytest.mark.slow def test_custom_attention_mask(self): self.run_beam_search("--custom_attention_mask") - @pytest.mark.slow def test_external_data(self): - self.run_beam_search( - f"-m t5-small --model_type t5 -e --output {self.beam_search_onnx_path}", - sentences=None, - append_arguments=False, - ) + self.run_beam_search("-e") + + +@unittest.skipUnless( + has_cuda_environment(), + "skip since there is no cuda environment.", +) +class TestBeamSearchT5Fp16(unittest.TestCase): + """Test BeamSearch for T5 model with fp16 in GPU""" + + @classmethod + def setUpClass(cls): + tiny_model_dir = get_tiny_t5_model_dir() + tiny_model_dir = os.path.normpath(tiny_model_dir) + cls.model_name = "tiny_t5" if use_tiny_model and os.path.exists(tiny_model_dir) else "t5-small" + cls.model_id = tiny_model_dir if cls.model_name == "tiny_t5" else "t5-small" + cls.beam_search_onnx_path = os.path.join(".", "onnx_models", f"{cls.model_name}_beam_search_fp16.onnx") + cls.default_arguments = [ + f"-m {cls.model_id}", + "--model_type t5", + f"--output {cls.beam_search_onnx_path}", + "--min_length 2", + "--max_length 16", + "--use_gpu", + "-p fp16", + ] + + cls.sentences = [ + "translate English to French: The product is released", + "summarize: research continues to show that pets bring real health benefits to their owners. Having a dog around can lead to lower levels of stress for both adults and kids.", + ] + + cls.remove_onnx_files() + + @classmethod + def remove_onnx_files(cls): + model_name = cls.model_name + for file in [ + f"{model_name}_beam_search_fp16.onnx", + f"{model_name}_encoder.onnx", + f"{model_name}_encoder_fp16.onnx", + f"{model_name}_decoder.onnx", + f"{model_name}_decoder_fp16.onnx", + ]: + if os.path.exists(os.path.join(".", "onnx_models", file)): + os.remove(os.path.join(".", "onnx_models", file)) + if os.path.exists(os.path.join(".", "onnx_models", file + ".data")): + os.remove(os.path.join(".", "onnx_models", file + ".data")) + + def setUp(self): + pass + + def tearDown(self): + self.remove_onnx_files() + + def check_encoder_fusion(self): + model_name = self.model_name + onnx_path = os.path.join(".", "onnx_models", f"{model_name}_encoder_fp16.onnx") + + model = onnx.load_model(onnx_path, format=None, load_external_data=True) + from onnxruntime.transformers.onnx_model import OnnxModel + + onnx_model = OnnxModel(model) + op_counters = onnx_model.get_operator_statistics() + print("encoder ops", op_counters) + + expected_node_count = { + "RelativePositionBias": 1, + "SimplifiedLayerNormalization": 5 if use_tiny_model else 13, + "Attention": 2 if use_tiny_model else 6, + } + for key, value in expected_node_count.items(): + self.assertIn(key, op_counters, f"Expected {key} to be in op_counters") + self.assertEqual(op_counters[key], value, f"Expected {key} to be {value}, but got {op_counters[key]}") + + def check_decoder_fusion(self): + model_name = self.model_name + onnx_path = os.path.join(".", "onnx_models", f"{model_name}_decoder_fp16.onnx") + + model = onnx.load_model(onnx_path, format=None, load_external_data=True) + from onnxruntime.transformers.onnx_model import OnnxModel + + onnx_model = OnnxModel(model) + op_counters = onnx_model.get_operator_statistics() + print("decoder ops", op_counters) + + expected_node_count = { + "RelativePositionBias": 1, + "SimplifiedLayerNormalization": 7 if use_tiny_model else 19, + "MultiHeadAttention": 4 if use_tiny_model else 12, + } + for key, value in expected_node_count.items(): + self.assertIn(key, op_counters, f"Expected {key} to be in op_counters") + self.assertEqual(op_counters[key], value, f"Expected {key} to be {value}, but got {op_counters[key]}") + + def run_beam_search(self, extra_arguments: str): + arguments = " ".join([*self.default_arguments, extra_arguments]).split() + result = run(arguments) + self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}") + + def test_return_sequences(self): + for return_sequences in [1, 2]: + self.run_beam_search(f"--num_return_sequences {return_sequences}") + + def test_early_stopping(self): + self.run_beam_search("--early_stopping") + + def test_length_penalty(self): + for length_penalty in [0.5, 2.0]: + self.run_beam_search(f"--length_penalty {length_penalty}") + + def test_no_repeat_ngram(self): + for ngram_size in [1, 2]: + self.run_beam_search(f"--no_repeat_ngram_size {ngram_size}") + + def test_external_data(self): + self.run_beam_search("-e") + + # Ensure fusion is done correctly. + self.check_encoder_fusion() + self.check_decoder_fusion() class TestBeamSearchWhisper(unittest.TestCase): @@ -294,7 +439,7 @@ def setUp(self): self.decoder_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_decoder.onnx") self.encoder_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_encoder.onnx") self.beam_search_onnx_path = os.path.join(".", self.onnx_folder, "whisper-tiny_beamsearch.onnx") - self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers() + self.enable_cuda = has_cuda_environment() self.base_arguments = [ "-m", diff --git a/onnxruntime/test/testdata/transformers/tiny_t5/added_tokens.json b/onnxruntime/test/testdata/transformers/tiny_t5/added_tokens.json new file mode 100644 index 0000000000000..3f5132007c4fc --- /dev/null +++ b/onnxruntime/test/testdata/transformers/tiny_t5/added_tokens.json @@ -0,0 +1,102 @@ +{ + "": 32099, + "": 32089, + "": 32088, + "": 32087, + "": 32086, + "": 32085, + "": 32084, + "": 32083, + "": 32082, + "": 32081, + "": 32080, + "": 32098, + "": 32079, + "": 32078, + "": 32077, + "": 32076, + "": 32075, + "": 32074, + "": 32073, + "": 32072, + "": 32071, + "": 32070, + "": 32097, + "": 32069, + "": 32068, + "": 32067, + "": 32066, + "": 32065, + "": 32064, + "": 32063, + "": 32062, + "": 32061, + "": 32060, + "": 32096, + "": 32059, + "": 32058, + "": 32057, + "": 32056, + "": 32055, + "": 32054, + "": 32053, + "": 32052, + "": 32051, + "": 32050, + "": 32095, + "": 32049, + "": 32048, + "": 32047, + "": 32046, + "": 32045, + "": 32044, + "": 32043, + "": 32042, + "": 32041, + "": 32040, + "": 32094, + "": 32039, + "": 32038, + "": 32037, + "": 32036, + "": 32035, + "": 32034, + "": 32033, + "": 32032, + "": 32031, + "": 32030, + "": 32093, + "": 32029, + "": 32028, + "": 32027, + "": 32026, + "": 32025, + "": 32024, + "": 32023, + "": 32022, + "": 32021, + "": 32020, + "": 32092, + "": 32019, + "": 32018, + "": 32017, + "": 32016, + "": 32015, + "": 32014, + "": 32013, + "": 32012, + "": 32011, + "": 32010, + "": 32091, + "": 32009, + "": 32008, + "": 32007, + "": 32006, + "": 32005, + "": 32004, + "": 32003, + "": 32002, + "": 32001, + "": 32000, + "": 32090 +} diff --git a/onnxruntime/test/testdata/transformers/tiny_t5/config.json b/onnxruntime/test/testdata/transformers/tiny_t5/config.json new file mode 100644 index 0000000000000..d649732da246f --- /dev/null +++ b/onnxruntime/test/testdata/transformers/tiny_t5/config.json @@ -0,0 +1,60 @@ +{ + "architectures": [ + "T5ForConditionalGeneration" + ], + "classifier_dropout": 0.0, + "d_ff": 16, + "d_kv": 4, + "d_model": 8, + "decoder_start_token_id": 0, + "dense_act_fn": "relu", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "relu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": false, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "n_positions": 512, + "num_decoder_layers": 2, + "num_heads": 2, + "num_layers": 2, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "task_specific_params": { + "summarization": { + "early_stopping": true, + "length_penalty": 2.0, + "max_length": 200, + "min_length": 30, + "no_repeat_ngram_size": 3, + "num_beams": 4, + "prefix": "summarize: " + }, + "translation_en_to_de": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to German: " + }, + "translation_en_to_fr": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to French: " + }, + "translation_en_to_ro": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to Romanian: " + } + }, + "torch_dtype": "float32", + "transformers_version": "4.42.4", + "use_cache": true, + "vocab_size": 1024 +} diff --git a/onnxruntime/test/testdata/transformers/tiny_t5/generation_config.json b/onnxruntime/test/testdata/transformers/tiny_t5/generation_config.json new file mode 100644 index 0000000000000..6f2a63c77c1b9 --- /dev/null +++ b/onnxruntime/test/testdata/transformers/tiny_t5/generation_config.json @@ -0,0 +1,7 @@ +{ + "_from_model_config": true, + "decoder_start_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 0, + "transformers_version": "4.42.4" +} diff --git a/onnxruntime/test/testdata/transformers/tiny_t5/model.safetensors b/onnxruntime/test/testdata/transformers/tiny_t5/model.safetensors new file mode 100644 index 0000000000000..1b90602ed0709 Binary files /dev/null and b/onnxruntime/test/testdata/transformers/tiny_t5/model.safetensors differ diff --git a/onnxruntime/test/testdata/transformers/tiny_t5/special_tokens_map.json b/onnxruntime/test/testdata/transformers/tiny_t5/special_tokens_map.json new file mode 100644 index 0000000000000..17ade346a1042 --- /dev/null +++ b/onnxruntime/test/testdata/transformers/tiny_t5/special_tokens_map.json @@ -0,0 +1,125 @@ +{ + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/onnxruntime/test/testdata/transformers/tiny_t5/spiece.model b/onnxruntime/test/testdata/transformers/tiny_t5/spiece.model new file mode 100644 index 0000000000000..16ff05c4dd0f9 Binary files /dev/null and b/onnxruntime/test/testdata/transformers/tiny_t5/spiece.model differ diff --git a/onnxruntime/test/testdata/transformers/tiny_t5/tiny_t5.py b/onnxruntime/test/testdata/transformers/tiny_t5/tiny_t5.py new file mode 100644 index 0000000000000..6a25cb89f6327 --- /dev/null +++ b/onnxruntime/test/testdata/transformers/tiny_t5/tiny_t5.py @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os + +from sentencepiece import SentencePieceProcessor, SentencePieceTrainer +from transformers import T5Config, T5ForConditionalGeneration, T5Tokenizer + +hidden_size = 8 + +vocab_size = 1024 +save_directory = "tiny_t5" +model_name = "google-t5/t5-small" + +config = T5Config.from_pretrained(model_name) + +config.num_heads = 2 + +if vocab_size: + config.vocab_size = 1024 + +config.d_model = hidden_size +config.d_kv = hidden_size // config.num_heads +config.d_ff = hidden_size * 2 +config.num_layers = 2 +config.num_decoder_layers = config.num_layers + +model = T5ForConditionalGeneration(config) + +model.save_pretrained(save_directory) + +tokenizer = T5Tokenizer.from_pretrained(model_name, legacy=False) +tokenizer.save_pretrained(save_directory) + + +def update_tokenizer(sp_model_path: str, vocab_size: int): + sp = SentencePieceProcessor() + sp.Load(sp_model_path) + + # Export the vocabulary + with open("vocab.txt", "w", encoding="utf-8") as f: + for id in range(sp.GetPieceSize()): + piece = sp.IdToPiece(id) + score = sp.GetScore(id) + f.write(f"{piece}\t{score}\n") + + with open("vocab.txt", encoding="utf-8") as f: + vocab = [line.strip().split("\t") for line in f] + + # Sort by score in descending order and select top tokens + vocab_sorted = sorted(vocab, key=lambda x: float(x[1]), reverse=True) + pruned_vocab = vocab_sorted[:vocab_size] + + # Write the pruned vocabulary to a new file + with open("pruned_vocab.txt", "w", encoding="utf-8") as f: + for piece, score in pruned_vocab: + f.write(f"{piece}\t{score}\n") + + # Train a new SentencePiece model using the pruned vocabulary as a seed. + # Example corpus.txt can be found by searching "corpus.txt download" in search engine. + SentencePieceTrainer.Train( + f"--input=corpus.txt --model_prefix=spiece --vocab_size={vocab_size} --user_defined_symbols=pruned_vocab.txt" + ) + + # Load the new model + sp_new = SentencePieceProcessor() + sp_new.Load("spiece.model") + + # Test encoding and decoding + text = "This is an example sentence." + tokens = sp_new.EncodeAsPieces(text) + print(tokens) + + detokenized_text = sp_new.DecodePieces(tokens) + print(detokenized_text) + + # Replace the original model. + os.replace("spiece.model", sp_model_path) + + +if vocab_size: + original_path = os.path.join(save_directory, "spiece.model") + update_tokenizer(original_path, vocab_size) diff --git a/onnxruntime/test/testdata/transformers/tiny_t5/tokenizer_config.json b/onnxruntime/test/testdata/transformers/tiny_t5/tokenizer_config.json new file mode 100644 index 0000000000000..da3a2f5a033d6 --- /dev/null +++ b/onnxruntime/test/testdata/transformers/tiny_t5/tokenizer_config.json @@ -0,0 +1,940 @@ +{ + "add_prefix_space": true, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32000": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32001": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32002": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32003": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32004": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32005": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32006": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32007": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32008": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32009": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32010": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32011": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32012": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32013": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32014": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32015": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32016": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32017": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32018": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32019": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32020": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32021": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32022": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32023": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32024": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32025": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32026": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32027": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32028": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32029": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32030": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32031": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32032": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32033": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32034": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32035": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32036": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32037": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32038": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32039": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32040": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32041": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32042": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32043": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32044": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32045": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32046": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32047": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32048": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32049": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32050": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32051": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32052": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32053": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32054": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32055": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32056": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32057": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32058": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32059": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32060": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32061": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32062": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32063": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32064": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32065": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32066": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32067": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32068": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32069": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32070": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32071": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32072": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32073": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32074": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32075": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32076": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32077": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32078": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32079": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32080": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32081": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32082": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32083": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32084": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32085": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32086": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32087": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32088": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32089": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32090": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32091": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32092": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32093": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32094": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32095": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32096": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32097": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32098": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32099": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "clean_up_tokenization_spaces": true, + "eos_token": "", + "extra_ids": 100, + "legacy": false, + "model_max_length": 512, + "pad_token": "", + "sp_model_kwargs": {}, + "tokenizer_class": "T5Tokenizer", + "unk_token": "" +} diff --git a/tools/ci_build/requirements/transformers-test/requirements.txt b/tools/ci_build/requirements/transformers-test/requirements.txt index 14aeff3df9c62..0fb37e3a1550a 100644 --- a/tools/ci_build/requirements/transformers-test/requirements.txt +++ b/tools/ci_build/requirements/transformers-test/requirements.txt @@ -8,5 +8,6 @@ torch coloredlogs==15.0 transformers==4.46.3 parameterized>=0.8.1 +sentencepiece psutil einops