2828#include < tvm/relay/expr_functor.h>
2929#include < tvm/relay/op.h>
3030#include < tvm/relay/op_attr_types.h>
31+ #include < tvm/relay/op_strategy.h>
3132#include < tvm/runtime/device_api.h>
3233#include < tvm/runtime/registry.h>
3334#include < tvm/te/operation.h>
3435#include < tvm/te/schedule.h>
3536#include < tvm/te/schedule_pass.h>
37+ #include < tvm/tir/function.h>
3638#include < tvm/topi/tags.h>
3739
3840#include < functional>
@@ -114,100 +116,40 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
114116 return res;
115117}
116118
117- // Construct a schedule for a given Relay primitive function and target.
118- class ScheduleBuilder : public backend ::MemoizedExprTranslator<Array<te::Tensor>> {
119+ // Lowers Relay primitive Function to TE Compute
120+ class LowerToTECompute : public backend ::MemoizedExprTranslator<Array<te::Tensor>> {
119121 public:
120- explicit ScheduleBuilder (Target target, bool create_schedule = true )
121- : target_(target),
122- device_copy_op_(Op::Get(" device_copy" )),
123- create_schedule_(create_schedule) {
124- // Whether to use auto_scheduler schedule.
125- use_auto_scheduler_ = backend::IsAutoSchedulerEnabled ();
126- use_meta_schedule_ = backend::IsMetaScheduleEnabled ();
127- }
122+ explicit LowerToTECompute (Target target)
123+ : target_(target), device_copy_op_(Op::Get(" device_copy" )) {}
128124
129- CachedFunc Create (const Function& relay_func, std::function<std::string(std::string)> renamer) {
130- Array<tvm::te::Tensor> fn_inputs;
125+ Array<te::Tensor> Lower (const Function& relay_func,
126+ std::function<std::string(std::string)> renamer) {
131127 for (Var param : relay_func->params ) {
132128 Array<tvm::te::Tensor> inputs;
133129 for (const auto & ttype : FlattenTupleType (param->checked_type ())) {
134130 tvm::te::Tensor tensor = tvm::te::placeholder (GetShape (ttype->shape ), ttype->dtype );
135- fn_inputs.push_back (tensor);
136131 inputs.push_back (tensor);
132+ fn_inputs_.push_back (tensor);
137133 }
138134 memo_[param] = inputs;
139135 }
140136 readable_name_stream_ << " fused" ;
141- auto outputs = this ->VisitExpr (relay_func->body );
142- auto candidate_name = readable_name_stream_.str ();
137+
138+ Array<te::Tensor> outputs = this ->VisitExpr (relay_func->body );
139+
140+ candidate_name_ = readable_name_stream_.str ();
141+
143142 constexpr static size_t kMaxFuncNameLength = 80 ;
144143 // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
145144 // whenever the value of kMaxFuncNameLength changes
146- if (candidate_name .size () > kMaxFuncNameLength ) {
145+ if (candidate_name_ .size () > kMaxFuncNameLength ) {
147146 std::stringstream truncated_name;
148- truncated_name << candidate_name.substr (0 , kMaxFuncNameLength );
149- truncated_name << " _" << std::hex << std::hash<std::string>{}(candidate_name) << " _" ;
150- candidate_name = truncated_name.str ();
151- }
152-
153- // TODO(mbs): This should be the definitive global by which the PrimFunc is known and
154- // no other GlobalVar ctors should appear inside the lowering machinery.
155- auto prim_fn_var = GlobalVar (renamer (candidate_name));
156- prim_fn_var->checked_type_ = relay_func->checked_type ();
157-
158- // Fusion over tupled results may leave identity relationships
159- // between inputs and outputs, and those should not be scheduled.
160- // Hence schedule only non PlaceholderOp outputs.
161- tvm::Array<te::Tensor> tensor_outs;
162- for (const auto & tensor : outputs) {
163- if (!tensor->op .as <te::PlaceholderOpNode>()) {
164- tensor_outs.push_back (tensor);
165- }
166- }
167-
168- te::Schedule schedule{nullptr };
169- tir::PrimFunc prim_func{nullptr };
170- // No need to register schedule for device copy op.
171- if (anchor_attrs_.as <DeviceCopyAttrs>() == nullptr && create_schedule_) {
172- if (use_auto_scheduler_) {
173- const auto * fauto_schedule =
174- runtime::Registry::Get (" auto_scheduler.relay_integration.auto_schedule_topi_compute" );
175- ICHECK (fauto_schedule != nullptr )
176- << " auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered" ;
177- ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint , tensor_outs);
178- if (obj.defined ()) {
179- schedule = Downcast<te::Schedule>(obj);
180- }
181- }
182- if (use_meta_schedule_) {
183- prim_func = tir::CreatePrimFunc (Concat (fn_inputs, tensor_outs));
184- Optional<ObjectRef> opt_mod_or_base_func =
185- meta_schedule::MetaScheduleContext::QueryInsideWithScope (
186- prim_fn_var->name_hint , IRModule ({{prim_fn_var, relay_func}}), target_,
187- Array<IRModule>{IRModule ({{prim_fn_var, prim_func}})});
188- if (const auto * result = opt_mod_or_base_func.as <tir::PrimFuncNode>()) {
189- prim_func = GetRef<tir::PrimFunc>(result);
190- } else {
191- prim_func = tir::PrimFunc (nullptr );
192- }
193- }
194-
195- // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
196- if (!schedule.defined () && !prim_func.defined ()) {
197- ICHECK (anchor_implementation_.defined ());
198- schedule = anchor_implementation_.Schedule (anchor_attrs_, tensor_outs, target_);
199- }
200- if (schedule.defined ()) {
201- for (const auto & scalar : scalars_) {
202- if (schedule->Contain (scalar)) {
203- schedule[scalar].compute_inline ();
204- }
205- }
206- }
147+ truncated_name << candidate_name_.substr (0 , kMaxFuncNameLength );
148+ truncated_name << " _" << std::hex << std::hash<std::string>{}(candidate_name_) << " _" ;
149+ candidate_name_ = truncated_name.str ();
207150 }
208151
209- return CachedFunc (target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
210- IRModule (Map<GlobalVar, BaseFunc>({})), constant_tensors_);
152+ return outputs;
211153 }
212154
213155 Array<te::Tensor> VisitExpr_ (const VarNode* op) final {
@@ -254,7 +196,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
254196 }
255197
256198 Array<te::Tensor> VisitExpr_ (const CallNode* call_node) final {
257- static auto fpattern = Op::GetAttrMap<TOpPattern>(" TOpPattern" );
258199 static auto flower_call = tvm::runtime::Registry::Get (" relay.backend.lower_call" );
259200 ICHECK (flower_call) << " relay.backend.lower_call is not registered." ;
260201
@@ -278,28 +219,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
278219 ICHECK (call_node->op .as <OpNode>()) << " Primitive function only allows call into primitive ops" ;
279220 Op op = Downcast<Op>(call_node->op );
280221
281- Array<te::Tensor> outputs;
282- OpImplementation impl;
283222 // TODO(mbs): device_copy cleanup
284223 ICHECK_NE (op, device_copy_op_) << " device_copy cannot be lowered" ;
224+
285225 LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
286- outputs = lowered_out->outputs ;
287- impl = lowered_out->implementation ;
288-
289- if (create_schedule_) {
290- int op_pattern = fpattern[op];
291- if (!use_auto_scheduler_ && op_pattern >= kCommReduce ) {
292- ICHECK (!anchor_op_.defined () || anchor_op_pattern_ < kCommReduce )
293- << " Cannot apply TOPI schedule to a primitive function with two complicated ops"
294- << " anchor=" << anchor_op_ << " current=" << op;
295- }
296- if (op_pattern >= anchor_op_pattern_) {
297- anchor_op_ = op;
298- anchor_attrs_ = call_node->attrs ;
299- anchor_op_pattern_ = op_pattern;
300- anchor_implementation_ = impl;
301- }
302- }
226+ Array<te::Tensor> outputs = lowered_out->outputs ;
227+ op_implementations_[op.operator ->()] = lowered_out->implementation ;
228+
303229 if (outputs.size () != 1 ) {
304230 const auto * tuple_type = call_node->checked_type ().as <TupleTypeNode>();
305231 ICHECK (tuple_type) << " Expected output to be a tuple type "
@@ -308,8 +234,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
308234 ICHECK_EQ (tuple_type->fields .size (), outputs.size ());
309235 }
310236
311- // TODO(mbs): device_copy cleanup
312- ICHECK_NE (op, device_copy_op_) << " device_copy cannot be lowered" ;
313237 readable_name_stream_ << ' _' << op->name ;
314238 return outputs;
315239 }
@@ -347,26 +271,131 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
347271 return {tuple[op->index ]};
348272 }
349273
274+ public:
275+ // Additional outputs
276+ Array<tvm::te::Tensor> fn_inputs_;
277+ Array<te::Operation> scalars_;
278+ std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
279+ std::unordered_map<const OpNode*, OpImplementation> op_implementations_;
280+ std::string candidate_name_;
281+
350282 private:
351283 tvm::Target target_;
352- Op anchor_op_;
353- Attrs anchor_attrs_;
354- int anchor_op_pattern_{0 };
355- OpImplementation anchor_implementation_;
356284 std::ostringstream readable_name_stream_;
357- Array<te::Operation> scalars_;
358- std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
359- bool use_auto_scheduler_;
360- bool use_meta_schedule_;
285+ // Index of the global constants
286+ static int const_index;
361287 // Cache device copy op for equivalence checking to reduce registry lookup
362288 // overhead for each invocation of call node when retrieving schedules.
363289 const Op& device_copy_op_;
364- bool create_schedule_;
365- // Index of the global constants
366- static int const_index;
367290};
368291
369- int ScheduleBuilder::const_index = 0 ;
292+ int LowerToTECompute::const_index = 0 ;
293+
294+ // Construct a schedule for a given Relay primitive function and target.
295+ class ScheduleBuilder : public ExprVisitor {
296+ public:
297+ explicit ScheduleBuilder (Target target) : target_(target) {
298+ // Whether to use auto_scheduler schedule.
299+ use_auto_scheduler_ = backend::IsAutoSchedulerEnabled ();
300+ }
301+
302+ CachedFunc Create (const Function& relay_func, std::function<std::string(std::string)> renamer) {
303+ LowerToTECompute lower_te_compute (target_);
304+ Array<te::Tensor> outputs = lower_te_compute.Lower (relay_func, renamer);
305+ Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_ ;
306+ VisitExpr (relay_func->body );
307+
308+ // TODO(mbs): This should be the definitive global by which the PrimFunc is known and
309+ // no other GlobalVar ctors should appear inside the lowering machinery.
310+ auto prim_fn_var = GlobalVar (renamer (lower_te_compute.candidate_name_ ));
311+ prim_fn_var->checked_type_ = relay_func->checked_type ();
312+
313+ // Fusion over tupled results may leave identity relationships
314+ // between inputs and outputs, and those should not be scheduled.
315+ // Hence schedule only non PlaceholderOp outputs.
316+ tvm::Array<te::Tensor> tensor_outs;
317+ for (const auto & tensor : outputs) {
318+ if (!tensor->op .as <te::PlaceholderOpNode>()) {
319+ tensor_outs.push_back (tensor);
320+ }
321+ }
322+
323+ te::Schedule schedule{nullptr };
324+ tir::PrimFunc prim_func{nullptr };
325+ // No need to register schedule for device copy op.
326+ if (anchor_attrs_.as <DeviceCopyAttrs>() == nullptr ) {
327+ if (use_auto_scheduler_) {
328+ const auto * fauto_schedule =
329+ runtime::Registry::Get (" auto_scheduler.relay_integration.auto_schedule_topi_compute" );
330+ ICHECK (fauto_schedule != nullptr )
331+ << " auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered" ;
332+ ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint , tensor_outs);
333+ if (obj.defined ()) {
334+ schedule = Downcast<te::Schedule>(obj);
335+ }
336+ }
337+ if (backend::IsMetaScheduleEnabled ()) {
338+ prim_func = tir::CreatePrimFunc (Concat (fn_inputs, tensor_outs));
339+ Optional<ObjectRef> opt_mod_or_base_func =
340+ meta_schedule::MetaScheduleContext::QueryInsideWithScope (
341+ prim_fn_var->name_hint , IRModule ({{prim_fn_var, relay_func}}), target_,
342+ Array<IRModule>{IRModule ({{prim_fn_var, prim_func}})});
343+ if (const auto * result = opt_mod_or_base_func.as <tir::PrimFuncNode>()) {
344+ prim_func = GetRef<tir::PrimFunc>(result);
345+ } else {
346+ prim_func = tir::PrimFunc (nullptr );
347+ }
348+ }
349+
350+ // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
351+ if (!schedule.defined () && !prim_func.defined ()) {
352+ auto anchor_impl = lower_te_compute.op_implementations_ .find (anchor_op_.operator ->());
353+ ICHECK (anchor_impl != lower_te_compute.op_implementations_ .end ());
354+ schedule = anchor_impl->second .Schedule (anchor_attrs_, tensor_outs, target_);
355+ }
356+ if (schedule.defined ()) {
357+ for (const auto & scalar : lower_te_compute.scalars_ ) {
358+ if (schedule->Contain (scalar)) {
359+ schedule[scalar].compute_inline ();
360+ }
361+ }
362+ }
363+ }
364+
365+ return CachedFunc (target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
366+ IRModule (Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_ );
367+ }
368+
369+ void VisitExpr_ (const CallNode* call_node) final {
370+ static auto fpattern = Op::GetAttrMap<TOpPattern>(" TOpPattern" );
371+
372+ ICHECK (call_node->op .as <OpNode>()) << " Primitive function only allows call into primitive ops" ;
373+ Op op = Downcast<Op>(call_node->op );
374+
375+ for (Expr arg : call_node->args ) {
376+ VisitExpr (arg);
377+ }
378+
379+ int op_pattern = fpattern[op];
380+ if (!use_auto_scheduler_ && op_pattern >= kCommReduce ) {
381+ ICHECK (!anchor_op_.defined () || anchor_op_pattern_ < kCommReduce )
382+ << " Cannot apply TOPI schedule to a primitive function with two complicated ops"
383+ << " anchor=" << anchor_op_ << " current=" << op;
384+ }
385+ if (op_pattern >= anchor_op_pattern_) {
386+ anchor_op_ = op;
387+ anchor_attrs_ = call_node->attrs ;
388+ anchor_op_pattern_ = op_pattern;
389+ }
390+ }
391+
392+ private:
393+ tvm::Target target_;
394+ Op anchor_op_;
395+ Attrs anchor_attrs_;
396+ int anchor_op_pattern_{0 };
397+ bool use_auto_scheduler_;
398+ };
370399
371400/* !
372401 * \brief Create schedule for target.
@@ -750,9 +779,12 @@ std::string GetUniqueName(std::string name, std::unordered_map<std::string, int>
750779}
751780
752781TVM_REGISTER_GLOBAL (" relay.backend.LowerToTE" ).set_body_typed([](Function prim_func) {
753- return ScheduleBuilder (tvm::Target (" ext_dev" ), false ).Create (prim_func, [&](std::string name) {
754- return name;
755- });
782+ auto tgt = tvm::Target (" ext_dev" );
783+ LowerToTECompute lower_te_compute (tgt);
784+ auto outputs = lower_te_compute.Lower (prim_func, [&](std::string name) { return name; });
785+ return CachedFunc (tgt, GlobalVar (lower_te_compute.candidate_name_ ), lower_te_compute.fn_inputs_ ,
786+ outputs, te::Schedule (), tir::PrimFunc (), {},
787+ IRModule (Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_ );
756788});
757789
758790} // namespace tec
0 commit comments