Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions onnxruntime/test/perftest/command_args_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ ABSL_FLAG(std::string, plugin_ep_options, "",
"--plugin_ep_options \"ep_1_option_1_key|ep_1_option_1_value ...;;ep_3_option_1_key|ep_3_option_1_value ...;... \"");
ABSL_FLAG(bool, list_ep_devices, false, "Prints all available device indices and their properties (including metadata). This option makes the program exit early without performing inference.\n");
ABSL_FLAG(std::string, select_ep_devices, "", "Specifies a semicolon-separated list of device indices to add to the session and run with.");
ABSL_FLAG(std::string, filter_ep_devices, "",
"Specifies EP or Device metadata entries as key-value pairs to filter ep devices passed to AppendExecutionProvider_V2.\n"
"[Usage]: --filter_ep_devices \"<key1>|<value1> <key2>|<value2>\" \n"
"Devices that match any of the key-value pair will be appended to the session. --select_ep_devices will take precedence over this option.\n");
ABSL_FLAG(bool, compile_ep_context, DefaultPerformanceTestConfig().run_config.compile_ep_context, "Generate an EP context model");
ABSL_FLAG(std::string, compile_model_path, "model_ctx.onnx", "The compiled model path for saving EP context model. Overwrites if already exists");
ABSL_FLAG(bool, compile_binary_embed, DefaultPerformanceTestConfig().run_config.compile_binary_embed, "Embed binary blob within EP context node");
Expand Down Expand Up @@ -490,6 +494,22 @@ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int a
if (!select_ep_devices.empty()) test_config.selected_ep_device_indices = select_ep_devices;
}

// --filter_ep_devices
{
const auto& filter_ep_devices = absl::GetFlag(FLAGS_filter_ep_devices);
if (!filter_ep_devices.empty()) {
ORT_TRY {
ParseEpDeviceFilterKeyValuePairs(filter_ep_devices, test_config.filter_ep_device_kv_pairs);
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
fprintf(stderr, "Error parsing filter_ep_devices: %s\n", ex.what());
});
return false;
}
}
}

// --compile_ep_context
test_config.run_config.compile_ep_context = absl::GetFlag(FLAGS_compile_ep_context);

Expand Down
24 changes: 23 additions & 1 deletion onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,36 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
if (added_ep_device_index_set.find(index) == added_ep_device_index_set.end()) {
added_ep_devices[device.EpName()].push_back(device);
added_ep_device_index_set.insert(index);
fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s] has been added to session.\n", index, device.EpName());
fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s, Type: %d] has been added to session.\n", static_cast<int>(index), device.EpName(), device.Device().Type());
}
} else {
std::string err_msg = "[Plugin EP] [WARNING] : The EP device index and its corresponding OrtEpDevice is not created from " +
performance_test_config.machine_config.provider_type_name + ". Will skip adding this device.\n";
fprintf(stderr, "%s", err_msg.c_str());
}
}
} else if (!performance_test_config.filter_ep_device_kv_pairs.empty()) {
// Find and select the OrtEpDevice associated with the EP in "--filter_ep_devices".
for (size_t index = 0; index < ep_devices.size(); ++index) {
auto device = ep_devices[index];
if (ep_set.find(std::string(device.EpName())) == ep_set.end())
continue;

// Check both EP metadata and device metadata for a match
auto ep_metadata_kv_pairs = device.EpMetadata().GetKeyValuePairs();
auto device_metadata_kv_pairs = device.Device().Metadata().GetKeyValuePairs();
for (const auto& kv : performance_test_config.filter_ep_device_kv_pairs) {
auto ep_metadata_itr = ep_metadata_kv_pairs.find(kv.first);
auto device_metadata_itr = device_metadata_kv_pairs.find(kv.first);

if ((ep_metadata_itr != ep_metadata_kv_pairs.end() && kv.second == ep_metadata_itr->second) ||
(device_metadata_itr != device_metadata_kv_pairs.end() && kv.second == device_metadata_itr->second)) {
added_ep_devices[device.EpName()].push_back(device);
fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s, Type: %d] has been added to session.\n", static_cast<int>(index), device.EpName(), device.Device().Type());
break;
}
}
}
} else {
// Find and select the OrtEpDevice associated with the EP in "--plugin_eps".
for (size_t index = 0; index < ep_devices.size(); ++index) {
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/test/perftest/strings_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,22 @@ void ParseEpDeviceIndexList(const std::string& input, std::vector<int>& result)
}
}
}

void ParseEpDeviceFilterKeyValuePairs(const std::string& input, std::vector<std::pair<std::string, std::string>>& result) {
std::stringstream ss(input);
std::string token;

while (std::getline(ss, token, ' ')) {
if (!token.empty()) {
size_t delimiter_location = token.find("|");
if (delimiter_location == std::string::npos || delimiter_location == 0 || delimiter_location == token.size() - 1) {
ORT_THROW("Use a '|' to separate the key and value for the device filter you are trying to use.\n");
}
std::string key = token.substr(0, delimiter_location);
std::string value = token.substr(delimiter_location + 1);
result.emplace_back(std::make_pair(std::move(key), std::move(value)));
}
}
}
} // namespace perftest
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/test/perftest/strings_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,7 @@ void ParseEpList(const std::string& input, std::vector<std::string>& result);
void ParseEpOptions(const std::string& input, std::vector<std::unordered_map<std::string, std::string>>& result);

void ParseEpDeviceIndexList(const std::string& input, std::vector<int>& result);

void ParseEpDeviceFilterKeyValuePairs(const std::string& input, std::vector<std::pair<std::string, std::string>>& result);
} // namespace perftest
} // namespace onnxruntime
1 change: 1 addition & 0 deletions onnxruntime/test/perftest/test_configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct PerformanceTestConfig {
std::basic_string<ORTCHAR_T> plugin_ep_names_and_libs;
std::vector<std::string> registered_plugin_eps;
std::string selected_ep_device_indices;
std::vector<std::pair<std::string, std::string>> filter_ep_device_kv_pairs;
bool list_available_ep_devices = false;
};

Expand Down
Loading