diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 93837e785b4a4..70d61a20c9095 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "contrib_ops/cpu/transformers/beam_search_parameters.h" +#include "core/platform/env_var_utils.h" namespace onnxruntime { namespace contrib { @@ -136,7 +137,11 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { temperature = 1.0f; } } + + // The following parameter is read from environment variable for testing purpose. + use_fast_topk = ParseEnvironmentVariableWithDefault(kBeamSearchUseFastTopK, true); } + void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { // Override vocab_size using the inferred shape from the decoder subgraph ONLY IF // the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 8145fbd4a4123..4c31c6cc53499 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -199,8 +199,14 @@ struct IGenerationParameters { int extra_decoding_ids_input_id = -1; int cross_qk_output_id = -1; int no_speech_probs_output_id = -1; + + // Parameter for testing slow topk path. It can be updated by the below environment variable. + bool use_fast_topk = true; }; +// Environment variable to enable/disable fast topk kernel on GPU. Default is 1 (enabled). +constexpr const char* kBeamSearchUseFastTopK = "ORT_BEAM_SEARCH_USE_FAST_TOPK"; + } // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 4e65336665bf7..23283706a11cf 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -524,7 +524,8 @@ Status ProcessLogits(const OrtValue& logits, // beam_state->remaining_scores = beam_state->remaining_scores.subspan(next_token_scores.size()); } - if (num_beams <= 32) { + gsl::span scores_to_process = beam_state->next_scores; + if (parameters->use_fast_topk && num_beams <= 32) { constexpr size_t max_parts_of_vocab = 128; size_t candidate_count = SafeInt(batch_beam_size) * 2 * num_beams; float* topk_tmp_buffer = beam_state->topk_buffer.data(); @@ -546,13 +547,6 @@ Status ProcessLogits(const OrtValue& logits, // beam_state->next_tokens.data(), beam_state->next_indices.data(), cuda_stream); - - // Select [batch_size, 2 * num_beams] from [batch_size * num_beams, 2 * num_beams] -#ifdef DEBUG_GENERATION - dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, 2 * num_beams); - dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams); - dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams); -#endif } else { // Apply top-k selection like the following: // next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) @@ -588,18 +582,20 @@ Status ProcessLogits(const OrtValue& logits, // cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(), batch_size, top_k, vocab_size, cuda_stream); -#ifdef DEBUG_GENERATION - dumper->Print("next_scores before scorer", topk_scores->Data(), batch_size, top_k); - dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, top_k); - dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k); -#endif + scores_to_process = gsl::span(topk_scores->MutableData(), batch_size * top_k); } // gsl::span doesn't convert from non const to const, so all we're doing here is making each const. - gsl::span next_scores(beam_state->next_scores.data(), beam_state->next_scores.size()); + gsl::span next_scores(scores_to_process.data(), scores_to_process.size()); gsl::span next_tokens(beam_state->next_tokens.data(), beam_state->next_tokens.size()); gsl::span next_indices(beam_state->next_indices.data(), beam_state->next_indices.size()); +#ifdef DEBUG_GENERATION + dumper->Print("next_scores before scorer", next_scores.data(), batch_size, 2 * num_beams); + dumper->Print("next_tokens before scorer", next_tokens.data(), batch_size, 2 * num_beams); + dumper->Print("next_indices before scorer", next_indices.data(), batch_size, 2 * num_beams); +#endif + beam_scorer->Process( *sequences, next_scores, @@ -735,6 +731,7 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences, next_tokens, next_indices, stream_); + CUDA_CALL_THROW(cudaEventRecord(event_process_complete_.Get(), stream_)); cuda::LaunchBeamSearchScorer_AppendNextTokenToSequences(*state_cpu_, diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 1ae15afdf7482..bb5c762a42582 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -9,6 +9,8 @@ #include "test/common/cuda_op_test_utils.h" #include "test/providers/model_tester.h" #include "test/util/include/current_test_name.h" +#include "test/util/include/scoped_env_vars.h" +#include "contrib_ops/cpu/transformers/generation_shared.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" @@ -19,7 +21,7 @@ extern std::unique_ptr ort_env; namespace onnxruntime { namespace test { -TEST(BeamSearchTest, GptBeamSearchFp32) { +void RunGptBeamSearchFp32() { std::vector input_ids_shape{3, 12}; std::vector input_ids{ 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, @@ -107,6 +109,16 @@ TEST(BeamSearchTest, GptBeamSearchFp32) { ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end())); } +TEST(BeamSearchTest, GptBeamSearchFp32) { + RunGptBeamSearchFp32(); +} + +TEST(BeamSearchTest, GptBeamSearchFp32_DisableFastTopK) { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::transformers::kBeamSearchUseFastTopK, "0"}}}; + RunGptBeamSearchFp32(); +} + TEST(BeamSearchTest, GptBeamSearchFp16) { std::vector input_ids_shape{3, 12}; std::vector input_ids{