diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 4cc3533f45e79..c37ae8caa5eef 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -183,7 +183,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy, device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy, create_encoder_inputs_func_ ? create_encoder_inputs_func_ : BeamSearchCpuDeviceHelper::CreateEncoderInputs, - update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds}; + update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds, + expand_buffer_int32_func_ ? expand_buffer_int32_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer, + expand_buffer_float_func_ ? expand_buffer_float_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer, + expand_buffer_float16_func_ ? expand_buffer_float16_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer}; ORT_RETURN_IF_ERROR(impl.Initialize()); return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); @@ -198,7 +201,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { device_copy_func_, device_copy_int32_func_, create_encoder_inputs_func_, - update_decoder_feeds_fp16_func_}; + update_decoder_feeds_fp16_func_, + expand_buffer_int32_func_, + expand_buffer_float_func_, + expand_buffer_float16_func_}; ORT_RETURN_IF_ERROR(impl.Initialize()); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 3a0de820106e8..f9cb7d66c585a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -74,9 +74,15 @@ class BeamSearch : public IControlFlowKernel { // device helpers for encoder-decoder model like T5 void SetDeviceHelpers_EncoderDecoder( const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func, - const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_fp16_func) { + const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_fp16_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func) { update_decoder_feeds_func_ = update_decoder_feeds_func; update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func; + expand_buffer_int32_func_ = expand_buffer_int32_func; + expand_buffer_float_func_ = expand_buffer_float_func; + expand_buffer_float16_func_ = expand_buffer_float16_func; } private: @@ -106,6 +112,10 @@ class BeamSearch : public IControlFlowKernel { BeamSearchDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_func_; BeamSearchDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_fp16_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_int32_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_float_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_float16_func_; + //------------------------------------------------------------ // Subgraph and FeedsFetchesManager re-used for each subgraph execution. //------------------------------------------------------------ diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc index 2e6af56fc4736..7b163dd923a31 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc @@ -60,6 +60,48 @@ void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, } } +// TODO(wy): Dispatch it to avoid passing multiple functions to interface. +template +Status ExpandBuffer(void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape) { + // Input shape (batch_size, xxx). The input is required with data type T. + // Output shape (batch_size * num_beams, xxx) + ORT_UNUSED_PARAMETER(stream); + + const TensorShape& input_shape = input.Get().Shape(); + const int64_t& batch_size = input_shape[0]; + const int64_t& chunk_size = static_cast(input_shape.Size() / batch_size); + + int64_t dims[4] = {0}; + input_shape.CopyDims(dims, input_shape.NumDimensions()); + dims[0] = batch_size * num_beams; + TensorShape expanded_shape(&dims[0], input_shape.NumDimensions()); + + MLDataType element_type = input.Get().DataType(); + ORT_ENFORCE(element_type == DataTypeImpl::GetType()); + Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded); + + if (only_copy_shape) { + return Status::OK(); + } + + const T* input_data = input.Get().Data(); + T* expanded_data = expanded.GetMutable()->MutableData(); + T* target = expanded_data; + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < num_beams; j++) { + memcpy(target, input_data + i * chunk_size, sizeof(T) * chunk_size); + target += chunk_size; + } + } + + return Status::OK(); +} + Status CreateGptInputs( const Tensor* original_input_ids, int num_beams, @@ -200,37 +242,45 @@ Status ProcessLogits(const OrtValue& logits, // const TensorShape& logits_shape = logits.Get().Shape(); ORT_ENFORCE(logits_shape.NumDimensions() == 3); auto input_length = logits_shape[1]; + auto logits_batch_size = logits_shape[0]; // Get logits for the last token: // next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size) // When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1. gsl::span& next_token_logits = beam_state->next_token_logits; - if (input_length > 1) { + + if (input_length > 1 || logits_batch_size == batch_size) { const T* current_logits = logits_data + (input_length - 1) * vocab_size; for (int i = 0; i < batch_beam_size; i++) { gsl::span source(current_logits, vocab_size); gsl::span target = next_token_logits.subspan(SafeInt(i) * vocab_size, static_cast(vocab_size)); gsl::copy(source, target); - current_logits += input_length * vocab_size; + if (logits_batch_size == batch_beam_size) { + current_logits += input_length * vocab_size; + } else if (logits_batch_size == batch_size && i % num_beams == num_beams - 1) { + current_logits += input_length * vocab_size; + } } } #ifdef DEBUG_BEAM_SEARCH dumper->Print("logits", logits); - if (input_length > 1) { + if (input_length > 1 || logits_batch_size == batch_size) { dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size); } #endif // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) gsl::span& next_token_scores = beam_state->next_token_scores; - ORT_RETURN_IF_ERROR(SoftmaxCPU(batch_beam_size, // rows - vocab_size, // elements per row - input_length > 1 ? next_token_logits.data() : logits_data, - next_token_scores.data(), - true, - thread_pool)); + ORT_RETURN_IF_ERROR( + SoftmaxCPU( + batch_beam_size, // rows + vocab_size, // elements per row + (input_length == 1 && logits_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(), + next_token_scores.data(), + true, + thread_pool)); #ifdef DEBUG_BEAM_SEARCH dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size); @@ -456,13 +506,12 @@ Status UpdateGptFeeds( Status CreateEncoderInputs( const Tensor* original_encoder_input_ids, const OrtValue* attn_mask_value, - int num_beams, int pad_token_id, int start_token_id, AllocatorPtr allocator, - OrtValue& expanded_encoder_input_ids, - OrtValue& expanded_encoder_attention_mask, - OrtValue& expanded_decoder_input_ids) { + OrtValue& encoder_input_ids, + OrtValue& encoder_attention_mask, + OrtValue& decoder_input_ids) { const TensorShape& input_ids_shape = original_encoder_input_ids->Shape(); ORT_ENFORCE(input_ids_shape.NumDimensions() == 2); const int64_t& batch_size = input_ids_shape[0]; @@ -475,14 +524,12 @@ Status CreateEncoderInputs( // Current shape is (batch_size, sequence_length) // Note that we will expand it to (batch_size * num_beams, sequence_length) later. // To avoid cloning input_ids, we use const_cast here since this function does not change its content. - OrtValue encoder_input_ids; Tensor::InitOrtValue(element_type, input_ids_shape, const_cast(original_encoder_input_ids)->MutableData(), allocator->Info(), encoder_input_ids); - OrtValue encoder_attention_mask; if (attn_mask_value != nullptr) { const Tensor& attention_mask = attn_mask_value->Get(); Tensor::InitOrtValue(element_type, input_ids_shape, const_cast(&attention_mask)->MutableData(), @@ -511,20 +558,14 @@ Status CreateEncoderInputs( } } - // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) - // for encoder_input_ids and encoder_attention_mask - // TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance. - ExpandInputs(encoder_input_ids, num_beams, allocator, expanded_encoder_input_ids); - ExpandInputs(encoder_attention_mask, num_beams, allocator, expanded_encoder_attention_mask); - // decoder_input_ids is optional. if (start_token_id >= 0) { - // Expanded decoder_input_ids has shape (batch_size * num_beams, 1), and filled with start token ID - int64_t dims[] = {batch_size * num_beams, 1}; + // Filled decoder_input_ids with start token ID + int64_t dims[] = {batch_size, 1}; TensorShape decoder_input_ids_shape(&dims[0], 2); - Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, expanded_decoder_input_ids); - int32_t* data = expanded_decoder_input_ids.GetMutable()->MutableData(); - for (int i = 0; i < batch_size * num_beams; i++, data++) { + Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, decoder_input_ids); + int32_t* data = decoder_input_ids.GetMutable()->MutableData(); + for (int i = 0; i < batch_size; i++, data++) { *data = start_token_id; } } @@ -602,7 +643,7 @@ Status UpdateDecoderFeeds( TensorShape input_ids_shape(&dims[0], 2); Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); - // TODO: decouple has_hidden_state with full input_ids + // TODO(wy): decouple has_hidden_state with full input_ids if (has_hidden_state) { gsl::copy(beam_next_tokens, input_ids.GetMutable()->MutableDataAsSpan()); } else { @@ -709,6 +750,30 @@ template Status UpdateDecoderFeeds( template void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + } // namespace BeamSearchCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h index 8cd7a0291af0c..ab18eec25cde0 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h @@ -107,13 +107,12 @@ using UpdateGptFeedsFunc = std::function; + OrtValue& encoder_input_ids, + OrtValue& encoder_attention_mask, + OrtValue& decoder_input_ids)>; // Update decoder inputs given decoder outputs of last iteration (for encoder-decoder model like T5). template @@ -132,8 +131,18 @@ using UpdateDecoderFeedsFunc = std::function; + +template +using ExpandBufferFunc = std::function; } // namespace BeamSearchDeviceHelper + // These are CPU specific device helper implementations namespace BeamSearchCpuDeviceHelper { Status TopK( @@ -212,13 +221,12 @@ Status UpdateGptFeeds( Status CreateEncoderInputs( const Tensor* original_encoder_input_ids, const OrtValue* attn_mask_value, - int num_beams, int pad_token_id, int start_token_id, AllocatorPtr allocator, - OrtValue& expanded_encoder_input_ids, - OrtValue& expanded_encoder_attention_mask, - OrtValue& expanded_decoder_input_ids); + OrtValue& encoder_input_ids, + OrtValue& encoder_attention_mask, + OrtValue& decoder_input_ids); // Update decoder inputs given decoder outputs of last iteration. template @@ -244,6 +252,15 @@ Status UpdateDecoderFeeds( template void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); +template +Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + } // namespace BeamSearchCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 03c88b5aa8047..790e96a4476b5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -95,7 +95,7 @@ struct BeamSearchCpuState : public IBeamSearchCpuState { this->sequences.Init(this->sequences_space, static_cast(batch_beam_size), sequence_length, max_length); } - // Copy input_ids to sequences[0] + // Copy expanded input_ids to sequences[0] void SetSequence(gsl::span input_ids_in_cpu, size_t batch_beam_size, int max_length, @@ -109,6 +109,21 @@ struct BeamSearchCpuState : public IBeamSearchCpuState { } } + // Copy unexpanded input_ids to sequences[0] + void SetSequence(gsl::span input_ids_in_cpu, + size_t batch_beam_size, + int beam_size, + int max_length, + int sequence_length) { + gsl::span sequences_0 = sequences_space; + for (size_t i = 0; i < batch_beam_size; i++) { + for (int j = 0; j < sequence_length; j++) { + const size_t index = SafeInt(i) * max_length + j; + sequences_0[index] = input_ids_in_cpu[SafeInt(i / beam_size) * sequence_length + j]; + } + } + } + private: BufferUniquePtr final_beam_scores_buffer_; BufferUniquePtr sequence_lengths_buffer_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 6acbe857e5a72..5360bfd8f4cfd 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -33,7 +33,10 @@ class BeamSearchT5 : public BeamSearchBase { const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_func, const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func, - const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func) + const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func) : BeamSearchBase(context, decoder_session_state, thread_pool, cuda_stream, cuda_dumper, params, topk_func, process_logits_func, device_copy_func, device_copy_int32_func), @@ -43,7 +46,10 @@ class BeamSearchT5 : public BeamSearchBase { add_to_feeds_func_(add_to_feeds_func), init_beam_state_func_(init_beam_state_func), create_encoder_inputs_func_(create_encoder_inputs_func), - update_decoder_feeds_func_(update_decoder_feeds_func) { + update_decoder_feeds_func_(update_decoder_feeds_func), + expand_buffer_int32_func_(expand_buffer_int32_func), + expand_buffer_float_func_(expand_buffer_float_func), + expand_buffer_float16_func_(expand_buffer_float16_func) { } // Execute beam search in iterations util stopping criteria is reached. @@ -62,6 +68,9 @@ class BeamSearchT5 : public BeamSearchBase { BeamSearchDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_; BeamSearchDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_int32_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_float_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_float16_func_; }; template @@ -110,19 +119,18 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches this->IsCuda()); IAllocatorUniquePtr buffer; - OrtValue expanded_decoder_input_ids; // Tensor in CPU, and it will be used to initialize sequence in cpu_state + OrtValue decoder_input_ids; // Tensor in CPU, and it will be used to initialize sequence in cpu_state ORT_RETURN_IF_ERROR(this->encoder_subgraph_.CreateInitialFeeds( encoder_input_ids, encoder_attn_mask_value, this->implicit_inputs_, - parameters->num_beams, parameters->pad_token_id, parameters->decoder_start_token_id, encoder_feeds, this->create_encoder_inputs_func_, this->add_to_feeds_func_, buffer, - expanded_decoder_input_ids)); + decoder_input_ids)); ORT_RETURN_IF_ERROR(utils::ExecuteSubgraph(this->encoder_session_state_, encoder_feeds_fetches_manager, @@ -150,9 +158,10 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches // Initialize resources // ------------------------------------ - // Copy expanded_decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam. - cpu_state.SetSequence(expanded_decoder_input_ids.Get().DataAsSpan(), + // Copy decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam. + cpu_state.SetSequence(decoder_input_ids.Get().DataAsSpan(), static_cast(parameters->BatchBeamSize()), + parameters->num_beams, parameters->max_length, parameters->sequence_length); @@ -211,6 +220,10 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches encoder_fetches, decoder_feeds, this->device_copy_int32_func_, + this->expand_buffer_int32_func_, + this->expand_buffer_float_func_, + this->expand_buffer_float16_func_, + parameters->num_beams, this->cuda_stream_)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index bb931255dd177..7918a333094c4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -121,6 +121,10 @@ Status T5DecoderSubgraph::CreateInitialFeeds( const std::vector& encoder_fetches, std::vector& decoder_feeds, const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, + int num_beam, void* stream) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); @@ -144,13 +148,58 @@ Status T5DecoderSubgraph::CreateInitialFeeds( decoder_feeds.push_back(input_ids); // The encoder_attention_mask is copied from the second input of encoder. - decoder_feeds.push_back(encoder_feeds[1]); + OrtValue expanded_decoder_attention_masks; + ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream, + encoder_feeds[1], + num_beam, + allocator, + expanded_decoder_attention_masks, + false)); + + decoder_feeds.push_back(expanded_decoder_attention_masks); // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output // of encoder. // When first_past_input_index_ == 2, the past states are copied from the second output of encoder. for (size_t j = 4 - first_past_input_index_; j < encoder_fetches.size(); j++) { - decoder_feeds.push_back(encoder_fetches[j]); + if (j == 1) { + ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false"); + OrtValue expanded_hidden_states; + if (is_output_float16_) { + ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream, + encoder_fetches[j], + num_beam, + allocator, + expanded_hidden_states, + true)); + } else { + ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, + encoder_fetches[j], + num_beam, + allocator, + expanded_hidden_states, + true)); + } + decoder_feeds.push_back(expanded_hidden_states); + } else { + OrtValue expanded_cache; + if (is_output_float16_) { + ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream, + encoder_fetches[j], + num_beam, + allocator, + expanded_cache, + false)); + } else { + ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, + encoder_fetches[j], + num_beam, + allocator, + expanded_cache, + false)); + } + decoder_feeds.push_back(expanded_cache); + } } // Pass through implicit inputs. diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 108b1c298d759..edf7293a978c1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -28,6 +28,10 @@ class T5DecoderSubgraph : public Subgraph { const std::vector& encoder_fetches, std::vector& decoder_feeds, const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, + int num_beam, void* stream); Status Validate(const std::vector& subgraph_inputs, diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc index 7574c31ec5b6e..153deaa8cb822 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -96,24 +96,23 @@ Status T5EncoderSubgraph::Validate(const std::vector& subgraph_i // Create inputs for first inference of subgraph. Status T5EncoderSubgraph::CreateInitialFeeds( - const Tensor& encoder_input_ids, + const Tensor& original_encoder_input_ids, const OrtValue* attn_mask_value, const std::vector& implicit_inputs, - int num_beams, int pad_token_id, int start_token_id, std::vector& feeds, const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func, const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, IAllocatorUniquePtr& buffer, - OrtValue& expanded_decoder_input_ids) { + OrtValue& decoder_input_ids) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); // The ordering is the same as used in Setup. feeds.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); // Allocate subgraph inputs to be same device as encoder_input_ids. - AllocatorPtr cpu_allocator = session_state_->GetAllocator(encoder_input_ids.Location()); + AllocatorPtr cpu_allocator = session_state_->GetAllocator(original_encoder_input_ids.Location()); if (cpu_allocator == nullptr) { const IExecutionProvider* provider = GetProvider(); cpu_allocator = provider->GetAllocator(0, OrtMemTypeDefault); @@ -121,22 +120,21 @@ Status T5EncoderSubgraph::CreateInitialFeeds( ORT_RETURN_IF(cpu_allocator == nullptr, "cpu_allocator shouldn't be nullptr"); // TODO(tianleiwu): expand the outputs instead of inputs to save computation. - OrtValue expanded_encoder_input_ids; - OrtValue expanded_encoder_attention_mask; - ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&encoder_input_ids, + OrtValue encoder_input_ids; + OrtValue encoder_attention_mask; + ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&original_encoder_input_ids, attn_mask_value, - num_beams, pad_token_id, start_token_id, cpu_allocator, - expanded_encoder_input_ids, - expanded_encoder_attention_mask, - expanded_decoder_input_ids)); + encoder_input_ids, + encoder_attention_mask, + decoder_input_ids)); const IExecutionProvider* provider = GetProvider(); ORT_RETURN_IF_ERROR(add_to_feeds_func( provider, - {expanded_encoder_input_ids, expanded_encoder_attention_mask, expanded_decoder_input_ids}, + {encoder_input_ids, encoder_attention_mask, decoder_input_ids}, feeds, buffer)); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h index 83c9cb22c66a8..9c67f49621357 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h @@ -24,7 +24,6 @@ class T5EncoderSubgraph : public Subgraph { const Tensor& encoder_input_ids, const OrtValue* attn_mask_value, const std::vector& implicit_inputs, - int num_beams, int pad_token_id, int start_token_id, std::vector& feeds, diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 91b660c197e1f..5e500f560131f 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -49,7 +49,10 @@ BeamSearch::BeamSearch(const OpKernelInfo& info) BeamSearchCudaDeviceHelper::UpdateGptFeeds); SetDeviceHelpers_EncoderDecoder(BeamSearchCudaDeviceHelper::UpdateDecoderFeeds, - BeamSearchCudaDeviceHelper::UpdateDecoderFeeds); + BeamSearchCudaDeviceHelper::UpdateDecoderFeeds, + BeamSearchCudaDeviceHelper::ExpandBuffer, + BeamSearchCudaDeviceHelper::ExpandBuffer, + BeamSearchCudaDeviceHelper::ExpandBuffer); SetConsoleDumper(&g_cuda_dumper); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc index 3c45c2cc60b11..b712908259da1 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc @@ -221,6 +221,7 @@ Status ProcessLogits(const OrtValue& logits, // const TensorShape& logits_shape = logits.Get().Shape(); ORT_ENFORCE(logits_shape.NumDimensions() == 3); auto input_length = logits_shape[1]; + auto logits_batch_size = logits_shape[0]; cudaStream_t cuda_stream = reinterpret_cast(stream); @@ -228,21 +229,28 @@ Status ProcessLogits(const OrtValue& logits, // // next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size) // When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1. gsl::span& next_token_logits = beam_state->next_token_logits; - if (input_length > 1) { - // TODO(tianleiwu): use one kernel to replace a loop of memory copy. + + // TODO(tianleiwu): use one kernel to replace a loop of memory copy. + if (input_length > 1 || logits_batch_size == batch_size) { const CudaT* current_logits = logits_data + (input_length - 1) * vocab_size; for (int i = 0; i < batch_beam_size; i++) { gsl::span source(reinterpret_cast(current_logits), vocab_size); gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream)); - current_logits += input_length * vocab_size; + if (logits_batch_size == batch_beam_size) { + current_logits += input_length * vocab_size; + } else if (logits_batch_size == batch_size && i % num_beams == num_beams - 1) { + current_logits += input_length * vocab_size; + } } } #ifdef DEBUG_BEAM_SEARCH dumper->Print("logits", logits); - dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size); + if (input_length > 1 || logits_batch_size == batch_size) { + dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size); + } #endif // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) @@ -250,7 +258,9 @@ Status ProcessLogits(const OrtValue& logits, // // The output will be float for consideration of precision and easy integration with remaining parts. float* Y_data = next_token_scores.data(); - const CudaT* X_data = input_length > 1 ? reinterpret_cast(next_token_logits.data()) : logits_data; + const CudaT* X_data = (input_length == 1 && logits_batch_size == batch_beam_size) ? + logits_data : + reinterpret_cast(next_token_logits.data()); dispatch_blockwise_softmax_forward( cuda_stream, Y_data, X_data, vocab_size, vocab_size, batch_size * num_beams); @@ -618,6 +628,53 @@ Status UpdateDecoderFeeds( t5_decoder_first_past_input_idx, t5_decoder_first_present_output_idx, stream); } +template +Status ExpandBuffer(void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape) { + // Input shape (batch_size, xxx). The input is required with data type T. + // Output shape (batch_size * num_beams, xxx) + const TensorShape& input_shape = input.Get().Shape(); + const int64_t& batch_size = input_shape[0]; + const int64_t& chunk_size = static_cast(input_shape.Size() / batch_size); + + int64_t dims[4] = {0}; + input_shape.CopyDims(dims, input_shape.NumDimensions()); + dims[0] = batch_size * num_beams; + TensorShape expanded_shape(&dims[0], input_shape.NumDimensions()); + + MLDataType element_type = input.Get().DataType(); + ORT_ENFORCE(element_type == DataTypeImpl::GetType()); + Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded); + + if (only_copy_shape) { + return Status::OK(); + } + + cudaStream_t cuda_stream = reinterpret_cast(stream); + + const T* input_data = input.Get().Data(); + T* expanded_data = expanded.GetMutable()->MutableData(); + T* target = expanded_data; + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < num_beams; j++) { + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync( + target, + input_data + i * chunk_size, + sizeof(T) * chunk_size, + cudaMemcpyDeviceToDevice, + cuda_stream)); + target += chunk_size; + } + } + + return Status::OK(); +} + // Explicit template instantiations of functions template void InitBeamState(transformers::IBeamSearchState* beam_state, gsl::span& sequence_lengths, @@ -730,6 +787,29 @@ template Status UpdateDecoderFeeds( transformers::Sequences& sequences, const transformers::IConsoleDumper* dumper); +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); } // namespace BeamSearchCudaDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h index a90d0c7ee84c8..14f64e923e781 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h @@ -97,6 +97,15 @@ Status UpdateDecoderFeeds( transformers::Sequences& sequences, const transformers::IConsoleDumper* dumper); +template +Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + } // namespace BeamSearchCudaDeviceHelper } // namespace contrib } // namespace onnxruntime