Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<float>,
device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<int32_t>,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : BeamSearchCpuDeviceHelper::CreateEncoderInputs,
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds<float>};
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds<float>,
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer<int32_t>,
expand_buffer_float_func_ ? expand_buffer_float_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer<MLFloat16>};
ORT_RETURN_IF_ERROR(impl.Initialize());

return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
Expand All @@ -198,7 +201,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
device_copy_func_,
device_copy_int32_func_,
create_encoder_inputs_func_,
update_decoder_feeds_fp16_func_};
update_decoder_feeds_fp16_func_,
expand_buffer_int32_func_,
expand_buffer_float_func_,
expand_buffer_float16_func_};

ORT_RETURN_IF_ERROR(impl.Initialize());

Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,15 @@ class BeamSearch : public IControlFlowKernel {
// device helpers for encoder-decoder model like T5
void SetDeviceHelpers_EncoderDecoder(
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<float>& update_decoder_feeds_func,
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16>& update_decoder_feeds_fp16_func) {
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16>& update_decoder_feeds_fp16_func,
const BeamSearchDeviceHelper::ExpandBufferFunc<int32_t>& expand_buffer_int32_func,
const BeamSearchDeviceHelper::ExpandBufferFunc<float>& expand_buffer_float_func,
const BeamSearchDeviceHelper::ExpandBufferFunc<MLFloat16>& expand_buffer_float16_func) {
update_decoder_feeds_func_ = update_decoder_feeds_func;
update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func;
expand_buffer_int32_func_ = expand_buffer_int32_func;
expand_buffer_float_func_ = expand_buffer_float_func;
expand_buffer_float16_func_ = expand_buffer_float16_func;
}

private:
Expand Down Expand Up @@ -106,6 +112,10 @@ class BeamSearch : public IControlFlowKernel {
BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<float> update_decoder_feeds_func_;
BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16> update_decoder_feeds_fp16_func_;

BeamSearchDeviceHelper::ExpandBufferFunc<int32_t> expand_buffer_int32_func_;
BeamSearchDeviceHelper::ExpandBufferFunc<float> expand_buffer_float_func_;
BeamSearchDeviceHelper::ExpandBufferFunc<MLFloat16> expand_buffer_float16_func_;

//------------------------------------------------------------
// Subgraph and FeedsFetchesManager re-used for each subgraph execution.
//------------------------------------------------------------
Expand Down
119 changes: 92 additions & 27 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,48 @@ void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator,
}
}

// TODO(wy): Dispatch it to avoid passing multiple functions to interface.
template <typename T>
Status ExpandBuffer(void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape) {
// Input shape (batch_size, xxx). The input is required with data type T.
// Output shape (batch_size * num_beams, xxx)
ORT_UNUSED_PARAMETER(stream);

const TensorShape& input_shape = input.Get<Tensor>().Shape();
const int64_t& batch_size = input_shape[0];
const int64_t& chunk_size = static_cast<int64_t>(input_shape.Size() / batch_size);

int64_t dims[4] = {0};
input_shape.CopyDims(dims, input_shape.NumDimensions());
dims[0] = batch_size * num_beams;
TensorShape expanded_shape(&dims[0], input_shape.NumDimensions());

MLDataType element_type = input.Get<Tensor>().DataType();
ORT_ENFORCE(element_type == DataTypeImpl::GetType<T>());
Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded);

if (only_copy_shape) {
return Status::OK();
}

const T* input_data = input.Get<Tensor>().Data<T>();
T* expanded_data = expanded.GetMutable<Tensor>()->MutableData<T>();
T* target = expanded_data;
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < num_beams; j++) {
memcpy(target, input_data + i * chunk_size, sizeof(T) * chunk_size);
target += chunk_size;
}
}

return Status::OK();
}

Status CreateGptInputs(
const Tensor* original_input_ids,
int num_beams,
Expand Down Expand Up @@ -200,37 +242,45 @@ Status ProcessLogits(const OrtValue& logits, //
const TensorShape& logits_shape = logits.Get<Tensor>().Shape();
ORT_ENFORCE(logits_shape.NumDimensions() == 3);
auto input_length = logits_shape[1];
auto logits_batch_size = logits_shape[0];

// Get logits for the last token:
// next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size)
// When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1.
gsl::span<T>& next_token_logits = beam_state->next_token_logits;
if (input_length > 1) {

if (input_length > 1 || logits_batch_size == batch_size) {
const T* current_logits = logits_data + (input_length - 1) * vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const T> source(current_logits, vocab_size);
gsl::span<T> target = next_token_logits.subspan(SafeInt<gsl::index>(i) * vocab_size,
static_cast<gsl::index>(vocab_size));
gsl::copy(source, target);
current_logits += input_length * vocab_size;
if (logits_batch_size == batch_beam_size) {
current_logits += input_length * vocab_size;
} else if (logits_batch_size == batch_size && i % num_beams == num_beams - 1) {
current_logits += input_length * vocab_size;
}
}
}

#ifdef DEBUG_BEAM_SEARCH
dumper->Print("logits", logits);
if (input_length > 1) {
if (input_length > 1 || logits_batch_size == batch_size) {
dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size);
}
#endif

// Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1)
gsl::span<T>& next_token_scores = beam_state->next_token_scores;
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(batch_beam_size, // rows
vocab_size, // elements per row
input_length > 1 ? next_token_logits.data() : logits_data,
next_token_scores.data(),
true,
thread_pool));
ORT_RETURN_IF_ERROR(
SoftmaxCPU<T>(
batch_beam_size, // rows
vocab_size, // elements per row
(input_length == 1 && logits_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(),
next_token_scores.data(),
true,
thread_pool));

#ifdef DEBUG_BEAM_SEARCH
dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size);
Expand Down Expand Up @@ -456,13 +506,12 @@ Status UpdateGptFeeds(
Status CreateEncoderInputs(
const Tensor* original_encoder_input_ids,
const OrtValue* attn_mask_value,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids) {
OrtValue& encoder_input_ids,
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids) {
const TensorShape& input_ids_shape = original_encoder_input_ids->Shape();
ORT_ENFORCE(input_ids_shape.NumDimensions() == 2);
const int64_t& batch_size = input_ids_shape[0];
Expand All @@ -475,14 +524,12 @@ Status CreateEncoderInputs(
// Current shape is (batch_size, sequence_length)
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
// To avoid cloning input_ids, we use const_cast here since this function does not change its content.
OrtValue encoder_input_ids;
Tensor::InitOrtValue(element_type,
input_ids_shape,
const_cast<Tensor*>(original_encoder_input_ids)->MutableData<int32_t>(),
allocator->Info(),
encoder_input_ids);

OrtValue encoder_attention_mask;
if (attn_mask_value != nullptr) {
const Tensor& attention_mask = attn_mask_value->Get<Tensor>();
Tensor::InitOrtValue(element_type, input_ids_shape, const_cast<Tensor*>(&attention_mask)->MutableData<int32_t>(),
Expand Down Expand Up @@ -511,20 +558,14 @@ Status CreateEncoderInputs(
}
}

// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
// for encoder_input_ids and encoder_attention_mask
// TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
ExpandInputs<int32_t>(encoder_input_ids, num_beams, allocator, expanded_encoder_input_ids);
ExpandInputs<int32_t>(encoder_attention_mask, num_beams, allocator, expanded_encoder_attention_mask);

// decoder_input_ids is optional.
if (start_token_id >= 0) {
// Expanded decoder_input_ids has shape (batch_size * num_beams, 1), and filled with start token ID
int64_t dims[] = {batch_size * num_beams, 1};
// Filled decoder_input_ids with start token ID
int64_t dims[] = {batch_size, 1};
TensorShape decoder_input_ids_shape(&dims[0], 2);
Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, expanded_decoder_input_ids);
int32_t* data = expanded_decoder_input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_size * num_beams; i++, data++) {
Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, decoder_input_ids);
int32_t* data = decoder_input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_size; i++, data++) {
*data = start_token_id;
}
}
Expand Down Expand Up @@ -602,7 +643,7 @@ Status UpdateDecoderFeeds(
TensorShape input_ids_shape(&dims[0], 2);
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), input_ids_shape, allocator, input_ids);

// TODO: decouple has_hidden_state with full input_ids
// TODO(wy): decouple has_hidden_state with full input_ids
if (has_hidden_state) {
gsl::copy(beam_next_tokens, input_ids.GetMutable<Tensor>()->MutableDataAsSpan<int32_t>());
} else {
Expand Down Expand Up @@ -709,6 +750,30 @@ template Status UpdateDecoderFeeds<float>(

template void ExpandInputs<int32_t>(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);

template Status ExpandBuffer<int32_t>(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);

template Status ExpandBuffer<float>(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);

template Status ExpandBuffer<MLFloat16>(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);

} // namespace BeamSearchCpuDeviceHelper
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,12 @@ using UpdateGptFeedsFunc = std::function<Status(
using CreateEncoderInputsFunc = std::function<Status(
const Tensor* original_encoder_input_ids,
const OrtValue* attn_mask_value,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids)>;
OrtValue& encoder_input_ids,
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids)>;

// Update decoder inputs given decoder outputs of last iteration (for encoder-decoder model like T5).
template <typename T>
Expand All @@ -132,8 +131,18 @@ using UpdateDecoderFeedsFunc = std::function<Status(
int current_length,
transformers::Sequences& sequences,
const transformers::IConsoleDumper* dumper)>;

template <typename T>
using ExpandBufferFunc = std::function<Status(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape)>;
} // namespace BeamSearchDeviceHelper


// These are CPU specific device helper implementations
namespace BeamSearchCpuDeviceHelper {
Status TopK(
Expand Down Expand Up @@ -212,13 +221,12 @@ Status UpdateGptFeeds(
Status CreateEncoderInputs(
const Tensor* original_encoder_input_ids,
const OrtValue* attn_mask_value,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids);
OrtValue& encoder_input_ids,
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids);

// Update decoder inputs given decoder outputs of last iteration.
template <typename T>
Expand All @@ -244,6 +252,15 @@ Status UpdateDecoderFeeds(
template <typename T>
void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);

template <typename T>
Status ExpandBuffer(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);

} // namespace BeamSearchCpuDeviceHelper
} // namespace contrib
} // namespace onnxruntime
17 changes: 16 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct BeamSearchCpuState : public IBeamSearchCpuState {
this->sequences.Init(this->sequences_space, static_cast<int>(batch_beam_size), sequence_length, max_length);
}

// Copy input_ids to sequences[0]
// Copy expanded input_ids to sequences[0]
void SetSequence(gsl::span<const int32_t> input_ids_in_cpu,
size_t batch_beam_size,
int max_length,
Expand All @@ -109,6 +109,21 @@ struct BeamSearchCpuState : public IBeamSearchCpuState {
}
}

// Copy unexpanded input_ids to sequences[0]
void SetSequence(gsl::span<const int32_t> input_ids_in_cpu,
size_t batch_beam_size,
int beam_size,
int max_length,
int sequence_length) {
gsl::span<int32_t> sequences_0 = sequences_space;
for (size_t i = 0; i < batch_beam_size; i++) {
for (int j = 0; j < sequence_length; j++) {
const size_t index = SafeInt<gsl::index>(i) * max_length + j;
sequences_0[index] = input_ids_in_cpu[SafeInt<gsl::index>(i / beam_size) * sequence_length + j];
}
}
}

private:
BufferUniquePtr final_beam_scores_buffer_;
BufferUniquePtr sequence_lengths_buffer_;
Expand Down
Loading