Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
#include "core/platform/env_var_utils.h"

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -136,7 +137,11 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
temperature = 1.0f;
}
}

// The following parameter is read from environment variable for testing purpose.
use_fast_topk = ParseEnvironmentVariableWithDefault<bool>(kBeamSearchUseFastTopK, true);
}

void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,14 @@ struct IGenerationParameters {
int extra_decoding_ids_input_id = -1;
int cross_qk_output_id = -1;
int no_speech_probs_output_id = -1;

// Parameter for testing slow topk path. It can be updated by the below environment variable.
bool use_fast_topk = true;
};

// Environment variable to enable/disable fast topk kernel on GPU. Default is 1 (enabled).
constexpr const char* kBeamSearchUseFastTopK = "ORT_BEAM_SEARCH_USE_FAST_TOPK";

} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ Status ProcessLogits(const OrtValue& logits, //
beam_state->remaining_scores = beam_state->remaining_scores.subspan(next_token_scores.size());
}

if (num_beams <= 32) {
gsl::span<float> scores_to_process = beam_state->next_scores;
if (parameters->use_fast_topk && num_beams <= 32) {
constexpr size_t max_parts_of_vocab = 128;
size_t candidate_count = SafeInt<size_t>(batch_beam_size) * 2 * num_beams;
float* topk_tmp_buffer = beam_state->topk_buffer.data();
Expand All @@ -546,13 +547,6 @@ Status ProcessLogits(const OrtValue& logits, //
beam_state->next_tokens.data(),
beam_state->next_indices.data(),
cuda_stream);

// Select [batch_size, 2 * num_beams] from [batch_size * num_beams, 2 * num_beams]
#ifdef DEBUG_GENERATION
dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, 2 * num_beams);
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams);
dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams);
#endif
} else {
// Apply top-k selection like the following:
// next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
Expand Down Expand Up @@ -588,18 +582,20 @@ Status ProcessLogits(const OrtValue& logits, //
cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(),
batch_size, top_k, vocab_size, cuda_stream);

#ifdef DEBUG_GENERATION
dumper->Print("next_scores before scorer", topk_scores->Data<float>(), batch_size, top_k);
dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, top_k);
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
#endif
scores_to_process = gsl::span<float>(topk_scores->MutableData<float>(), batch_size * top_k);
}

// gsl::span doesn't convert from non const to const, so all we're doing here is making each const.
gsl::span<const float> next_scores(beam_state->next_scores.data(), beam_state->next_scores.size());
gsl::span<const float> next_scores(scores_to_process.data(), scores_to_process.size());
gsl::span<const int32_t> next_tokens(beam_state->next_tokens.data(), beam_state->next_tokens.size());
gsl::span<const int32_t> next_indices(beam_state->next_indices.data(), beam_state->next_indices.size());

#ifdef DEBUG_GENERATION
dumper->Print("next_scores before scorer", next_scores.data(), batch_size, 2 * num_beams);
dumper->Print("next_tokens before scorer", next_tokens.data(), batch_size, 2 * num_beams);
dumper->Print("next_indices before scorer", next_indices.data(), batch_size, 2 * num_beams);
#endif

beam_scorer->Process(
*sequences,
next_scores,
Expand Down Expand Up @@ -735,6 +731,7 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
next_tokens,
next_indices,
stream_);

CUDA_CALL_THROW(cudaEventRecord(event_process_complete_.Get(), stream_));

cuda::LaunchBeamSearchScorer_AppendNextTokenToSequences(*state_cpu_,
Expand Down
14 changes: 13 additions & 1 deletion onnxruntime/test/contrib_ops/beam_search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/model_tester.h"
#include "test/util/include/current_test_name.h"
#include "test/util/include/scoped_env_vars.h"
#include "contrib_ops/cpu/transformers/generation_shared.h"

#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_options.h"
Expand All @@ -19,7 +21,7 @@ extern std::unique_ptr<Ort::Env> ort_env;
namespace onnxruntime {
namespace test {

TEST(BeamSearchTest, GptBeamSearchFp32) {
void RunGptBeamSearchFp32() {
std::vector<int64_t> input_ids_shape{3, 12};
std::vector<int32_t> input_ids{
0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620,
Expand Down Expand Up @@ -107,6 +109,16 @@ TEST(BeamSearchTest, GptBeamSearchFp32) {
ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end()));
}

TEST(BeamSearchTest, GptBeamSearchFp32) {
RunGptBeamSearchFp32();
}

TEST(BeamSearchTest, GptBeamSearchFp32_DisableFastTopK) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{{onnxruntime::contrib::transformers::kBeamSearchUseFastTopK, "0"}}};
RunGptBeamSearchFp32();
}

TEST(BeamSearchTest, GptBeamSearchFp16) {
std::vector<int64_t> input_ids_shape{3, 12};
std::vector<int32_t> input_ids{
Expand Down
Loading