@@ -70,51 +70,28 @@ class TensorRTCompilerConfig : public Attrs {
7070TVM_REGISTER_NODE_TYPE (TensorRTCompilerConfigNode);
7171TVM_REGISTER_PASS_CONFIG_OPTION (" relay.ext.tensorrt.options" , TensorRTCompilerConfig);
7272
73+ using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
74+ using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
75+ using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr;
76+ using OpAttrExtractor = backend::contrib::OpAttrExtractor;
77+ using JSONSerializer = backend::contrib::JSONSerializer;
78+
79+ class TensorRTJSONSerializer ;
80+
7381/* !
74- * \brief Generates an TensorRTModule from a relay expression by serializing the expression to a
75- * json representation. TensorRT is not required here because use of TensorRT APIs is deferred until
76- * runtime.
82+ * \brief Collect the constants and attributes from all operator calls in the body
83+ * of a "Composite" function.
7784 */
78- class TensorRTJSONSerializer : public backend ::contrib::JSONSerializer {
79- using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
80- using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
81-
85+ class CollectFromCompositeFunctionBody : public ExprVisitor {
8286 public:
83- TensorRTJSONSerializer (const std::string& symbol, const Expr& expr)
84- : JSONSerializer(symbol, expr) {}
85-
86- std::vector<JSONGraphNodeEntry> VisitExpr_ (const CallNode* cn) {
87- std::string name;
88- if (const auto * op_node = cn->op .as <OpNode>()) {
89- name = op_node->name ;
90- } else {
91- return JSONSerializer::VisitExpr_ (cn);
92- }
87+ explicit CollectFromCompositeFunctionBody (TensorRTJSONSerializer* serializer)
88+ : serializer_(serializer), node_(std::make_shared<JSONGraphNode>()) {}
9389
94- std::vector<JSONGraphNodeEntry> inputs;
95- for (const auto & arg : cn->args ) {
96- auto res = VisitExpr (arg);
97- inputs.insert (inputs.end (), res.begin (), res.end ());
98- }
99- auto node = std::make_shared<JSONGraphNode>(name, /* name_ */
100- " kernel" , /* op_type_ */
101- inputs, 1 /* num_outputs_ */ );
102- if (name == " nn.pad" ) {
103- SetPadNodeAttribute (node, cn);
104- } else if (name == " strided_slice" ) {
105- SetStridedSliceNodeAttribute (node, cn);
106- } else if (name == " split" ) {
107- SetSplitNodeAttribute (node, cn);
108- } else {
109- SetCallNodeAttribute (node, cn);
110- }
111- // These attributes are global to the whole module.
112- SaveGlobalAttributes (node);
113- return AddNode (node, GetRef<Expr>(cn));
114- }
90+ void VisitExpr_ (const ConstantNode* constant_node) final ;
91+ void VisitExpr_ (const CallNode* call_node) final ;
11592
116- void SetPadNodeAttribute (std::shared_ptr<JSONGraphNode> node, const CallNode* cn ) {
117- const auto * pad_attr = cn ->attrs .as <PadAttrs>();
93+ void SetPadNodeAttribute (const CallNode* call_node ) {
94+ const auto * pad_attr = call_node ->attrs .as <PadAttrs>();
11895 ICHECK (pad_attr);
11996 auto p = pad_attr->pad_width ;
12097 const int dim_h = (p.size () == 5 ) ? 3 : 2 ;
@@ -125,16 +102,16 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
125102 std::to_string (p[dim_w][1 ].as <IntImmNode>()->value )};
126103 std::vector<dmlc::any> padding_attr;
127104 padding_attr.emplace_back (padding);
128- node ->SetAttr (" padding" , padding_attr);
105+ node_ ->SetAttr (" padding" , padding_attr);
129106 }
130107
131- void SetStridedSliceNodeAttribute (std::shared_ptr<JSONGraphNode> node, const CallNode* cn ) {
132- const auto * attrs = cn ->attrs .as <StridedSliceAttrs>();
108+ void SetStridedSliceNodeAttribute (const CallNode* call_node ) {
109+ const auto * attrs = call_node ->attrs .as <StridedSliceAttrs>();
133110 ICHECK (attrs && attrs->begin && attrs->end && attrs->strides )
134111 << " StridedSlice must have static begin, end, and strides." ;
135112 const bool default_strides =
136113 !attrs->strides .value ().defined () || attrs->strides .value ().size () == 0 ;
137- auto ishape = backend::GetShape (cn ->args [0 ]->checked_type ());
114+ auto ishape = backend::GetShape (call_node ->args [0 ]->checked_type ());
138115
139116 auto process_slice_index = [](Integer x, int default_value, int dim_value) {
140117 if (!x.defined ()) return default_value;
@@ -173,19 +150,19 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
173150 start_attr.emplace_back (start);
174151 size_attr.emplace_back (size);
175152 strides_attr.emplace_back (strides);
176- node ->SetAttr (" start" , start_attr);
177- node ->SetAttr (" size" , size_attr);
178- node ->SetAttr (" strides" , strides_attr);
153+ node_ ->SetAttr (" start" , start_attr);
154+ node_ ->SetAttr (" size" , size_attr);
155+ node_ ->SetAttr (" strides" , strides_attr);
179156 }
180157
181- void SetSplitNodeAttribute (std::shared_ptr<JSONGraphNode> node, const CallNode* cn ) {
182- const auto * split_attr = cn ->attrs .as <SplitAttrs>();
158+ void SetSplitNodeAttribute (const CallNode* call_node ) {
159+ const auto * split_attr = call_node ->attrs .as <SplitAttrs>();
183160 ICHECK (split_attr);
184161
185162 std::vector<std::string> indices_or_sections;
186163 std::vector<std::string> mode;
187164 std::vector<std::string> axis = {std::to_string (split_attr->axis )};
188- if (const IntImmNode * sections = split_attr->indices_or_sections .as <IntImmNode>()) {
165+ if (const auto * sections = split_attr->indices_or_sections .as <IntImmNode>()) {
189166 mode.emplace_back (" sections" );
190167 indices_or_sections.emplace_back (std::to_string (sections->value ));
191168 } else {
@@ -202,12 +179,80 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
202179 indices_or_sections_attr.emplace_back (indices_or_sections);
203180 mode_attr.emplace_back (mode);
204181 axis_attr.emplace_back (axis);
205- node->SetAttr (" indices_or_sections" , indices_or_sections_attr);
206- node->SetAttr (" mode" , mode_attr);
207- node->SetAttr (" axis" , axis_attr);
182+ node_->SetAttr (" indices_or_sections" , indices_or_sections_attr);
183+ node_->SetAttr (" mode" , mode_attr);
184+ node_->SetAttr (" axis" , axis_attr);
185+ }
186+
187+ void SetGenericAttributes (const CallNode* call_node) {
188+ OpAttrExtractor extractor (node_);
189+ const Object* attr_obj = call_node->attrs .get ();
190+ extractor.Extract (const_cast <Object*>(attr_obj));
208191 }
209192
210- void SaveGlobalAttributes (std::shared_ptr<JSONGraphNode> node) {
193+ TensorRTJSONSerializer* serializer_;
194+ /* ! \brief Accumulated translated arguments. */
195+ std::vector<JSONGraphNodeEntry> args_;
196+ /* !
197+ * \brief Temporary node into which we'll accumulate attributes. Ideally this would be the
198+ * final JSONGraphNode however we don't yet know how many inputs that will have.
199+ */
200+ JSONGraphObjectPtr node_;
201+ };
202+
203+ /* !
204+ * \brief Generates an TensorRTModule from a relay expression by serializing the expression to a
205+ * json representation. TensorRT is not required here because use of TensorRT APIs is deferred until
206+ * runtime.
207+ */
208+ class TensorRTJSONSerializer : public JSONSerializer {
209+ public:
210+ TensorRTJSONSerializer (const std::string& symbol, const Expr& expr)
211+ : JSONSerializer(symbol, expr) {}
212+
213+ using JSONSerializer::VisitExpr_;
214+
215+ std::vector<JSONGraphNodeEntry> VisitExpr_ (const CallNode* call_node) final {
216+ // The call must be to an inline "Composite" function
217+ const auto * function_node = call_node->op .as <FunctionNode>();
218+ ICHECK (function_node != nullptr );
219+ auto opt_composite = function_node->GetAttr <String>(attr::kComposite );
220+ ICHECK (opt_composite.defined ());
221+ std::string name = opt_composite.value ();
222+
223+ // Collect the constants and attributes of all operator calls inside the composite body.
224+ CollectFromCompositeFunctionBody collector (this );
225+ collector.VisitExpr (function_node->body );
226+
227+ // Capture the args to the "Composite" function as inputs for this node.
228+ std::vector<JSONGraphNodeEntry> inputs;
229+ for (const auto & arg : call_node->args ) {
230+ auto res = VisitExpr (arg);
231+ inputs.insert (inputs.end (), res.begin (), res.end ());
232+ }
233+
234+ // Capture constants from the composite function body as additional inputs for this node.
235+ for (const auto & node : collector.args_ ) {
236+ inputs.emplace_back (node);
237+ }
238+
239+ // Create the final node.
240+ auto node = std::make_shared<JSONGraphNode>(name,
241+ /* op_type=*/ " kernel" , inputs,
242+ /* num_output=*/ 1 );
243+
244+ // Transfer attributes from the collector's node to the final node.
245+ node->CaptureAttrs (*collector.node_ );
246+
247+ // Capture global settings on the JSON node.
248+ SaveGlobalAttributes (node);
249+
250+ VLOG (1 ) << name << " has " << node->GetInputs ().size () << " inputs" ;
251+
252+ return AddNode (node, GetRef<Expr>(call_node));
253+ }
254+
255+ static void SaveGlobalAttributes (std::shared_ptr<JSONGraphNode> node) {
211256 auto ctx = transform::PassContext::Current ();
212257 auto cfg = ctx->GetConfig <TensorRTCompilerConfig>(" relay.ext.tensorrt.options" );
213258 if (!cfg.defined ()) {
@@ -236,6 +281,28 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
236281 }
237282};
238283
284+ void CollectFromCompositeFunctionBody::VisitExpr_ (const ConstantNode* constant_node) {
285+ for (const auto & entry : serializer_->VisitExpr (GetRef<Constant>(constant_node))) {
286+ args_.emplace_back (entry);
287+ }
288+ }
289+
290+ void CollectFromCompositeFunctionBody::VisitExpr_ (const CallNode* call_node) {
291+ const auto * op_node = call_node->op .as <OpNode>();
292+ ICHECK (op_node != nullptr );
293+ std::string name = op_node->name ;
294+ if (name == " nn.pad" ) {
295+ SetPadNodeAttribute (call_node);
296+ } else if (name == " strided_slice" ) {
297+ SetStridedSliceNodeAttribute (call_node);
298+ } else if (name == " split" ) {
299+ SetSplitNodeAttribute (call_node);
300+ } else {
301+ SetGenericAttributes (call_node);
302+ }
303+ ExprVisitor::VisitExpr_ (call_node);
304+ }
305+
239306/* !
240307 * \brief Create a runtime module for TensorRT.
241308 * \param ref The ext_func Relay expression/module to be executed using extern ops.
@@ -246,12 +313,15 @@ runtime::Module TensorRTCompiler(const ObjectRef& ref) {
246313 Function func = Downcast<Function>(ref);
247314 std::string func_name = backend::GetExtSymbol (func);
248315
316+ VLOG (1 ) << " TensorRT partition:" << std::endl << PrettyPrint (func);
249317 TensorRTJSONSerializer serializer (func_name, func);
250318 serializer.serialize ();
251319 std::string graph_json = serializer.GetJSON ();
320+ VLOG (1 ) << " TensorRT JSON:" << std::endl << graph_json;
252321 auto param_names = serializer.GetParams ();
253322 const auto * pf = runtime::Registry::Get (" runtime.tensorrt_runtime_create" );
254323 ICHECK (pf != nullptr ) << " Cannot find TensorRT runtime module create function." ;
324+ VLOG (1 ) << " Creating tensorrt runtime::Module for '" << func_name << " '" ;
255325 runtime::Module lib = (*pf)(func_name, graph_json, param_names);
256326 return lib;
257327}
0 commit comments