Skip to content

Commit

Permalink
delay tensorrt registry (#45824)
Browse files Browse the repository at this point in the history
* Delay TensorRT registry
* Add unused define
* Fix TensorRT test
* fix function to reference
* Update trt_plugin.h
  • Loading branch information
JZZ-NOTE authored Sep 14, 2022
1 parent 6891a4f commit d7d35ff
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
Expand Down Expand Up @@ -117,6 +118,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
framework::ir::Graph *graph) const {
framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph);

static std::once_flag trt_plugin_registered;
std::call_once(trt_plugin_registered, []() {
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
});

auto model_precision =
static_cast<phi::DataType>(Get<int>("model_precision"));
if (model_precision == phi::DataType::BFLOAT16) {
Expand Down
38 changes: 35 additions & 3 deletions paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ namespace inference {
namespace tensorrt {
namespace plugin {

#if defined(_WIN32)
#define UNUSED
#define __builtin_expect(EXP, C) (EXP)
#else
#define UNUSED __attribute__((unused))
#endif

class PluginTensorRT;

typedef std::function<PluginTensorRT*(const void*, size_t)>
Expand Down Expand Up @@ -372,6 +379,26 @@ class TensorRTPluginCreator : public nvinfer1::IPluginCreator {
std::vector<nvinfer1::PluginField> plugin_attributes_;
};

class TrtPluginRegistry {
public:
static TrtPluginRegistry* Global() {
static TrtPluginRegistry registry;
return &registry;
}
bool Regist(const std::string& name, const std::function<void()>& func) {
map.emplace(name, func);
return true;
}
void RegistToTrt() {
for (auto& it : map) {
it.second();
}
}

private:
std::unordered_map<std::string, std::function<void()>> map;
};

template <typename T>
class TrtPluginRegistrarV2 {
public:
Expand All @@ -386,9 +413,14 @@ class TrtPluginRegistrarV2 {
T creator;
};

#define REGISTER_TRT_PLUGIN_V2(name) \
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
plugin_registrar_##name {}
#define REGISTER_TRT_PLUGIN_V2(name) REGISTER_TRT_PLUGIN_V2_HELPER(name)

#define REGISTER_TRT_PLUGIN_V2_HELPER(name) \
UNUSED static bool REGISTER_TRT_PLUGIN_V2_HELPER##name = \
TrtPluginRegistry::Global()->Regist(#name, []() -> void { \
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
plugin_registrar_##name{}; \
});

} // namespace plugin
} // namespace tensorrt
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {

TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000)
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
auto *attn = engine_->DeclareInput(
"attn", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4});
auto *x = engine_->DeclareInput(
Expand Down
20 changes: 5 additions & 15 deletions paddle/fluid/platform/dynload/tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,11 @@ void* GetDsoHandle(const std::string& dso_name) {

void* dso_handle = dlopen(dso_name.c_str(), dynload_flags);

if (nullptr == dso_handle) {
auto error_msg =
"You are using Paddle compiled with TensorRT, but TensorRT dynamic "
"library is not found. Ignore this if TensorRT is not needed.\n"
"The TensorRT that Paddle depends on is not configured correctly.\n"
" Suggestions:\n"
" 1. Check if the TensorRT is installed correctly and its version"
" is matched with paddlepaddle you installed.\n"
" 2. Configure environment variables as "
"follows:\n"
" - Linux: set LD_LIBRARY_PATH by `export LD_LIBRARY_PATH=...`\n"
" - Windows: set PATH by `set PATH=XXX;%PATH%`\n"
" - Mac: set DYLD_LIBRARY_PATH by `export DYLD_LIBRARY_PATH=...`\n";
LOG(WARNING) << error_msg;
}
PADDLE_ENFORCE_NOT_NULL(dso_handle,
paddle::platform::errors::NotFound(
"TensorRT is needed, "
"but TensorRT dynamic library is not found."));

return dso_handle;
}

Expand Down
19 changes: 4 additions & 15 deletions paddle/phi/backends/dynload/tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,10 @@ void* GetDsoHandle(const std::string& dso_name) {

void* dso_handle = dlopen(dso_name.c_str(), dynload_flags);

if (nullptr == dso_handle) {
auto error_msg =
"You are using Paddle compiled with TensorRT, but TensorRT dynamic "
"library is not found. Ignore this if TensorRT is not needed.\n"
"The TensorRT that Paddle depends on is not configured correctly.\n"
" Suggestions:\n"
" 1. Check if the TensorRT is installed correctly and its version"
" is matched with paddlepaddle you installed.\n"
" 2. Configure environment variables as "
"follows:\n"
" - Linux: set LD_LIBRARY_PATH by `export LD_LIBRARY_PATH=...`\n"
" - Windows: set PATH by `set PATH=XXX;%PATH%`\n"
" - Mac: set DYLD_LIBRARY_PATH by `export DYLD_LIBRARY_PATH=...`\n";
LOG(WARNING) << error_msg;
}
PADDLE_ENFORCE_NOT_NULL(dso_handle,
paddle::platform::errors::NotFound(
"TensorRT is needed, "
"but TensorRT dynamic library is not found."));
return dso_handle;
}

Expand Down

0 comments on commit d7d35ff

Please sign in to comment.