Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ struct OrtTensorRTProviderOptionsV2 {
int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true
const char* trt_engine_decryption_lib_path; // specify engine decryption library path
int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
int trt_timing_cache_enable; // enable TensorRT timing cache. Default 0 = false, nonzero = true
};
2 changes: 1 addition & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX

SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
Expand Down
95 changes: 91 additions & 4 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,36 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map<s
}
return true;
}

inline std::vector<char> loadTimingCacheFile(const std::string inFileName)
{
std::ifstream iFile(inFileName, std::ios::in | std::ios::binary);
if (!iFile)
{
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not read timing cache from: " << inFileName
<< ". A new timing cache will be generated and written.";
return std::vector<char>();
}
iFile.seekg(0, std::ifstream::end);
size_t fsize = iFile.tellg();
iFile.seekg(0, std::ifstream::beg);
std::vector<char> content(fsize);
iFile.read(content.data(), fsize);
iFile.close();
return content;
}

inline void saveTimingCacheFile(const std::string outFileName, const nvinfer1::IHostMemory* blob)
{
std::ofstream oFile(outFileName, std::ios::out | std::ios::binary);
if (!oFile)
{
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not write timing cache to: " << outFileName;
return;
}
oFile.write((char*) blob->data(), blob->size());
oFile.close();
}
} // namespace

namespace google {
Expand Down Expand Up @@ -427,7 +457,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
}
dump_subgraphs_ = info.dump_subgraphs;
engine_cache_enable_ = info.engine_cache_enable;
if (engine_cache_enable_ || int8_enable_) {
timing_cache_enable_ = info.timing_cache_enable;
if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) {
cache_path_ = info.engine_cache_path;
}
engine_decryption_enable_ = info.engine_decryption_enable;
Expand Down Expand Up @@ -497,7 +528,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
engine_cache_enable_ = (std::stoi(engine_cache_enable_env) == 0 ? false : true);
}

if (engine_cache_enable_ || int8_enable_) {
const std::string timing_cache_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kTimingCacheEnable);
if (!timing_cache_enable_env.empty()) {
timing_cache_enable_ = (std::stoi(timing_cache_enable_env) == 0 ? false : true);
}

if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) {
const std::string engine_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath);
cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath);
if (!engine_cache_path.empty() && cache_path_.empty()) {
Expand All @@ -519,6 +555,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
if (!force_sequential_engine_build_env.empty()) {
force_sequential_engine_build_ = (std::stoi(force_sequential_engine_build_env) == 0 ? false : true);
}

}

// Validate setting
Expand All @@ -539,7 +576,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
dla_core_ = 0;
}

if (engine_cache_enable_ || int8_enable_) {
if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) {
if (!cache_path_.empty() && !fs::is_directory(cache_path_)) {
if (!fs::create_directory(cache_path_)) {
throw std::runtime_error("Failed to create directory " + cache_path_);
Expand Down Expand Up @@ -1302,6 +1339,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
if (!has_dynamic_shape) {
const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
const std::string engine_cache_path = cache_path + ".engine";
const std::string timing_cache_path = cache_path + ".timing";
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
if (engine_cache_enable_ && engine_file) {
engine_file.seekg(0, std::ios::end);
Expand Down Expand Up @@ -1344,6 +1382,18 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
}
}

// Load timing cache from file. Create a fresh cache if the file doesn't exist
std::unique_ptr<nvinfer1::ITimingCache> timing_cache = nullptr;
if (timing_cache_enable_) {
std::vector<char> loaded_timing_cache = loadTimingCacheFile(timing_cache_path);
timing_cache.reset(trt_config->createTimingCache(static_cast<const void*>(loaded_timing_cache.data()), loaded_timing_cache.size()));
if (timing_cache == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not create timing cache: " + timing_cache_path);
}
trt_config->setTimingCache(*timing_cache, false);
}

// Build engine
{
auto lock = GetEngineBuildLock();
Expand All @@ -1369,6 +1419,18 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
serializedModel->destroy();
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
}

// serialize and save timing cache
if (timing_cache_enable_)
{
auto timing_cache = trt_config->getTimingCache();
std::unique_ptr<nvinfer1::IHostMemory> timingCacheHostData{timing_cache->serialize()};
if (timingCacheHostData == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not serialize timing cache: " + timing_cache_path);
}
saveTimingCacheFile(timing_cache_path, timingCacheHostData.get());
}
}

// Build context
Expand Down Expand Up @@ -1422,7 +1484,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
runtime_.get(), nullptr, allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_};
runtime_.get(), nullptr, allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_};
*state = p.release();
return 0;
};
Expand Down Expand Up @@ -1458,6 +1520,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
const std::string engine_cache_path = cache_path + ".engine";
const std::string profile_cache_path = cache_path + ".profile";
const std::string timing_cache_path = cache_path + ".timing";
if (trt_state->engine_cache_enable && trt_engine == nullptr) {
std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
Expand Down Expand Up @@ -1685,6 +1748,18 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
trt_config->setDLACore(trt_state->dla_core);
}

// Load timing cache from file. Create a fresh cache if the file doesn't exist
std::unique_ptr<nvinfer1::ITimingCache> timing_cache = nullptr;
if (trt_state->timing_cache_enable) {
std::vector<char> loaded_timing_cache = loadTimingCacheFile(timing_cache_path);
timing_cache.reset(trt_config->createTimingCache(static_cast<const void*>(loaded_timing_cache.data()), loaded_timing_cache.size()));
if (timing_cache == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not create timing cache: " + timing_cache_path);
}
trt_config->setTimingCache(*timing_cache, false);
}

// Build engine
{
auto lock = GetEngineBuildLock();
Expand Down Expand Up @@ -1716,6 +1791,18 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
serializedModel->destroy();
}

// serialize and save timing cache
if (trt_state->timing_cache_enable)
{
auto timing_cache = trt_config->getTimingCache();
std::unique_ptr<nvinfer1::IHostMemory> timingCacheHostData{timing_cache->serialize()};
if (timingCacheHostData == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not serialize timing cache: " + timing_cache_path);
}
saveTimingCacheFile(timing_cache_path, timingCacheHostData.get());
}

// Build context
*(trt_state->context) = tensorrt_ptr::unique_pointer<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH";
static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE";
static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH";
static const std::string kForceSequentialEngineBuild= "ORT_TENSORRT_FORCE_SEQUENTIAL_ENGINE_BUILD";
static const std::string kTimingCacheEnable = "ORT_TENSORRT_TIMING_CACHE_ENABLE";
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
Expand Down Expand Up @@ -107,6 +108,7 @@ struct TensorrtFuncState {
bool engine_decryption_enable;
int (*engine_decryption)(const char*, char*, size_t*);
int (*engine_encryption)(const char*, char*, size_t);
bool timing_cache_enable;
};

// Logical device representation.
Expand Down Expand Up @@ -167,6 +169,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool engine_decryption_enable_ = false;
int (*engine_decryption_)(const char*, char*, size_t*);
int (*engine_encryption_)(const char*, char*, size_t);
bool timing_cache_enable_ = false;

std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvonnxparser::IParser>> parsers_;
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>> engines_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ constexpr const char* kCachePath = "trt_engine_cache_path";
constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable";
constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path";
constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build";
constexpr const char* kTimingCacheEnable = "trt_timing_cache_enable";
// add new provider option name here.
} // namespace provider_option_names
} // namespace tensorrt
Expand Down Expand Up @@ -64,7 +65,8 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable)
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path)
.AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build)
.Parse(options)); // add new provider option here.
.AddAssignmentToReference(tensorrt::provider_option_names::kTimingCacheEnable, info.timing_cache_enable)
.Parse(options)); //add new provider option here.

return info;
}
Expand All @@ -88,6 +90,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)},
{tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)},
{tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)},
{tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.timing_cache_enable)},
// add new provider option here.
};
return options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct TensorrtExecutionProviderInfo {
bool engine_decryption_enable{false};
std::string engine_decryption_lib_path{""};
bool force_sequential_engine_build{false};
bool timing_cache_enable{false};

static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct Tensorrt_Provider : Provider {
info.engine_decryption_enable = options.trt_engine_decryption_enable != 0;
info.engine_decryption_lib_path = options.trt_engine_decryption_lib_path == nullptr ? "" : options.trt_engine_decryption_lib_path;
info.force_sequential_engine_build = options.trt_force_sequential_engine_build != 0;
info.timing_cache_enable = options.trt_timing_cache_enable != 0;
return std::make_shared<TensorrtProviderFactory>(info);
}

Expand Down Expand Up @@ -135,6 +136,7 @@ struct Tensorrt_Provider : Provider {
}

trt_options.trt_force_sequential_engine_build = internal_options.force_sequential_engine_build;
trt_options.trt_timing_cache_enable = internal_options.timing_cache_enable;
}

ProviderOptions GetProviderOptions(const void* provider_options) override {
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,8 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti
trt_options_converted.trt_engine_decryption_lib_path = legacy_trt_options->trt_engine_decryption_lib_path;
trt_options_converted.trt_force_sequential_engine_build = legacy_trt_options->trt_force_sequential_engine_build;
// Add new provider option below
// Use default value as this field is not available in OrtTensorRTProviderOptionsV
// Use default value as this field is not available in OrtTensorRTProviderOptions
trt_options_converted.trt_timing_cache_enable = 0;

return trt_options_converted;
}
Expand Down Expand Up @@ -1489,6 +1490,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRT
(*out)->trt_engine_decryption_enable = false;
(*out)->trt_engine_decryption_lib_path = nullptr;
(*out)->trt_force_sequential_engine_build = false;
(*out)->trt_timing_cache_enable = false;
return nullptr;
#else
ORT_UNUSED_PARAMETER(out);
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
nullptr,
0,
nullptr,
0,
0};
for (auto option : it->second) {
if (option.first == "device_id") {
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
bool trt_engine_decryption_enable = false;
std::string trt_engine_decryption_lib_path = "";
bool trt_force_sequential_engine_build = false;
bool trt_timing_cache_enable = false;

#ifdef _MSC_VER
std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
Expand Down Expand Up @@ -206,8 +207,16 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_sequential_engine_build' should be a boolean i.e. true or false. Default value is false.\n");
}
} else if (key == "trt_timing_cache_enable") {
if (value == "true" || value == "True") {
trt_timing_cache_enable = true;
} else if (value == "false" || value == "False") {
trt_timing_cache_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_timing_cache_enable' should be a boolean i.e. true or false. Default value is false.\n");
}
} else {
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build'] \n");
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_timing_cache_enable'] \n");
}
}
OrtTensorRTProviderOptionsV2 tensorrt_options;
Expand All @@ -229,6 +238,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
tensorrt_options.trt_engine_decryption_enable = trt_engine_decryption_enable;
tensorrt_options.trt_engine_decryption_lib_path = trt_engine_decryption_lib_path.c_str();
tensorrt_options.trt_force_sequential_engine_build = trt_force_sequential_engine_build;
tensorrt_options.trt_timing_cache_enable = trt_timing_cache_enable;
session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options);

OrtCUDAProviderOptions cuda_options;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/providers/cpu/model_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ TEST_P(ModelTest, Run) {
nullptr,
0,
nullptr,
0,
0};
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(TensorrtExecutionProviderWithOptions(&params)));
} else {
Expand Down
Loading