4545#include " ../../te/operation/create_primfunc.h"
4646#include " ../op/memory/memory.h"
4747#include " ../transforms/pass_utils.h"
48+ #include " tvm/relay/op_strategy.h"
4849#include " utils.h"
4950
5051namespace tvm {
@@ -115,99 +116,25 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
115116}
116117
117118// Construct a schedule for a given Relay primitive function and target.
118- class ScheduleBuilder : public backend ::MemoizedExprTranslator<Array<te::Tensor>> {
119+ class LowerToTECompute : public backend ::MemoizedExprTranslator<Array<te::Tensor>> {
119120 public:
120- explicit ScheduleBuilder (Target target, bool create_schedule = true )
121- : target_(target),
122- device_copy_op_(Op::Get(" device_copy" )),
123- create_schedule_(create_schedule) {
124- // Whether to use auto_scheduler schedule.
125- use_auto_scheduler_ = backend::IsAutoSchedulerEnabled ();
126- use_meta_schedule_ = backend::IsMetaScheduleEnabled ();
127- }
121+ explicit LowerToTECompute (Target target)
122+ : target_(target), device_copy_op_(Op::Get(" device_copy" )) {}
128123
129- CachedFunc Create (const Function& relay_func, std::function<std::string(std::string)> renamer) {
130- Array<tvm::te::Tensor> fn_inputs;
124+ Array<te::Tensor> Lower (const Function& relay_func,
125+ std::function<std::string(std::string)> renamer) {
131126 for (Var param : relay_func->params ) {
132127 Array<tvm::te::Tensor> inputs;
133128 for (const auto & ttype : FlattenTupleType (param->checked_type ())) {
134129 tvm::te::Tensor tensor = tvm::te::placeholder (GetShape (ttype->shape ), ttype->dtype );
135- fn_inputs.push_back (tensor);
136130 inputs.push_back (tensor);
131+ fn_inputs_.push_back (tensor);
137132 }
138133 memo_[param] = inputs;
139134 }
140135 readable_name_stream_ << " fused" ;
141136 auto outputs = this ->VisitExpr (relay_func->body );
142- auto candidate_name = readable_name_stream_.str ();
143- constexpr static size_t kMaxFuncNameLength = 80 ;
144- // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
145- // whenever the value of kMaxFuncNameLength changes
146- if (candidate_name.size () > kMaxFuncNameLength ) {
147- std::stringstream truncated_name;
148- truncated_name << candidate_name.substr (0 , kMaxFuncNameLength );
149- truncated_name << " _" << std::hex << std::hash<std::string>{}(candidate_name) << " _" ;
150- candidate_name = truncated_name.str ();
151- }
152-
153- // TODO(mbs): This should be the definitive global by which the PrimFunc is known and
154- // no other GlobalVar ctors should appear inside the lowering machinery.
155- auto prim_fn_var = GlobalVar (renamer (candidate_name));
156- prim_fn_var->checked_type_ = relay_func->checked_type ();
157-
158- // Fusion over tupled results may leave identity relationships
159- // between inputs and outputs, and those should not be scheduled.
160- // Hence schedule only non PlaceholderOp outputs.
161- tvm::Array<te::Tensor> tensor_outs;
162- for (const auto & tensor : outputs) {
163- if (!tensor->op .as <te::PlaceholderOpNode>()) {
164- tensor_outs.push_back (tensor);
165- }
166- }
167-
168- te::Schedule schedule{nullptr };
169- tir::PrimFunc prim_func{nullptr };
170- // No need to register schedule for device copy op.
171- if (anchor_attrs_.as <DeviceCopyAttrs>() == nullptr && create_schedule_) {
172- if (use_auto_scheduler_) {
173- const auto * fauto_schedule =
174- runtime::Registry::Get (" auto_scheduler.relay_integration.auto_schedule_topi_compute" );
175- ICHECK (fauto_schedule != nullptr )
176- << " auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered" ;
177- ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint , tensor_outs);
178- if (obj.defined ()) {
179- schedule = Downcast<te::Schedule>(obj);
180- }
181- }
182- if (use_meta_schedule_) {
183- prim_func = tir::CreatePrimFuncFromOutputs (tensor_outs);
184- Optional<ObjectRef> opt_mod_or_base_func =
185- meta_schedule::MetaScheduleContext::QueryInsideWithScope (
186- prim_fn_var->name_hint , IRModule ({{prim_fn_var, relay_func}}), target_,
187- Array<IRModule>{IRModule ({{prim_fn_var, prim_func}})});
188- if (const auto * result = opt_mod_or_base_func.as <tir::PrimFuncNode>()) {
189- prim_func = GetRef<tir::PrimFunc>(result);
190- } else {
191- prim_func = tir::PrimFunc (nullptr );
192- }
193- }
194-
195- // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
196- if (!schedule.defined () && !prim_func.defined ()) {
197- ICHECK (anchor_implementation_.defined ());
198- schedule = anchor_implementation_.Schedule (anchor_attrs_, tensor_outs, target_);
199- }
200- if (schedule.defined ()) {
201- for (const auto & scalar : scalars_) {
202- if (schedule->Contain (scalar)) {
203- schedule[scalar].compute_inline ();
204- }
205- }
206- }
207- }
208-
209- return CachedFunc (target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
210- IRModule (Map<GlobalVar, BaseFunc>({})), constant_tensors_);
137+ return outputs;
211138 }
212139
213140 Array<te::Tensor> VisitExpr_ (const VarNode* op) final {
@@ -254,7 +181,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
254181 }
255182
256183 Array<te::Tensor> VisitExpr_ (const CallNode* call_node) final {
257- static auto fpattern = Op::GetAttrMap<TOpPattern>(" TOpPattern" );
258184 static auto flower_call = tvm::runtime::Registry::Get (" relay.backend.lower_call" );
259185 ICHECK (flower_call) << " relay.backend.lower_call is not registered." ;
260186
@@ -278,28 +204,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
278204 ICHECK (call_node->op .as <OpNode>()) << " Primitive function only allows call into primitive ops" ;
279205 Op op = Downcast<Op>(call_node->op );
280206
281- Array<te::Tensor> outputs;
282- OpImplementation impl;
283207 // TODO(mbs): device_copy cleanup
284208 ICHECK_NE (op, device_copy_op_) << " device_copy cannot be lowered" ;
209+
285210 LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
286- outputs = lowered_out->outputs ;
287- impl = lowered_out->implementation ;
211+ Array<te::Tensor> outputs = lowered_out->outputs ;
212+ anchor_implementation_ = lowered_out->implementation ;
288213
289- if (create_schedule_) {
290- int op_pattern = fpattern[op];
291- if (!use_auto_scheduler_ && op_pattern >= kCommReduce ) {
292- ICHECK (!anchor_op_.defined () || anchor_op_pattern_ < kCommReduce )
293- << " Cannot apply TOPI schedule to a primitive function with two complicated ops"
294- << " anchor=" << anchor_op_ << " current=" << op;
295- }
296- if (op_pattern >= anchor_op_pattern_) {
297- anchor_op_ = op;
298- anchor_attrs_ = call_node->attrs ;
299- anchor_op_pattern_ = op_pattern;
300- anchor_implementation_ = impl;
301- }
302- }
303214 if (outputs.size () != 1 ) {
304215 const auto * tuple_type = call_node->checked_type ().as <TupleTypeNode>();
305216 ICHECK (tuple_type) << " Expected output to be a tuple type "
@@ -308,8 +219,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
308219 ICHECK_EQ (tuple_type->fields .size (), outputs.size ());
309220 }
310221
311- // TODO(mbs): device_copy cleanup
312- ICHECK_NE (op, device_copy_op_) << " device_copy cannot be lowered" ;
313222 readable_name_stream_ << ' _' << op->name ;
314223 return outputs;
315224 }
@@ -347,27 +256,146 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
347256 return {tuple[op->index ]};
348257 }
349258
259+ public:
260+ // outputs
261+ Array<tvm::te::Tensor> fn_inputs_;
262+ Array<te::Operation> scalars_;
263+ std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
264+ std::ostringstream readable_name_stream_;
265+ OpImplementation anchor_implementation_;
266+
267+ private:
268+ tvm::Target target_;
269+ // Index of the global constants
270+ static int const_index;
271+ // Cache device copy op for equivalence checking to reduce registry lookup
272+ // overhead for each invocation of call node when retrieving schedules.
273+ const Op& device_copy_op_;
274+ };
275+
276+ int LowerToTECompute::const_index = 0 ;
277+
278+ // Construct a schedule for a given Relay primitive function and target.
279+ class ScheduleBuilder : ExprVisitor {
280+ public:
281+ explicit ScheduleBuilder (Target target, bool create_schedule = true )
282+ : target_(target),
283+
284+ create_schedule_(create_schedule) {
285+ // Whether to use auto_scheduler schedule.
286+ use_auto_scheduler_ = backend::IsAutoSchedulerEnabled ();
287+ use_meta_schedule_ = backend::IsMetaScheduleEnabled ();
288+ }
289+
290+ CachedFunc Create (const Function& relay_func, std::function<std::string(std::string)> renamer) {
291+ LowerToTECompute lower_te_compute (target_);
292+ Array<te::Tensor> outputs = lower_te_compute.Lower (relay_func, renamer);
293+ std::string candidate_name = lower_te_compute.readable_name_stream_ .str ();
294+ VisitExpr (relay_func->body );
295+
296+ constexpr static size_t kMaxFuncNameLength = 80 ;
297+ // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
298+ // whenever the value of kMaxFuncNameLength changes
299+ if (candidate_name.size () > kMaxFuncNameLength ) {
300+ std::stringstream truncated_name;
301+ truncated_name << candidate_name.substr (0 , kMaxFuncNameLength );
302+ truncated_name << " _" << std::hex << std::hash<std::string>{}(candidate_name) << " _" ;
303+ candidate_name = truncated_name.str ();
304+ }
305+
306+ // TODO(mbs): This should be the definitive global by which the PrimFunc is known and
307+ // no other GlobalVar ctors should appear inside the lowering machinery.
308+ auto prim_fn_var = GlobalVar (renamer (candidate_name));
309+ prim_fn_var->checked_type_ = relay_func->checked_type ();
310+
311+ // Fusion over tupled results may leave identity relationships
312+ // between inputs and outputs, and those should not be scheduled.
313+ // Hence schedule only non PlaceholderOp outputs.
314+ tvm::Array<te::Tensor> tensor_outs;
315+ for (const auto & tensor : outputs) {
316+ if (!tensor->op .as <te::PlaceholderOpNode>()) {
317+ tensor_outs.push_back (tensor);
318+ }
319+ }
320+
321+ te::Schedule schedule{nullptr };
322+ tir::PrimFunc prim_func{nullptr };
323+ // No need to register schedule for device copy op.
324+ if (anchor_attrs_.as <DeviceCopyAttrs>() == nullptr && create_schedule_) {
325+ if (use_auto_scheduler_) {
326+ const auto * fauto_schedule =
327+ runtime::Registry::Get (" auto_scheduler.relay_integration.auto_schedule_topi_compute" );
328+ ICHECK (fauto_schedule != nullptr )
329+ << " auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered" ;
330+ ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint , tensor_outs);
331+ if (obj.defined ()) {
332+ schedule = Downcast<te::Schedule>(obj);
333+ }
334+ }
335+ if (use_meta_schedule_) {
336+ prim_func = tir::CreatePrimFuncFromOutputs (tensor_outs);
337+ Optional<ObjectRef> opt_mod_or_base_func =
338+ meta_schedule::MetaScheduleContext::QueryInsideWithScope (
339+ prim_fn_var->name_hint , IRModule ({{prim_fn_var, relay_func}}), target_,
340+ Array<IRModule>{IRModule ({{prim_fn_var, prim_func}})});
341+ if (const auto * result = opt_mod_or_base_func.as <tir::PrimFuncNode>()) {
342+ prim_func = GetRef<tir::PrimFunc>(result);
343+ } else {
344+ prim_func = tir::PrimFunc (nullptr );
345+ }
346+ }
347+
348+ // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
349+ if (!schedule.defined () && !prim_func.defined ()) {
350+ ICHECK (lower_te_compute.anchor_implementation_ .defined ());
351+ schedule =
352+ lower_te_compute.anchor_implementation_ .Schedule (anchor_attrs_, tensor_outs, target_);
353+ }
354+ if (schedule.defined ()) {
355+ for (const auto & scalar : lower_te_compute.scalars_ ) {
356+ if (schedule->Contain (scalar)) {
357+ schedule[scalar].compute_inline ();
358+ }
359+ }
360+ }
361+ }
362+
363+ return CachedFunc (target_, prim_fn_var, lower_te_compute.fn_inputs_ , outputs, schedule,
364+ prim_func, {}, IRModule (Map<GlobalVar, BaseFunc>({})),
365+ lower_te_compute.constant_tensors_ );
366+ }
367+
368+ void VisitExpr_ (const CallNode* call_node) final {
369+ static auto fpattern = Op::GetAttrMap<TOpPattern>(" TOpPattern" );
370+
371+ ICHECK (call_node->op .as <OpNode>()) << " Primitive function only allows call into primitive ops" ;
372+ Op op = Downcast<Op>(call_node->op );
373+
374+ if (create_schedule_) {
375+ int op_pattern = fpattern[op];
376+ if (!use_auto_scheduler_ && op_pattern >= kCommReduce ) {
377+ ICHECK (!anchor_op_.defined () || anchor_op_pattern_ < kCommReduce )
378+ << " Cannot apply TOPI schedule to a primitive function with two complicated ops"
379+ << " anchor=" << anchor_op_ << " current=" << op;
380+ }
381+ if (op_pattern >= anchor_op_pattern_) {
382+ anchor_op_ = op;
383+ anchor_attrs_ = call_node->attrs ;
384+ anchor_op_pattern_ = op_pattern;
385+ }
386+ }
387+ }
388+
350389 private:
351390 tvm::Target target_;
352391 Op anchor_op_;
353392 Attrs anchor_attrs_;
354393 int anchor_op_pattern_{0 };
355- OpImplementation anchor_implementation_;
356- std::ostringstream readable_name_stream_;
357- Array<te::Operation> scalars_;
358- std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
359394 bool use_auto_scheduler_;
360395 bool use_meta_schedule_;
361- // Cache device copy op for equivalence checking to reduce registry lookup
362- // overhead for each invocation of call node when retrieving schedules.
363- const Op& device_copy_op_;
364396 bool create_schedule_;
365- // Index of the global constants
366- static int const_index;
367397};
368398
369- int ScheduleBuilder::const_index = 0 ;
370-
371399/* !
372400 * \brief Create schedule for target.
373401 * \param source_func The primitive function to be lowered.
0 commit comments