4545#include " ../../te/operation/create_primfunc.h"
4646#include " ../op/memory/memory.h"
4747#include " ../transforms/pass_utils.h"
48+ #include " tvm/relay/op_strategy.h"
4849#include " utils.h"
4950
5051namespace tvm {
@@ -115,99 +116,24 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
115116}
116117
117118// Construct a schedule for a given Relay primitive function and target.
118- class ScheduleBuilder : public backend ::MemoizedExprTranslator<Array<te::Tensor>> {
119+ class LowerToTECompute : public backend ::MemoizedExprTranslator<Array<te::Tensor>> {
119120 public:
120- explicit ScheduleBuilder (Target target, bool create_schedule = true )
121- : target_(target),
122- device_copy_op_(Op::Get(" device_copy" )),
123- create_schedule_(create_schedule) {
124- // Whether to use auto_scheduler schedule.
125- use_auto_scheduler_ = backend::IsAutoSchedulerEnabled ();
126- use_meta_schedule_ = backend::IsMetaScheduleEnabled ();
127- }
121+ explicit LowerToTECompute (Target target)
122+ : target_(target), device_copy_op_(Op::Get(" device_copy" )) {}
128123
129- CachedFunc Create (const Function& relay_func, std::function<std::string(std::string)> renamer) {
130- Array<tvm::te::Tensor> fn_inputs;
124+ Array<te::Tensor> Lower (const Function& relay_func,
125+ std::function<std::string(std::string)> renamer) {
131126 for (Var param : relay_func->params ) {
132127 Array<tvm::te::Tensor> inputs;
133128 for (const auto & ttype : FlattenTupleType (param->checked_type ())) {
134129 tvm::te::Tensor tensor = tvm::te::placeholder (GetShape (ttype->shape ), ttype->dtype );
135- fn_inputs.push_back (tensor);
136130 inputs.push_back (tensor);
131+ fn_inputs_.push_back (tensor);
137132 }
138133 memo_[param] = inputs;
139134 }
140135 readable_name_stream_ << " fused" ;
141- auto outputs = this ->VisitExpr (relay_func->body );
142- auto candidate_name = readable_name_stream_.str ();
143- constexpr static size_t kMaxFuncNameLength = 80 ;
144- // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
145- // whenever the value of kMaxFuncNameLength changes
146- if (candidate_name.size () > kMaxFuncNameLength ) {
147- std::stringstream truncated_name;
148- truncated_name << candidate_name.substr (0 , kMaxFuncNameLength );
149- truncated_name << " _" << std::hex << std::hash<std::string>{}(candidate_name) << " _" ;
150- candidate_name = truncated_name.str ();
151- }
152-
153- // TODO(mbs): This should be the definitive global by which the PrimFunc is known and
154- // no other GlobalVar ctors should appear inside the lowering machinery.
155- auto prim_fn_var = GlobalVar (renamer (candidate_name));
156- prim_fn_var->checked_type_ = relay_func->checked_type ();
157-
158- // Fusion over tupled results may leave identity relationships
159- // between inputs and outputs, and those should not be scheduled.
160- // Hence schedule only non PlaceholderOp outputs.
161- tvm::Array<te::Tensor> tensor_outs;
162- for (const auto & tensor : outputs) {
163- if (!tensor->op .as <te::PlaceholderOpNode>()) {
164- tensor_outs.push_back (tensor);
165- }
166- }
167-
168- te::Schedule schedule{nullptr };
169- tir::PrimFunc prim_func{nullptr };
170- // No need to register schedule for device copy op.
171- if (anchor_attrs_.as <DeviceCopyAttrs>() == nullptr && create_schedule_) {
172- if (use_auto_scheduler_) {
173- const auto * fauto_schedule =
174- runtime::Registry::Get (" auto_scheduler.relay_integration.auto_schedule_topi_compute" );
175- ICHECK (fauto_schedule != nullptr )
176- << " auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered" ;
177- ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint , tensor_outs);
178- if (obj.defined ()) {
179- schedule = Downcast<te::Schedule>(obj);
180- }
181- }
182- if (use_meta_schedule_) {
183- prim_func = tir::CreatePrimFunc (Concat (fn_inputs, tensor_outs));
184- Optional<ObjectRef> opt_mod_or_base_func =
185- meta_schedule::MetaScheduleContext::QueryInsideWithScope (
186- prim_fn_var->name_hint , IRModule ({{prim_fn_var, relay_func}}), target_,
187- Array<IRModule>{IRModule ({{prim_fn_var, prim_func}})});
188- if (const auto * result = opt_mod_or_base_func.as <tir::PrimFuncNode>()) {
189- prim_func = GetRef<tir::PrimFunc>(result);
190- } else {
191- prim_func = tir::PrimFunc (nullptr );
192- }
193- }
194-
195- // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
196- if (!schedule.defined () && !prim_func.defined ()) {
197- ICHECK (anchor_implementation_.defined ());
198- schedule = anchor_implementation_.Schedule (anchor_attrs_, tensor_outs, target_);
199- }
200- if (schedule.defined ()) {
201- for (const auto & scalar : scalars_) {
202- if (schedule->Contain (scalar)) {
203- schedule[scalar].compute_inline ();
204- }
205- }
206- }
207- }
208-
209- return CachedFunc (target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
210- IRModule (Map<GlobalVar, BaseFunc>({})), constant_tensors_);
136+ return this ->VisitExpr (relay_func->body );
211137 }
212138
213139 Array<te::Tensor> VisitExpr_ (const VarNode* op) final {
@@ -254,7 +180,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
254180 }
255181
256182 Array<te::Tensor> VisitExpr_ (const CallNode* call_node) final {
257- static auto fpattern = Op::GetAttrMap<TOpPattern>(" TOpPattern" );
258183 static auto flower_call = tvm::runtime::Registry::Get (" relay.backend.lower_call" );
259184 ICHECK (flower_call) << " relay.backend.lower_call is not registered." ;
260185
@@ -278,28 +203,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
278203 ICHECK (call_node->op .as <OpNode>()) << " Primitive function only allows call into primitive ops" ;
279204 Op op = Downcast<Op>(call_node->op );
280205
281- Array<te::Tensor> outputs;
282- OpImplementation impl;
283206 // TODO(mbs): device_copy cleanup
284207 ICHECK_NE (op, device_copy_op_) << " device_copy cannot be lowered" ;
208+
285209 LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
286- outputs = lowered_out->outputs ;
287- impl = lowered_out->implementation ;
210+ Array<te::Tensor> outputs = lowered_out->outputs ;
211+ anchor_implementation_ = lowered_out->implementation ;
288212
289- if (create_schedule_) {
290- int op_pattern = fpattern[op];
291- if (!use_auto_scheduler_ && op_pattern >= kCommReduce ) {
292- ICHECK (!anchor_op_.defined () || anchor_op_pattern_ < kCommReduce )
293- << " Cannot apply TOPI schedule to a primitive function with two complicated ops"
294- << " anchor=" << anchor_op_ << " current=" << op;
295- }
296- if (op_pattern >= anchor_op_pattern_) {
297- anchor_op_ = op;
298- anchor_attrs_ = call_node->attrs ;
299- anchor_op_pattern_ = op_pattern;
300- anchor_implementation_ = impl;
301- }
302- }
303213 if (outputs.size () != 1 ) {
304214 const auto * tuple_type = call_node->checked_type ().as <TupleTypeNode>();
305215 ICHECK (tuple_type) << " Expected output to be a tuple type "
@@ -308,8 +218,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
308218 ICHECK_EQ (tuple_type->fields .size (), outputs.size ());
309219 }
310220
311- // TODO(mbs): device_copy cleanup
312- ICHECK_NE (op, device_copy_op_) << " device_copy cannot be lowered" ;
313221 readable_name_stream_ << ' _' << op->name ;
314222 return outputs;
315223 }
@@ -347,27 +255,146 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
347255 return {tuple[op->index ]};
348256 }
349257
258+ public:
259+ // Additional outputs
260+ Array<tvm::te::Tensor> fn_inputs_;
261+ Array<te::Operation> scalars_;
262+ std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
263+ std::ostringstream readable_name_stream_;
264+ OpImplementation anchor_implementation_;
265+
266+ private:
267+ tvm::Target target_;
268+ // Index of the global constants
269+ static int const_index;
270+ // Cache device copy op for equivalence checking to reduce registry lookup
271+ // overhead for each invocation of call node when retrieving schedules.
272+ const Op& device_copy_op_;
273+ };
274+
275+ int LowerToTECompute::const_index = 0 ;
276+
277+ // Construct a schedule for a given Relay primitive function and target.
278+ class ScheduleBuilder : ExprVisitor {
279+ public:
280+ explicit ScheduleBuilder (Target target, bool create_schedule = true )
281+ : target_(target),
282+
283+ create_schedule_(create_schedule) {
284+ // Whether to use auto_scheduler schedule.
285+ use_auto_scheduler_ = backend::IsAutoSchedulerEnabled ();
286+ use_meta_schedule_ = backend::IsMetaScheduleEnabled ();
287+ }
288+
289+ CachedFunc Create (const Function& relay_func, std::function<std::string(std::string)> renamer) {
290+ LowerToTECompute lower_te_compute (target_);
291+ Array<te::Tensor> outputs = lower_te_compute.Lower (relay_func, renamer);
292+ std::string candidate_name = lower_te_compute.readable_name_stream_ .str ();
293+ VisitExpr (relay_func->body );
294+
295+ constexpr static size_t kMaxFuncNameLength = 80 ;
296+ // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
297+ // whenever the value of kMaxFuncNameLength changes
298+ if (candidate_name.size () > kMaxFuncNameLength ) {
299+ std::stringstream truncated_name;
300+ truncated_name << candidate_name.substr (0 , kMaxFuncNameLength );
301+ truncated_name << " _" << std::hex << std::hash<std::string>{}(candidate_name) << " _" ;
302+ candidate_name = truncated_name.str ();
303+ }
304+
305+ // TODO(mbs): This should be the definitive global by which the PrimFunc is known and
306+ // no other GlobalVar ctors should appear inside the lowering machinery.
307+ auto prim_fn_var = GlobalVar (renamer (candidate_name));
308+ prim_fn_var->checked_type_ = relay_func->checked_type ();
309+
310+ // Fusion over tupled results may leave identity relationships
311+ // between inputs and outputs, and those should not be scheduled.
312+ // Hence schedule only non PlaceholderOp outputs.
313+ tvm::Array<te::Tensor> tensor_outs;
314+ for (const auto & tensor : outputs) {
315+ if (!tensor->op .as <te::PlaceholderOpNode>()) {
316+ tensor_outs.push_back (tensor);
317+ }
318+ }
319+
320+ te::Schedule schedule{nullptr };
321+ tir::PrimFunc prim_func{nullptr };
322+ // No need to register schedule for device copy op.
323+ if (anchor_attrs_.as <DeviceCopyAttrs>() == nullptr && create_schedule_) {
324+ if (use_auto_scheduler_) {
325+ const auto * fauto_schedule =
326+ runtime::Registry::Get (" auto_scheduler.relay_integration.auto_schedule_topi_compute" );
327+ ICHECK (fauto_schedule != nullptr )
328+ << " auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered" ;
329+ ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint , tensor_outs);
330+ if (obj.defined ()) {
331+ schedule = Downcast<te::Schedule>(obj);
332+ }
333+ }
334+ if (use_meta_schedule_) {
335+ prim_func = tir::CreatePrimFuncFromOutputs (tensor_outs);
336+ Optional<ObjectRef> opt_mod_or_base_func =
337+ meta_schedule::MetaScheduleContext::QueryInsideWithScope (
338+ prim_fn_var->name_hint , IRModule ({{prim_fn_var, relay_func}}), target_,
339+ Array<IRModule>{IRModule ({{prim_fn_var, prim_func}})});
340+ if (const auto * result = opt_mod_or_base_func.as <tir::PrimFuncNode>()) {
341+ prim_func = GetRef<tir::PrimFunc>(result);
342+ } else {
343+ prim_func = tir::PrimFunc (nullptr );
344+ }
345+ }
346+
347+ // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
348+ if (!schedule.defined () && !prim_func.defined ()) {
349+ ICHECK (lower_te_compute.anchor_implementation_ .defined ());
350+ schedule =
351+ lower_te_compute.anchor_implementation_ .Schedule (anchor_attrs_, tensor_outs, target_);
352+ }
353+ if (schedule.defined ()) {
354+ for (const auto & scalar : lower_te_compute.scalars_ ) {
355+ if (schedule->Contain (scalar)) {
356+ schedule[scalar].compute_inline ();
357+ }
358+ }
359+ }
360+ }
361+
362+ return CachedFunc (target_, prim_fn_var, lower_te_compute.fn_inputs_ , outputs, schedule,
363+ prim_func, {}, IRModule (Map<GlobalVar, BaseFunc>({})),
364+ lower_te_compute.constant_tensors_ );
365+ }
366+
367+ void VisitExpr_ (const CallNode* call_node) final {
368+ static auto fpattern = Op::GetAttrMap<TOpPattern>(" TOpPattern" );
369+
370+ ICHECK (call_node->op .as <OpNode>()) << " Primitive function only allows call into primitive ops" ;
371+ Op op = Downcast<Op>(call_node->op );
372+
373+ if (create_schedule_) {
374+ int op_pattern = fpattern[op];
375+ if (!use_auto_scheduler_ && op_pattern >= kCommReduce ) {
376+ ICHECK (!anchor_op_.defined () || anchor_op_pattern_ < kCommReduce )
377+ << " Cannot apply TOPI schedule to a primitive function with two complicated ops"
378+ << " anchor=" << anchor_op_ << " current=" << op;
379+ }
380+ if (op_pattern >= anchor_op_pattern_) {
381+ anchor_op_ = op;
382+ anchor_attrs_ = call_node->attrs ;
383+ anchor_op_pattern_ = op_pattern;
384+ }
385+ }
386+ }
387+
350388 private:
351389 tvm::Target target_;
352390 Op anchor_op_;
353391 Attrs anchor_attrs_;
354392 int anchor_op_pattern_{0 };
355- OpImplementation anchor_implementation_;
356- std::ostringstream readable_name_stream_;
357- Array<te::Operation> scalars_;
358- std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
359393 bool use_auto_scheduler_;
360394 bool use_meta_schedule_;
361- // Cache device copy op for equivalence checking to reduce registry lookup
362- // overhead for each invocation of call node when retrieving schedules.
363- const Op& device_copy_op_;
364395 bool create_schedule_;
365- // Index of the global constants
366- static int const_index;
367396};
368397
369- int ScheduleBuilder::const_index = 0 ;
370-
371398/* !
372399 * \brief Create schedule for target.
373400 * \param source_func The primitive function to be lowered.
0 commit comments