Skip to content
20 changes: 13 additions & 7 deletions benchmark/c/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,22 @@ void WriteE2EStats(std::string_view label,
<< "\n";
}

std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, const OgaTokenizer& tokenizer, size_t batch_size) {
std::string GeneratePrompt(const benchmark::Options& opts, size_t num_prompt_tokens, const OgaModel& model, const OgaTokenizer& tokenizer, size_t batch_size) {
const char* const base_prompt = "A";
auto base_prompt_sequences = OgaSequences::Create();
for (size_t i = 0; i < batch_size; ++i) {
tokenizer.Encode(base_prompt, *base_prompt_sequences);
}

auto params = OgaGeneratorParams::Create(model);
params->SetSearchOption("max_length", static_cast<double>(num_prompt_tokens));
if (!opts.no_dynamic_max_length) {
params->SetSearchOption("max_length", static_cast<double>(num_prompt_tokens));
}
params->SetSearchOption("min_length", static_cast<double>(num_prompt_tokens));

auto generator = OgaGenerator::Create(model, *params);
generator->AppendTokenSequences(*base_prompt_sequences);
while (!generator->IsDone()) {
while (!generator->IsDone() && num_prompt_tokens-- > 0) {
generator->GenerateNextToken();
}

Expand Down Expand Up @@ -159,7 +161,7 @@ void RunBenchmark(const benchmark::Options& opts) {

const auto prompt = [&]() -> std::string {
if (const size_t* num_prompt_tokens = std::get_if<size_t>(&opts.prompt_num_tokens_or_content)) {
return GeneratePrompt(*num_prompt_tokens, *model, *tokenizer, opts.batch_size);
return GeneratePrompt(opts, *num_prompt_tokens, *model, *tokenizer, opts.batch_size);
}
return std::get<std::string>(opts.prompt_num_tokens_or_content);
}();
Expand All @@ -179,7 +181,9 @@ void RunBenchmark(const benchmark::Options& opts) {

auto make_generator_params = [&] {
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", static_cast<double>(num_tokens));
if (!opts.no_dynamic_max_length) {
params->SetSearchOption("max_length", static_cast<double>(num_tokens));
}
params->SetSearchOption("min_length", static_cast<double>(num_tokens));
return params;
};
Expand All @@ -190,8 +194,9 @@ void RunBenchmark(const benchmark::Options& opts) {
if (opts.verbose) std::cout << "Running warmup iterations (" << opts.num_warmup_iterations << ")...\n";
for (size_t i = 0; i < opts.num_warmup_iterations; ++i) {
auto generator = OgaGenerator::Create(*model, *generator_params);
auto num_tokens_to_generate = opts.num_tokens_to_generate;
generator->AppendTokenSequences(*prompt_sequences);
while (!generator->IsDone()) {
while (!generator->IsDone() && num_tokens_to_generate-- > 0) {
generator->GenerateNextToken();
}

Expand All @@ -215,6 +220,7 @@ void RunBenchmark(const benchmark::Options& opts) {
if (opts.verbose) std::cout << "Running iterations (" << opts.num_iterations << ")...\n";
for (size_t i = 0; i < opts.num_iterations; ++i) {
auto generator = OgaGenerator::Create(*model, *generator_params);
auto num_tokens_to_generate = opts.num_tokens_to_generate;

{
Timing e2e_gen_timing{e2e_gen_times};
Expand All @@ -232,7 +238,7 @@ void RunBenchmark(const benchmark::Options& opts) {
generator_done = generator->IsDone();
}

while (!generator_done) {
while (!generator_done && num_tokens_to_generate-- > 0) {
{
Timing token_gen_timing{token_gen_times};
generator->GenerateNextToken();
Expand Down
4 changes: 4 additions & 0 deletions benchmark/c/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ namespace {
<< " Number of warmup runs before benchmarking. Default: " << defaults.num_warmup_iterations << "\n"
<< " -v,--verbose\n"
<< " Show more informational output.\n"
<< " --no_dynamic_max_length\n"
<< " Disable dynamic max_length.\n"
<< " -h,--help\n"
<< " Show this help message and exit.\n";

Expand Down Expand Up @@ -132,6 +134,8 @@ Options ParseOptionsFromCommandLine(int argc, const char* const* argv) {
opts.num_warmup_iterations = ParseNumber<size_t>(next_arg(i));
} else if (arg == "-v" || arg == "--verbose") {
opts.verbose = true;
} else if (arg == "--no_dynamic_max_length") {
opts.no_dynamic_max_length = true;
} else if (arg == "-h" || arg == "--help") {
PrintHelpAndExit(program_name, 0);
} else {
Expand Down
1 change: 1 addition & 0 deletions benchmark/c/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct Options {
size_t num_iterations{5};
size_t num_warmup_iterations{1};
bool verbose{};
bool no_dynamic_max_length{};
};

Options ParseOptionsFromCommandLine(int argc, const char* const* argv);
Expand Down
20 changes: 15 additions & 5 deletions benchmark/python/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,17 @@ def generate_prompt(model, tokenizer, prompt_length) -> str:
tokens = tokenizer.encode(prompt)
params = og.GeneratorParams(model)
max_length_to_use = prompt_length + len(tokens)
params.set_search_options(max_length=max_length_to_use, min_length=prompt_length)
params.set_search_options(
min_length=prompt_length,
**({ "max_length": max_length_to_use } if not args.no_dynamic_max_length else {})
)

generator = og.Generator(model, params)
generator.append_tokens(tokens)
while not generator.is_done():
i = 0
while not generator.is_done() and i < prompt_length:
generator.generate_next_token()
i += 1
return tokenizer.decode(generator.get_sequence(0))


Expand Down Expand Up @@ -307,7 +312,7 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
top_k=args.top_k,
top_p=args.top_p,
temperature=temperature,
max_length=max_length,
**({ "max_length": max_length } if not args.no_dynamic_max_length else {}),
min_length=max_length,
batch_size=batch_size,
)
Expand All @@ -317,8 +322,10 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
for _ in tqdm(range(args.warmup)):
generator = og.Generator(model, params)
generator.append_tokens(tokens)
while not generator.is_done():
i = 0
while not generator.is_done() and i < generation_length:
generator.generate_next_token()
i += 1
if args.print_model_output:
print(tokenizer.decode(generator.get_sequence(0)))
# Delete the generator to free the captured graph for the next generator, if graph capture is enabled
Expand Down Expand Up @@ -350,7 +357,7 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
top_k=args.top_k,
top_p=args.top_p,
temperature=temperature,
max_length=max_length,
**({ "max_length": max_length } if not args.no_dynamic_max_length else {}),
min_length=max_length,
batch_size=batch_size,
)
Expand Down Expand Up @@ -543,6 +550,9 @@ def str2strlist(value):
choices=["cpu", "cuda", "dml", "follow_config"],
help="Execution provider to run the ONNX Runtime session with. Defaults to follow_config that uses the execution provider listed in the genai_config.json instead.",
)
parser.add_argument(
"--no_dynamic_max_length", action="store_true", help="Disable dynamic max_length"
)
args = parser.parse_args()

# check max_lengths
Expand Down
22 changes: 21 additions & 1 deletion src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
else if (provider_options.name == "VitisAI") {
session_options.AddConfigEntry("session.inter_op.allow_spinning", "0");
session_options.AddConfigEntry("session.intra_op.allow_spinning", "0");
session_options.AddConfigEntry("model_root", config.config_path.string().c_str());
} else if (provider_options.name == "NvTensorRtRtx") {
bool is_multi_profile_enabled = IsMultiProfileEnabled(config.model.decoder.session_options);
ConfigureNvTensorRtRtxProfile(config, session_options, is_multi_profile_enabled);
Expand Down Expand Up @@ -812,7 +813,26 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
values.emplace_back(option.second.c_str());
}
session_options.AppendExecutionProvider(provider_options.name.c_str(), keys.data(), values.data(), keys.size());

#if defined(_WIN32)
if (provider_options.name == "VitisAI") {
if (const auto opt_it = std::find_if(provider_options.options.begin(), provider_options.options.end(),
[](const auto& pair) { return pair.first == "external_ep_libray"; });
opt_it != provider_options.options.end()) {
auto lib_name = opt_it->second;
auto lib = LoadLibrary(lib_name.c_str());
if (const auto func = (void (*)(void*, const OrtApiBase*, void*, OrtEpFactory**, size_t, size_t*))GetProcAddress(lib, "CreateEpFactories")) {
OrtEpFactory* factory = nullptr;
size_t num = 1;

func(nullptr, OrtGetApiBase(), nullptr, &factory, num, &num);
}
if (const auto func = (void (*)(OrtSessionOptions*))GetProcAddress(lib, "RyzenAI_SetSessionOptions"))
func(&session_options);
fs::path custom_ops_lib_path(lib_name);
session_options.RegisterCustomOpsLibrary(custom_ops_lib_path.c_str());
}
}
#endif // WIN32
#endif
}
}
Expand Down
Loading