Skip to content

Commit 076fa33

Browse files
authored
[TECompiler] Decouple TE compute and schedule lowering in ScheduleBuilder (apache#10561)
* Decouple TE compute and schedule lowering in ScheduleBuilder * fixed merge conflict * removed create_schedule stuff * add public, fix include path convention * Forgot visiting arg in ScheduleBuilder CallNode vsit * fixed anchor impl selection
1 parent 51ae845 commit 076fa33

File tree

1 file changed

+146
-114
lines changed

1 file changed

+146
-114
lines changed

src/relay/backend/te_compiler_cache.cc

Lines changed: 146 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@
2828
#include <tvm/relay/expr_functor.h>
2929
#include <tvm/relay/op.h>
3030
#include <tvm/relay/op_attr_types.h>
31+
#include <tvm/relay/op_strategy.h>
3132
#include <tvm/runtime/device_api.h>
3233
#include <tvm/runtime/registry.h>
3334
#include <tvm/te/operation.h>
3435
#include <tvm/te/schedule.h>
3536
#include <tvm/te/schedule_pass.h>
37+
#include <tvm/tir/function.h>
3638
#include <tvm/topi/tags.h>
3739

3840
#include <functional>
@@ -114,100 +116,40 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
114116
return res;
115117
}
116118

117-
// Construct a schedule for a given Relay primitive function and target.
118-
class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
119+
// Lowers Relay primitive Function to TE Compute
120+
class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
119121
public:
120-
explicit ScheduleBuilder(Target target, bool create_schedule = true)
121-
: target_(target),
122-
device_copy_op_(Op::Get("device_copy")),
123-
create_schedule_(create_schedule) {
124-
// Whether to use auto_scheduler schedule.
125-
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
126-
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
127-
}
122+
explicit LowerToTECompute(Target target)
123+
: target_(target), device_copy_op_(Op::Get("device_copy")) {}
128124

129-
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
130-
Array<tvm::te::Tensor> fn_inputs;
125+
Array<te::Tensor> Lower(const Function& relay_func,
126+
std::function<std::string(std::string)> renamer) {
131127
for (Var param : relay_func->params) {
132128
Array<tvm::te::Tensor> inputs;
133129
for (const auto& ttype : FlattenTupleType(param->checked_type())) {
134130
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
135-
fn_inputs.push_back(tensor);
136131
inputs.push_back(tensor);
132+
fn_inputs_.push_back(tensor);
137133
}
138134
memo_[param] = inputs;
139135
}
140136
readable_name_stream_ << "fused";
141-
auto outputs = this->VisitExpr(relay_func->body);
142-
auto candidate_name = readable_name_stream_.str();
137+
138+
Array<te::Tensor> outputs = this->VisitExpr(relay_func->body);
139+
140+
candidate_name_ = readable_name_stream_.str();
141+
143142
constexpr static size_t kMaxFuncNameLength = 80;
144143
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
145144
// whenever the value of kMaxFuncNameLength changes
146-
if (candidate_name.size() > kMaxFuncNameLength) {
145+
if (candidate_name_.size() > kMaxFuncNameLength) {
147146
std::stringstream truncated_name;
148-
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
149-
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
150-
candidate_name = truncated_name.str();
151-
}
152-
153-
// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
154-
// no other GlobalVar ctors should appear inside the lowering machinery.
155-
auto prim_fn_var = GlobalVar(renamer(candidate_name));
156-
prim_fn_var->checked_type_ = relay_func->checked_type();
157-
158-
// Fusion over tupled results may leave identity relationships
159-
// between inputs and outputs, and those should not be scheduled.
160-
// Hence schedule only non PlaceholderOp outputs.
161-
tvm::Array<te::Tensor> tensor_outs;
162-
for (const auto& tensor : outputs) {
163-
if (!tensor->op.as<te::PlaceholderOpNode>()) {
164-
tensor_outs.push_back(tensor);
165-
}
166-
}
167-
168-
te::Schedule schedule{nullptr};
169-
tir::PrimFunc prim_func{nullptr};
170-
// No need to register schedule for device copy op.
171-
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) {
172-
if (use_auto_scheduler_) {
173-
const auto* fauto_schedule =
174-
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
175-
ICHECK(fauto_schedule != nullptr)
176-
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
177-
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
178-
if (obj.defined()) {
179-
schedule = Downcast<te::Schedule>(obj);
180-
}
181-
}
182-
if (use_meta_schedule_) {
183-
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
184-
Optional<ObjectRef> opt_mod_or_base_func =
185-
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
186-
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
187-
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
188-
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
189-
prim_func = GetRef<tir::PrimFunc>(result);
190-
} else {
191-
prim_func = tir::PrimFunc(nullptr);
192-
}
193-
}
194-
195-
// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
196-
if (!schedule.defined() && !prim_func.defined()) {
197-
ICHECK(anchor_implementation_.defined());
198-
schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
199-
}
200-
if (schedule.defined()) {
201-
for (const auto& scalar : scalars_) {
202-
if (schedule->Contain(scalar)) {
203-
schedule[scalar].compute_inline();
204-
}
205-
}
206-
}
147+
truncated_name << candidate_name_.substr(0, kMaxFuncNameLength);
148+
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name_) << "_";
149+
candidate_name_ = truncated_name.str();
207150
}
208151

209-
return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
210-
IRModule(Map<GlobalVar, BaseFunc>({})), constant_tensors_);
152+
return outputs;
211153
}
212154

213155
Array<te::Tensor> VisitExpr_(const VarNode* op) final {
@@ -254,7 +196,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
254196
}
255197

256198
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
257-
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
258199
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
259200
ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
260201

@@ -278,28 +219,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
278219
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
279220
Op op = Downcast<Op>(call_node->op);
280221

281-
Array<te::Tensor> outputs;
282-
OpImplementation impl;
283222
// TODO(mbs): device_copy cleanup
284223
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";
224+
285225
LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
286-
outputs = lowered_out->outputs;
287-
impl = lowered_out->implementation;
288-
289-
if (create_schedule_) {
290-
int op_pattern = fpattern[op];
291-
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
292-
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
293-
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
294-
<< " anchor=" << anchor_op_ << " current=" << op;
295-
}
296-
if (op_pattern >= anchor_op_pattern_) {
297-
anchor_op_ = op;
298-
anchor_attrs_ = call_node->attrs;
299-
anchor_op_pattern_ = op_pattern;
300-
anchor_implementation_ = impl;
301-
}
302-
}
226+
Array<te::Tensor> outputs = lowered_out->outputs;
227+
op_implementations_[op.operator->()] = lowered_out->implementation;
228+
303229
if (outputs.size() != 1) {
304230
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
305231
ICHECK(tuple_type) << "Expected output to be a tuple type "
@@ -308,8 +234,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
308234
ICHECK_EQ(tuple_type->fields.size(), outputs.size());
309235
}
310236

311-
// TODO(mbs): device_copy cleanup
312-
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";
313237
readable_name_stream_ << '_' << op->name;
314238
return outputs;
315239
}
@@ -347,26 +271,131 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
347271
return {tuple[op->index]};
348272
}
349273

274+
public:
275+
// Additional outputs
276+
Array<tvm::te::Tensor> fn_inputs_;
277+
Array<te::Operation> scalars_;
278+
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
279+
std::unordered_map<const OpNode*, OpImplementation> op_implementations_;
280+
std::string candidate_name_;
281+
350282
private:
351283
tvm::Target target_;
352-
Op anchor_op_;
353-
Attrs anchor_attrs_;
354-
int anchor_op_pattern_{0};
355-
OpImplementation anchor_implementation_;
356284
std::ostringstream readable_name_stream_;
357-
Array<te::Operation> scalars_;
358-
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
359-
bool use_auto_scheduler_;
360-
bool use_meta_schedule_;
285+
// Index of the global constants
286+
static int const_index;
361287
// Cache device copy op for equivalence checking to reduce registry lookup
362288
// overhead for each invocation of call node when retrieving schedules.
363289
const Op& device_copy_op_;
364-
bool create_schedule_;
365-
// Index of the global constants
366-
static int const_index;
367290
};
368291

369-
int ScheduleBuilder::const_index = 0;
292+
int LowerToTECompute::const_index = 0;
293+
294+
// Construct a schedule for a given Relay primitive function and target.
295+
class ScheduleBuilder : public ExprVisitor {
296+
public:
297+
explicit ScheduleBuilder(Target target) : target_(target) {
298+
// Whether to use auto_scheduler schedule.
299+
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
300+
}
301+
302+
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
303+
LowerToTECompute lower_te_compute(target_);
304+
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
305+
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
306+
VisitExpr(relay_func->body);
307+
308+
// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
309+
// no other GlobalVar ctors should appear inside the lowering machinery.
310+
auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_));
311+
prim_fn_var->checked_type_ = relay_func->checked_type();
312+
313+
// Fusion over tupled results may leave identity relationships
314+
// between inputs and outputs, and those should not be scheduled.
315+
// Hence schedule only non PlaceholderOp outputs.
316+
tvm::Array<te::Tensor> tensor_outs;
317+
for (const auto& tensor : outputs) {
318+
if (!tensor->op.as<te::PlaceholderOpNode>()) {
319+
tensor_outs.push_back(tensor);
320+
}
321+
}
322+
323+
te::Schedule schedule{nullptr};
324+
tir::PrimFunc prim_func{nullptr};
325+
// No need to register schedule for device copy op.
326+
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
327+
if (use_auto_scheduler_) {
328+
const auto* fauto_schedule =
329+
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
330+
ICHECK(fauto_schedule != nullptr)
331+
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
332+
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
333+
if (obj.defined()) {
334+
schedule = Downcast<te::Schedule>(obj);
335+
}
336+
}
337+
if (backend::IsMetaScheduleEnabled()) {
338+
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
339+
Optional<ObjectRef> opt_mod_or_base_func =
340+
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
341+
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
342+
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
343+
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
344+
prim_func = GetRef<tir::PrimFunc>(result);
345+
} else {
346+
prim_func = tir::PrimFunc(nullptr);
347+
}
348+
}
349+
350+
// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
351+
if (!schedule.defined() && !prim_func.defined()) {
352+
auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->());
353+
ICHECK(anchor_impl != lower_te_compute.op_implementations_.end());
354+
schedule = anchor_impl->second.Schedule(anchor_attrs_, tensor_outs, target_);
355+
}
356+
if (schedule.defined()) {
357+
for (const auto& scalar : lower_te_compute.scalars_) {
358+
if (schedule->Contain(scalar)) {
359+
schedule[scalar].compute_inline();
360+
}
361+
}
362+
}
363+
}
364+
365+
return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
366+
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_);
367+
}
368+
369+
void VisitExpr_(const CallNode* call_node) final {
370+
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
371+
372+
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
373+
Op op = Downcast<Op>(call_node->op);
374+
375+
for (Expr arg : call_node->args) {
376+
VisitExpr(arg);
377+
}
378+
379+
int op_pattern = fpattern[op];
380+
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
381+
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
382+
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
383+
<< " anchor=" << anchor_op_ << " current=" << op;
384+
}
385+
if (op_pattern >= anchor_op_pattern_) {
386+
anchor_op_ = op;
387+
anchor_attrs_ = call_node->attrs;
388+
anchor_op_pattern_ = op_pattern;
389+
}
390+
}
391+
392+
private:
393+
tvm::Target target_;
394+
Op anchor_op_;
395+
Attrs anchor_attrs_;
396+
int anchor_op_pattern_{0};
397+
bool use_auto_scheduler_;
398+
};
370399

371400
/*!
372401
* \brief Create schedule for target.
@@ -750,9 +779,12 @@ std::string GetUniqueName(std::string name, std::unordered_map<std::string, int>
750779
}
751780

752781
TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) {
753-
return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) {
754-
return name;
755-
});
782+
auto tgt = tvm::Target("ext_dev");
783+
LowerToTECompute lower_te_compute(tgt);
784+
auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { return name; });
785+
return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_,
786+
outputs, te::Schedule(), tir::PrimFunc(), {},
787+
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_);
756788
});
757789

758790
} // namespace tec

0 commit comments

Comments
 (0)