4646#include " ../op/memory/memory.h"
4747#include " ../transforms/pass_utils.h"
4848#include " tvm/relay/op_strategy.h"
49+ #include " tvm/tir/function.h"
4950#include " utils.h"
5051
5152namespace tvm {
@@ -115,7 +116,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
115116 return res;
116117}
117118
118- // Construct a schedule for a given Relay primitive function and target.
119+ // Lowers Relay primitive Function to TE Compute
119120class LowerToTECompute : public backend ::MemoizedExprTranslator<Array<te::Tensor>> {
120121 public:
121122 explicit LowerToTECompute (Target target)
@@ -133,7 +134,21 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
133134 memo_[param] = inputs;
134135 }
135136 readable_name_stream_ << " fused" ;
136- return this ->VisitExpr (relay_func->body );
137+
138+ Array<te::Tensor> outputs = this ->VisitExpr (relay_func->body );
139+
140+ candidate_name_ = readable_name_stream_.str ();
141+ constexpr static size_t kMaxFuncNameLength = 80 ;
142+ // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
143+ // whenever the value of kMaxFuncNameLength changes
144+ if (candidate_name_.size () > kMaxFuncNameLength ) {
145+ std::stringstream truncated_name;
146+ truncated_name << candidate_name_.substr (0 , kMaxFuncNameLength );
147+ truncated_name << " _" << std::hex << std::hash<std::string>{}(candidate_name_) << " _" ;
148+ candidate_name_ = truncated_name.str ();
149+ }
150+
151+ return outputs;
137152 }
138153
139154 Array<te::Tensor> VisitExpr_ (const VarNode* op) final {
@@ -260,11 +275,12 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
260275 Array<tvm::te::Tensor> fn_inputs_;
261276 Array<te::Operation> scalars_;
262277 std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
263- std::ostringstream readable_name_stream_ ;
278+ std::string candidate_name_ ;
264279 OpImplementation anchor_implementation_;
265280
266281 private:
267282 tvm::Target target_;
283+ std::ostringstream readable_name_stream_;
268284 // Index of the global constants
269285 static int const_index;
270286 // Cache device copy op for equivalence checking to reduce registry lookup
@@ -277,33 +293,20 @@ int LowerToTECompute::const_index = 0;
277293// Construct a schedule for a given Relay primitive function and target.
278294class ScheduleBuilder : ExprVisitor {
279295 public:
280- explicit ScheduleBuilder (Target target, bool create_schedule = true )
281- : target_(target), create_schedule_(create_schedule) {
296+ explicit ScheduleBuilder (Target target) : target_(target) {
282297 // Whether to use auto_scheduler schedule.
283298 use_auto_scheduler_ = backend::IsAutoSchedulerEnabled ();
284- use_meta_schedule_ = backend::IsMetaScheduleEnabled ();
285299 }
286300
287301 CachedFunc Create (const Function& relay_func, std::function<std::string(std::string)> renamer) {
288302 LowerToTECompute lower_te_compute (target_);
289303 Array<te::Tensor> outputs = lower_te_compute.Lower (relay_func, renamer);
290304 Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_ ;
291- std::string candidate_name = lower_te_compute.readable_name_stream_ .str ();
292305 VisitExpr (relay_func->body );
293306
294- constexpr static size_t kMaxFuncNameLength = 80 ;
295- // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
296- // whenever the value of kMaxFuncNameLength changes
297- if (candidate_name.size () > kMaxFuncNameLength ) {
298- std::stringstream truncated_name;
299- truncated_name << candidate_name.substr (0 , kMaxFuncNameLength );
300- truncated_name << " _" << std::hex << std::hash<std::string>{}(candidate_name) << " _" ;
301- candidate_name = truncated_name.str ();
302- }
303-
304307 // TODO(mbs): This should be the definitive global by which the PrimFunc is known and
305308 // no other GlobalVar ctors should appear inside the lowering machinery.
306- auto prim_fn_var = GlobalVar (renamer (candidate_name ));
309+ auto prim_fn_var = GlobalVar (renamer (lower_te_compute. candidate_name_ ));
307310 prim_fn_var->checked_type_ = relay_func->checked_type ();
308311
309312 // Fusion over tupled results may leave identity relationships
@@ -319,7 +322,7 @@ class ScheduleBuilder : ExprVisitor {
319322 te::Schedule schedule{nullptr };
320323 tir::PrimFunc prim_func{nullptr };
321324 // No need to register schedule for device copy op.
322- if (anchor_attrs_.as <DeviceCopyAttrs>() == nullptr && create_schedule_ ) {
325+ if (anchor_attrs_.as <DeviceCopyAttrs>() == nullptr ) {
323326 if (use_auto_scheduler_) {
324327 const auto * fauto_schedule =
325328 runtime::Registry::Get (" auto_scheduler.relay_integration.auto_schedule_topi_compute" );
@@ -330,7 +333,7 @@ class ScheduleBuilder : ExprVisitor {
330333 schedule = Downcast<te::Schedule>(obj);
331334 }
332335 }
333- if (use_meta_schedule_ ) {
336+ if (backend::IsMetaScheduleEnabled () ) {
334337 prim_func = tir::CreatePrimFunc (Concat (fn_inputs, tensor_outs));
335338 Optional<ObjectRef> opt_mod_or_base_func =
336339 meta_schedule::MetaScheduleContext::QueryInsideWithScope (
@@ -368,18 +371,16 @@ class ScheduleBuilder : ExprVisitor {
368371 ICHECK (call_node->op .as <OpNode>()) << " Primitive function only allows call into primitive ops" ;
369372 Op op = Downcast<Op>(call_node->op );
370373
371- if (create_schedule_) {
372- int op_pattern = fpattern[op];
373- if (!use_auto_scheduler_ && op_pattern >= kCommReduce ) {
374- ICHECK (!anchor_op_.defined () || anchor_op_pattern_ < kCommReduce )
375- << " Cannot apply TOPI schedule to a primitive function with two complicated ops"
376- << " anchor=" << anchor_op_ << " current=" << op;
377- }
378- if (op_pattern >= anchor_op_pattern_) {
379- anchor_op_ = op;
380- anchor_attrs_ = call_node->attrs ;
381- anchor_op_pattern_ = op_pattern;
382- }
374+ int op_pattern = fpattern[op];
375+ if (!use_auto_scheduler_ && op_pattern >= kCommReduce ) {
376+ ICHECK (!anchor_op_.defined () || anchor_op_pattern_ < kCommReduce )
377+ << " Cannot apply TOPI schedule to a primitive function with two complicated ops"
378+ << " anchor=" << anchor_op_ << " current=" << op;
379+ }
380+ if (op_pattern >= anchor_op_pattern_) {
381+ anchor_op_ = op;
382+ anchor_attrs_ = call_node->attrs ;
383+ anchor_op_pattern_ = op_pattern;
383384 }
384385 }
385386
@@ -389,8 +390,6 @@ class ScheduleBuilder : ExprVisitor {
389390 Attrs anchor_attrs_;
390391 int anchor_op_pattern_{0 };
391392 bool use_auto_scheduler_;
392- bool use_meta_schedule_;
393- bool create_schedule_;
394393};
395394
396395/* !
@@ -775,9 +774,12 @@ std::string GetUniqueName(std::string name, std::unordered_map<std::string, int>
775774}
776775
777776TVM_REGISTER_GLOBAL (" relay.backend.LowerToTE" ).set_body_typed([](Function prim_func) {
778- return ScheduleBuilder (tvm::Target (" ext_dev" ), false ).Create (prim_func, [&](std::string name) {
779- return name;
780- });
777+ auto tgt = tvm::Target (" ext_dev" );
778+ LowerToTECompute lower_te_compute (tgt);
779+ auto outputs = lower_te_compute.Lower (prim_func, [&](std::string name) { return name; });
780+ return CachedFunc (tgt, GlobalVar (lower_te_compute.candidate_name_ ), lower_te_compute.fn_inputs_ ,
781+ outputs, te::Schedule (), tir::PrimFunc (), {},
782+ IRModule (Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_ );
781783});
782784
783785} // namespace tec
0 commit comments