diff --git a/benchmark/c/main.cpp b/benchmark/c/main.cpp index 5c86684d9f..103cf86787 100644 --- a/benchmark/c/main.cpp +++ b/benchmark/c/main.cpp @@ -112,7 +112,7 @@ void WriteE2EStats(std::string_view label, << "\n"; } -std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, const OgaTokenizer& tokenizer, size_t batch_size) { +std::string GeneratePrompt(const benchmark::Options& opts, size_t num_prompt_tokens, const OgaModel& model, const OgaTokenizer& tokenizer, size_t batch_size) { const char* const base_prompt = "A"; auto base_prompt_sequences = OgaSequences::Create(); for (size_t i = 0; i < batch_size; ++i) { @@ -120,12 +120,14 @@ std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, cons } auto params = OgaGeneratorParams::Create(model); - params->SetSearchOption("max_length", static_cast(num_prompt_tokens)); + if (!opts.no_dynamic_max_length) { + params->SetSearchOption("max_length", static_cast(num_prompt_tokens)); + } params->SetSearchOption("min_length", static_cast(num_prompt_tokens)); auto generator = OgaGenerator::Create(model, *params); generator->AppendTokenSequences(*base_prompt_sequences); - while (!generator->IsDone()) { + while (!generator->IsDone() && num_prompt_tokens-- > 0) { generator->GenerateNextToken(); } @@ -159,7 +161,7 @@ void RunBenchmark(const benchmark::Options& opts) { const auto prompt = [&]() -> std::string { if (const size_t* num_prompt_tokens = std::get_if(&opts.prompt_num_tokens_or_content)) { - return GeneratePrompt(*num_prompt_tokens, *model, *tokenizer, opts.batch_size); + return GeneratePrompt(opts, *num_prompt_tokens, *model, *tokenizer, opts.batch_size); } return std::get(opts.prompt_num_tokens_or_content); }(); @@ -179,7 +181,9 @@ void RunBenchmark(const benchmark::Options& opts) { auto make_generator_params = [&] { auto params = OgaGeneratorParams::Create(*model); - params->SetSearchOption("max_length", static_cast(num_tokens)); + if (!opts.no_dynamic_max_length) { + params->SetSearchOption("max_length", static_cast(num_tokens)); + } params->SetSearchOption("min_length", static_cast(num_tokens)); return params; }; @@ -190,8 +194,9 @@ void RunBenchmark(const benchmark::Options& opts) { if (opts.verbose) std::cout << "Running warmup iterations (" << opts.num_warmup_iterations << ")...\n"; for (size_t i = 0; i < opts.num_warmup_iterations; ++i) { auto generator = OgaGenerator::Create(*model, *generator_params); + auto num_tokens_to_generate = opts.num_tokens_to_generate; generator->AppendTokenSequences(*prompt_sequences); - while (!generator->IsDone()) { + while (!generator->IsDone() && num_tokens_to_generate-- > 0) { generator->GenerateNextToken(); } @@ -215,6 +220,7 @@ void RunBenchmark(const benchmark::Options& opts) { if (opts.verbose) std::cout << "Running iterations (" << opts.num_iterations << ")...\n"; for (size_t i = 0; i < opts.num_iterations; ++i) { auto generator = OgaGenerator::Create(*model, *generator_params); + auto num_tokens_to_generate = opts.num_tokens_to_generate; { Timing e2e_gen_timing{e2e_gen_times}; @@ -232,7 +238,7 @@ void RunBenchmark(const benchmark::Options& opts) { generator_done = generator->IsDone(); } - while (!generator_done) { + while (!generator_done && num_tokens_to_generate-- > 0) { { Timing token_gen_timing{token_gen_times}; generator->GenerateNextToken(); diff --git a/benchmark/c/options.cpp b/benchmark/c/options.cpp index e968b8df0a..c3e4c413f4 100644 --- a/benchmark/c/options.cpp +++ b/benchmark/c/options.cpp @@ -49,6 +49,8 @@ namespace { << " Number of warmup runs before benchmarking. Default: " << defaults.num_warmup_iterations << "\n" << " -v,--verbose\n" << " Show more informational output.\n" + << " --no_dynamic_max_length\n" + << " Disable dynamic max_length.\n" << " -h,--help\n" << " Show this help message and exit.\n"; @@ -132,6 +134,8 @@ Options ParseOptionsFromCommandLine(int argc, const char* const* argv) { opts.num_warmup_iterations = ParseNumber(next_arg(i)); } else if (arg == "-v" || arg == "--verbose") { opts.verbose = true; + } else if (arg == "--no_dynamic_max_length") { + opts.no_dynamic_max_length = true; } else if (arg == "-h" || arg == "--help") { PrintHelpAndExit(program_name, 0); } else { diff --git a/benchmark/c/options.h b/benchmark/c/options.h index 2eeb61c3de..22f309cee0 100644 --- a/benchmark/c/options.h +++ b/benchmark/c/options.h @@ -20,6 +20,7 @@ struct Options { size_t num_iterations{5}; size_t num_warmup_iterations{1}; bool verbose{}; + bool no_dynamic_max_length{}; }; Options ParseOptionsFromCommandLine(int argc, const char* const* argv); diff --git a/benchmark/python/benchmark_e2e.py b/benchmark/python/benchmark_e2e.py index 850a61a9f6..702432b7a2 100644 --- a/benchmark/python/benchmark_e2e.py +++ b/benchmark/python/benchmark_e2e.py @@ -80,12 +80,17 @@ def generate_prompt(model, tokenizer, prompt_length) -> str: tokens = tokenizer.encode(prompt) params = og.GeneratorParams(model) max_length_to_use = prompt_length + len(tokens) - params.set_search_options(max_length=max_length_to_use, min_length=prompt_length) + params.set_search_options( + min_length=prompt_length, + **({ "max_length": max_length_to_use } if not args.no_dynamic_max_length else {}) + ) generator = og.Generator(model, params) generator.append_tokens(tokens) - while not generator.is_done(): + i = 0 + while not generator.is_done() and i < prompt_length: generator.generate_next_token() + i += 1 return tokenizer.decode(generator.get_sequence(0)) @@ -307,7 +312,7 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length top_k=args.top_k, top_p=args.top_p, temperature=temperature, - max_length=max_length, + **({ "max_length": max_length } if not args.no_dynamic_max_length else {}), min_length=max_length, batch_size=batch_size, ) @@ -317,8 +322,10 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length for _ in tqdm(range(args.warmup)): generator = og.Generator(model, params) generator.append_tokens(tokens) - while not generator.is_done(): + i = 0 + while not generator.is_done() and i < generation_length: generator.generate_next_token() + i += 1 if args.print_model_output: print(tokenizer.decode(generator.get_sequence(0))) # Delete the generator to free the captured graph for the next generator, if graph capture is enabled @@ -350,7 +357,7 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length top_k=args.top_k, top_p=args.top_p, temperature=temperature, - max_length=max_length, + **({ "max_length": max_length } if not args.no_dynamic_max_length else {}), min_length=max_length, batch_size=batch_size, ) @@ -543,6 +550,9 @@ def str2strlist(value): choices=["cpu", "cuda", "dml", "follow_config"], help="Execution provider to run the ONNX Runtime session with. Defaults to follow_config that uses the execution provider listed in the genai_config.json instead.", ) + parser.add_argument( + "--no_dynamic_max_length", action="store_true", help="Disable dynamic max_length" + ) args = parser.parse_args() # check max_lengths diff --git a/src/models/model.cpp b/src/models/model.cpp index 77e6c82657..044fa38935 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -666,6 +666,7 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options, else if (provider_options.name == "VitisAI") { session_options.AddConfigEntry("session.inter_op.allow_spinning", "0"); session_options.AddConfigEntry("session.intra_op.allow_spinning", "0"); + session_options.AddConfigEntry("model_root", config.config_path.string().c_str()); } else if (provider_options.name == "NvTensorRtRtx") { bool is_multi_profile_enabled = IsMultiProfileEnabled(config.model.decoder.session_options); ConfigureNvTensorRtRtxProfile(config, session_options, is_multi_profile_enabled); @@ -812,7 +813,26 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options, values.emplace_back(option.second.c_str()); } session_options.AppendExecutionProvider(provider_options.name.c_str(), keys.data(), values.data(), keys.size()); - +#if defined(_WIN32) + if (provider_options.name == "VitisAI") { + if (const auto opt_it = std::find_if(provider_options.options.begin(), provider_options.options.end(), + [](const auto& pair) { return pair.first == "external_ep_libray"; }); + opt_it != provider_options.options.end()) { + auto lib_name = opt_it->second; + auto lib = LoadLibrary(lib_name.c_str()); + if (const auto func = (void (*)(void*, const OrtApiBase*, void*, OrtEpFactory**, size_t, size_t*))GetProcAddress(lib, "CreateEpFactories")) { + OrtEpFactory* factory = nullptr; + size_t num = 1; + + func(nullptr, OrtGetApiBase(), nullptr, &factory, num, &num); + } + if (const auto func = (void (*)(OrtSessionOptions*))GetProcAddress(lib, "RyzenAI_SetSessionOptions")) + func(&session_options); + fs::path custom_ops_lib_path(lib_name); + session_options.RegisterCustomOpsLibrary(custom_ops_lib_path.c_str()); + } + } +#endif // WIN32 #endif } }