Skip to content

Commit be2ae94

Browse files
mbaretmbs-octoml
andauthored
[TENSORRT] Improvements and fixes for TensorRT (#11203)
A number of small fixes and refactors to improve the robustness of the TensorRT integration. Co-authored-by: Mark Shields <[email protected]> Co-authored-by: Mark Shields <[email protected]>
1 parent 8d4f4dd commit be2ae94

File tree

12 files changed

+901
-1105
lines changed

12 files changed

+901
-1105
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 511 additions & 612 deletions
Large diffs are not rendered by default.

src/relay/backend/contrib/tensorrt/codegen.cc

Lines changed: 124 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -70,51 +70,28 @@ class TensorRTCompilerConfig : public Attrs {
7070
TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode);
7171
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.tensorrt.options", TensorRTCompilerConfig);
7272

73+
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
74+
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
75+
using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr;
76+
using OpAttrExtractor = backend::contrib::OpAttrExtractor;
77+
using JSONSerializer = backend::contrib::JSONSerializer;
78+
79+
class TensorRTJSONSerializer;
80+
7381
/*!
74-
* \brief Generates an TensorRTModule from a relay expression by serializing the expression to a
75-
* json representation. TensorRT is not required here because use of TensorRT APIs is deferred until
76-
* runtime.
82+
* \brief Collect the constants and attributes from all operator calls in the body
83+
* of a "Composite" function.
7784
*/
78-
class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
79-
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
80-
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
81-
85+
class CollectFromCompositeFunctionBody : public ExprVisitor {
8286
public:
83-
TensorRTJSONSerializer(const std::string& symbol, const Expr& expr)
84-
: JSONSerializer(symbol, expr) {}
85-
86-
std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) {
87-
std::string name;
88-
if (const auto* op_node = cn->op.as<OpNode>()) {
89-
name = op_node->name;
90-
} else {
91-
return JSONSerializer::VisitExpr_(cn);
92-
}
87+
explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer)
88+
: serializer_(serializer), node_(std::make_shared<JSONGraphNode>()) {}
9389

94-
std::vector<JSONGraphNodeEntry> inputs;
95-
for (const auto& arg : cn->args) {
96-
auto res = VisitExpr(arg);
97-
inputs.insert(inputs.end(), res.begin(), res.end());
98-
}
99-
auto node = std::make_shared<JSONGraphNode>(name, /* name_ */
100-
"kernel", /* op_type_ */
101-
inputs, 1 /* num_outputs_ */);
102-
if (name == "nn.pad") {
103-
SetPadNodeAttribute(node, cn);
104-
} else if (name == "strided_slice") {
105-
SetStridedSliceNodeAttribute(node, cn);
106-
} else if (name == "split") {
107-
SetSplitNodeAttribute(node, cn);
108-
} else {
109-
SetCallNodeAttribute(node, cn);
110-
}
111-
// These attributes are global to the whole module.
112-
SaveGlobalAttributes(node);
113-
return AddNode(node, GetRef<Expr>(cn));
114-
}
90+
void VisitExpr_(const ConstantNode* constant_node) final;
91+
void VisitExpr_(const CallNode* call_node) final;
11592

116-
void SetPadNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
117-
const auto* pad_attr = cn->attrs.as<PadAttrs>();
93+
void SetPadNodeAttribute(const CallNode* call_node) {
94+
const auto* pad_attr = call_node->attrs.as<PadAttrs>();
11895
ICHECK(pad_attr);
11996
auto p = pad_attr->pad_width;
12097
const int dim_h = (p.size() == 5) ? 3 : 2;
@@ -125,16 +102,16 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
125102
std::to_string(p[dim_w][1].as<IntImmNode>()->value)};
126103
std::vector<dmlc::any> padding_attr;
127104
padding_attr.emplace_back(padding);
128-
node->SetAttr("padding", padding_attr);
105+
node_->SetAttr("padding", padding_attr);
129106
}
130107

131-
void SetStridedSliceNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
132-
const auto* attrs = cn->attrs.as<StridedSliceAttrs>();
108+
void SetStridedSliceNodeAttribute(const CallNode* call_node) {
109+
const auto* attrs = call_node->attrs.as<StridedSliceAttrs>();
133110
ICHECK(attrs && attrs->begin && attrs->end && attrs->strides)
134111
<< "StridedSlice must have static begin, end, and strides.";
135112
const bool default_strides =
136113
!attrs->strides.value().defined() || attrs->strides.value().size() == 0;
137-
auto ishape = backend::GetShape(cn->args[0]->checked_type());
114+
auto ishape = backend::GetShape(call_node->args[0]->checked_type());
138115

139116
auto process_slice_index = [](Integer x, int default_value, int dim_value) {
140117
if (!x.defined()) return default_value;
@@ -173,19 +150,19 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
173150
start_attr.emplace_back(start);
174151
size_attr.emplace_back(size);
175152
strides_attr.emplace_back(strides);
176-
node->SetAttr("start", start_attr);
177-
node->SetAttr("size", size_attr);
178-
node->SetAttr("strides", strides_attr);
153+
node_->SetAttr("start", start_attr);
154+
node_->SetAttr("size", size_attr);
155+
node_->SetAttr("strides", strides_attr);
179156
}
180157

181-
void SetSplitNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
182-
const auto* split_attr = cn->attrs.as<SplitAttrs>();
158+
void SetSplitNodeAttribute(const CallNode* call_node) {
159+
const auto* split_attr = call_node->attrs.as<SplitAttrs>();
183160
ICHECK(split_attr);
184161

185162
std::vector<std::string> indices_or_sections;
186163
std::vector<std::string> mode;
187164
std::vector<std::string> axis = {std::to_string(split_attr->axis)};
188-
if (const IntImmNode* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
165+
if (const auto* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
189166
mode.emplace_back("sections");
190167
indices_or_sections.emplace_back(std::to_string(sections->value));
191168
} else {
@@ -202,12 +179,80 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
202179
indices_or_sections_attr.emplace_back(indices_or_sections);
203180
mode_attr.emplace_back(mode);
204181
axis_attr.emplace_back(axis);
205-
node->SetAttr("indices_or_sections", indices_or_sections_attr);
206-
node->SetAttr("mode", mode_attr);
207-
node->SetAttr("axis", axis_attr);
182+
node_->SetAttr("indices_or_sections", indices_or_sections_attr);
183+
node_->SetAttr("mode", mode_attr);
184+
node_->SetAttr("axis", axis_attr);
185+
}
186+
187+
void SetGenericAttributes(const CallNode* call_node) {
188+
OpAttrExtractor extractor(node_);
189+
const Object* attr_obj = call_node->attrs.get();
190+
extractor.Extract(const_cast<Object*>(attr_obj));
208191
}
209192

210-
void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
193+
TensorRTJSONSerializer* serializer_;
194+
/*! \brief Accumulated translated arguments. */
195+
std::vector<JSONGraphNodeEntry> args_;
196+
/*!
197+
* \brief Temporary node into which we'll accumulate attributes. Ideally this would be the
198+
* final JSONGraphNode however we don't yet know how many inputs that will have.
199+
*/
200+
JSONGraphObjectPtr node_;
201+
};
202+
203+
/*!
204+
* \brief Generates an TensorRTModule from a relay expression by serializing the expression to a
205+
* json representation. TensorRT is not required here because use of TensorRT APIs is deferred until
206+
* runtime.
207+
*/
208+
class TensorRTJSONSerializer : public JSONSerializer {
209+
public:
210+
TensorRTJSONSerializer(const std::string& symbol, const Expr& expr)
211+
: JSONSerializer(symbol, expr) {}
212+
213+
using JSONSerializer::VisitExpr_;
214+
215+
std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* call_node) final {
216+
// The call must be to an inline "Composite" function
217+
const auto* function_node = call_node->op.as<FunctionNode>();
218+
ICHECK(function_node != nullptr);
219+
auto opt_composite = function_node->GetAttr<String>(attr::kComposite);
220+
ICHECK(opt_composite.defined());
221+
std::string name = opt_composite.value();
222+
223+
// Collect the constants and attributes of all operator calls inside the composite body.
224+
CollectFromCompositeFunctionBody collector(this);
225+
collector.VisitExpr(function_node->body);
226+
227+
// Capture the args to the "Composite" function as inputs for this node.
228+
std::vector<JSONGraphNodeEntry> inputs;
229+
for (const auto& arg : call_node->args) {
230+
auto res = VisitExpr(arg);
231+
inputs.insert(inputs.end(), res.begin(), res.end());
232+
}
233+
234+
// Capture constants from the composite function body as additional inputs for this node.
235+
for (const auto& node : collector.args_) {
236+
inputs.emplace_back(node);
237+
}
238+
239+
// Create the final node.
240+
auto node = std::make_shared<JSONGraphNode>(name,
241+
/*op_type=*/"kernel", inputs,
242+
/*num_output=*/1);
243+
244+
// Transfer attributes from the collector's node to the final node.
245+
node->CaptureAttrs(*collector.node_);
246+
247+
// Capture global settings on the JSON node.
248+
SaveGlobalAttributes(node);
249+
250+
VLOG(1) << name << " has " << node->GetInputs().size() << " inputs";
251+
252+
return AddNode(node, GetRef<Expr>(call_node));
253+
}
254+
255+
static void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
211256
auto ctx = transform::PassContext::Current();
212257
auto cfg = ctx->GetConfig<TensorRTCompilerConfig>("relay.ext.tensorrt.options");
213258
if (!cfg.defined()) {
@@ -236,6 +281,28 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
236281
}
237282
};
238283

284+
void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) {
285+
for (const auto& entry : serializer_->VisitExpr(GetRef<Constant>(constant_node))) {
286+
args_.emplace_back(entry);
287+
}
288+
}
289+
290+
void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) {
291+
const auto* op_node = call_node->op.as<OpNode>();
292+
ICHECK(op_node != nullptr);
293+
std::string name = op_node->name;
294+
if (name == "nn.pad") {
295+
SetPadNodeAttribute(call_node);
296+
} else if (name == "strided_slice") {
297+
SetStridedSliceNodeAttribute(call_node);
298+
} else if (name == "split") {
299+
SetSplitNodeAttribute(call_node);
300+
} else {
301+
SetGenericAttributes(call_node);
302+
}
303+
ExprVisitor::VisitExpr_(call_node);
304+
}
305+
239306
/*!
240307
* \brief Create a runtime module for TensorRT.
241308
* \param ref The ext_func Relay expression/module to be executed using extern ops.
@@ -246,12 +313,15 @@ runtime::Module TensorRTCompiler(const ObjectRef& ref) {
246313
Function func = Downcast<Function>(ref);
247314
std::string func_name = backend::GetExtSymbol(func);
248315

316+
VLOG(1) << "TensorRT partition:" << std::endl << PrettyPrint(func);
249317
TensorRTJSONSerializer serializer(func_name, func);
250318
serializer.serialize();
251319
std::string graph_json = serializer.GetJSON();
320+
VLOG(1) << "TensorRT JSON:" << std::endl << graph_json;
252321
auto param_names = serializer.GetParams();
253322
const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create");
254323
ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function.";
324+
VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'";
255325
runtime::Module lib = (*pf)(func_name, graph_json, param_names);
256326
return lib;
257327
}

src/relay/transforms/inline_composites.cc

Lines changed: 0 additions & 94 deletions
This file was deleted.

src/runtime/contrib/json/json_node.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,12 @@ class JSONGraphNode {
281281
*/
282282
bool HasAttr(const std::string& key) const { return attrs_.find(key) != attrs_.end(); }
283283

284+
void CaptureAttrs(const JSONGraphNode& that) {
285+
for (const auto& kv : that.attrs_) {
286+
attrs_[kv.first] = kv.second;
287+
}
288+
}
289+
284290
virtual ~JSONGraphNode() {}
285291

286292
private:

0 commit comments

Comments
 (0)