Skip to content

Commit a167655

Browse files
committed
removed create_schedule stuff
1 parent 3b97529 commit a167655

File tree

1 file changed

+39
-37
lines changed

1 file changed

+39
-37
lines changed

src/relay/backend/te_compiler_cache.cc

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "../op/memory/memory.h"
4747
#include "../transforms/pass_utils.h"
4848
#include "tvm/relay/op_strategy.h"
49+
#include "tvm/tir/function.h"
4950
#include "utils.h"
5051

5152
namespace tvm {
@@ -115,7 +116,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
115116
return res;
116117
}
117118

118-
// Construct a schedule for a given Relay primitive function and target.
119+
// Lowers Relay primitive Function to TE Compute
119120
class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
120121
public:
121122
explicit LowerToTECompute(Target target)
@@ -133,7 +134,21 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
133134
memo_[param] = inputs;
134135
}
135136
readable_name_stream_ << "fused";
136-
return this->VisitExpr(relay_func->body);
137+
138+
Array<te::Tensor> outputs = this->VisitExpr(relay_func->body);
139+
140+
candidate_name_ = readable_name_stream_.str();
141+
constexpr static size_t kMaxFuncNameLength = 80;
142+
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
143+
// whenever the value of kMaxFuncNameLength changes
144+
if (candidate_name_.size() > kMaxFuncNameLength) {
145+
std::stringstream truncated_name;
146+
truncated_name << candidate_name_.substr(0, kMaxFuncNameLength);
147+
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name_) << "_";
148+
candidate_name_ = truncated_name.str();
149+
}
150+
151+
return outputs;
137152
}
138153

139154
Array<te::Tensor> VisitExpr_(const VarNode* op) final {
@@ -260,11 +275,12 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
260275
Array<tvm::te::Tensor> fn_inputs_;
261276
Array<te::Operation> scalars_;
262277
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
263-
std::ostringstream readable_name_stream_;
278+
std::string candidate_name_;
264279
OpImplementation anchor_implementation_;
265280

266281
private:
267282
tvm::Target target_;
283+
std::ostringstream readable_name_stream_;
268284
// Index of the global constants
269285
static int const_index;
270286
// Cache device copy op for equivalence checking to reduce registry lookup
@@ -277,33 +293,20 @@ int LowerToTECompute::const_index = 0;
277293
// Construct a schedule for a given Relay primitive function and target.
278294
class ScheduleBuilder : ExprVisitor {
279295
public:
280-
explicit ScheduleBuilder(Target target, bool create_schedule = true)
281-
: target_(target), create_schedule_(create_schedule) {
296+
explicit ScheduleBuilder(Target target) : target_(target) {
282297
// Whether to use auto_scheduler schedule.
283298
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
284-
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
285299
}
286300

287301
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
288302
LowerToTECompute lower_te_compute(target_);
289303
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
290304
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
291-
std::string candidate_name = lower_te_compute.readable_name_stream_.str();
292305
VisitExpr(relay_func->body);
293306

294-
constexpr static size_t kMaxFuncNameLength = 80;
295-
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
296-
// whenever the value of kMaxFuncNameLength changes
297-
if (candidate_name.size() > kMaxFuncNameLength) {
298-
std::stringstream truncated_name;
299-
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
300-
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
301-
candidate_name = truncated_name.str();
302-
}
303-
304307
// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
305308
// no other GlobalVar ctors should appear inside the lowering machinery.
306-
auto prim_fn_var = GlobalVar(renamer(candidate_name));
309+
auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_));
307310
prim_fn_var->checked_type_ = relay_func->checked_type();
308311

309312
// Fusion over tupled results may leave identity relationships
@@ -319,7 +322,7 @@ class ScheduleBuilder : ExprVisitor {
319322
te::Schedule schedule{nullptr};
320323
tir::PrimFunc prim_func{nullptr};
321324
// No need to register schedule for device copy op.
322-
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) {
325+
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
323326
if (use_auto_scheduler_) {
324327
const auto* fauto_schedule =
325328
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
@@ -330,7 +333,7 @@ class ScheduleBuilder : ExprVisitor {
330333
schedule = Downcast<te::Schedule>(obj);
331334
}
332335
}
333-
if (use_meta_schedule_) {
336+
if (backend::IsMetaScheduleEnabled()) {
334337
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
335338
Optional<ObjectRef> opt_mod_or_base_func =
336339
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
@@ -368,18 +371,16 @@ class ScheduleBuilder : ExprVisitor {
368371
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
369372
Op op = Downcast<Op>(call_node->op);
370373

371-
if (create_schedule_) {
372-
int op_pattern = fpattern[op];
373-
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
374-
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
375-
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
376-
<< " anchor=" << anchor_op_ << " current=" << op;
377-
}
378-
if (op_pattern >= anchor_op_pattern_) {
379-
anchor_op_ = op;
380-
anchor_attrs_ = call_node->attrs;
381-
anchor_op_pattern_ = op_pattern;
382-
}
374+
int op_pattern = fpattern[op];
375+
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
376+
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
377+
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
378+
<< " anchor=" << anchor_op_ << " current=" << op;
379+
}
380+
if (op_pattern >= anchor_op_pattern_) {
381+
anchor_op_ = op;
382+
anchor_attrs_ = call_node->attrs;
383+
anchor_op_pattern_ = op_pattern;
383384
}
384385
}
385386

@@ -389,8 +390,6 @@ class ScheduleBuilder : ExprVisitor {
389390
Attrs anchor_attrs_;
390391
int anchor_op_pattern_{0};
391392
bool use_auto_scheduler_;
392-
bool use_meta_schedule_;
393-
bool create_schedule_;
394393
};
395394

396395
/*!
@@ -775,9 +774,12 @@ std::string GetUniqueName(std::string name, std::unordered_map<std::string, int>
775774
}
776775

777776
TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) {
778-
return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) {
779-
return name;
780-
});
777+
auto tgt = tvm::Target("ext_dev");
778+
LowerToTECompute lower_te_compute(tgt);
779+
auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { return name; });
780+
return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_,
781+
outputs, te::Schedule(), tir::PrimFunc(), {},
782+
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_);
781783
});
782784

783785
} // namespace tec

0 commit comments

Comments
 (0)