Skip to content

Commit 6e68fd9

Browse files
committed
Decouple TE compute and schedule lowering in ScheduleBuilder
1 parent 5b76768 commit 6e68fd9

File tree

1 file changed

+141
-114
lines changed

1 file changed

+141
-114
lines changed

src/relay/backend/te_compiler_cache.cc

Lines changed: 141 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "../../te/operation/create_primfunc.h"
4646
#include "../op/memory/memory.h"
4747
#include "../transforms/pass_utils.h"
48+
#include "tvm/relay/op_strategy.h"
4849
#include "utils.h"
4950

5051
namespace tvm {
@@ -115,99 +116,24 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
115116
}
116117

117118
// Construct a schedule for a given Relay primitive function and target.
118-
class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
119+
class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
119120
public:
120-
explicit ScheduleBuilder(Target target, bool create_schedule = true)
121-
: target_(target),
122-
device_copy_op_(Op::Get("device_copy")),
123-
create_schedule_(create_schedule) {
124-
// Whether to use auto_scheduler schedule.
125-
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
126-
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
127-
}
121+
explicit LowerToTECompute(Target target)
122+
: target_(target), device_copy_op_(Op::Get("device_copy")) {}
128123

129-
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
130-
Array<tvm::te::Tensor> fn_inputs;
124+
Array<te::Tensor> Lower(const Function& relay_func,
125+
std::function<std::string(std::string)> renamer) {
131126
for (Var param : relay_func->params) {
132127
Array<tvm::te::Tensor> inputs;
133128
for (const auto& ttype : FlattenTupleType(param->checked_type())) {
134129
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
135-
fn_inputs.push_back(tensor);
136130
inputs.push_back(tensor);
131+
fn_inputs_.push_back(tensor);
137132
}
138133
memo_[param] = inputs;
139134
}
140135
readable_name_stream_ << "fused";
141-
auto outputs = this->VisitExpr(relay_func->body);
142-
auto candidate_name = readable_name_stream_.str();
143-
constexpr static size_t kMaxFuncNameLength = 80;
144-
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
145-
// whenever the value of kMaxFuncNameLength changes
146-
if (candidate_name.size() > kMaxFuncNameLength) {
147-
std::stringstream truncated_name;
148-
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
149-
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
150-
candidate_name = truncated_name.str();
151-
}
152-
153-
// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
154-
// no other GlobalVar ctors should appear inside the lowering machinery.
155-
auto prim_fn_var = GlobalVar(renamer(candidate_name));
156-
prim_fn_var->checked_type_ = relay_func->checked_type();
157-
158-
// Fusion over tupled results may leave identity relationships
159-
// between inputs and outputs, and those should not be scheduled.
160-
// Hence schedule only non PlaceholderOp outputs.
161-
tvm::Array<te::Tensor> tensor_outs;
162-
for (const auto& tensor : outputs) {
163-
if (!tensor->op.as<te::PlaceholderOpNode>()) {
164-
tensor_outs.push_back(tensor);
165-
}
166-
}
167-
168-
te::Schedule schedule{nullptr};
169-
tir::PrimFunc prim_func{nullptr};
170-
// No need to register schedule for device copy op.
171-
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) {
172-
if (use_auto_scheduler_) {
173-
const auto* fauto_schedule =
174-
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
175-
ICHECK(fauto_schedule != nullptr)
176-
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
177-
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
178-
if (obj.defined()) {
179-
schedule = Downcast<te::Schedule>(obj);
180-
}
181-
}
182-
if (use_meta_schedule_) {
183-
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
184-
Optional<ObjectRef> opt_mod_or_base_func =
185-
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
186-
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
187-
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
188-
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
189-
prim_func = GetRef<tir::PrimFunc>(result);
190-
} else {
191-
prim_func = tir::PrimFunc(nullptr);
192-
}
193-
}
194-
195-
// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
196-
if (!schedule.defined() && !prim_func.defined()) {
197-
ICHECK(anchor_implementation_.defined());
198-
schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
199-
}
200-
if (schedule.defined()) {
201-
for (const auto& scalar : scalars_) {
202-
if (schedule->Contain(scalar)) {
203-
schedule[scalar].compute_inline();
204-
}
205-
}
206-
}
207-
}
208-
209-
return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
210-
IRModule(Map<GlobalVar, BaseFunc>({})), constant_tensors_);
136+
return this->VisitExpr(relay_func->body);
211137
}
212138

213139
Array<te::Tensor> VisitExpr_(const VarNode* op) final {
@@ -254,7 +180,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
254180
}
255181

256182
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
257-
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
258183
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
259184
ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
260185

@@ -278,28 +203,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
278203
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
279204
Op op = Downcast<Op>(call_node->op);
280205

281-
Array<te::Tensor> outputs;
282-
OpImplementation impl;
283206
// TODO(mbs): device_copy cleanup
284207
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";
208+
285209
LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
286-
outputs = lowered_out->outputs;
287-
impl = lowered_out->implementation;
210+
Array<te::Tensor> outputs = lowered_out->outputs;
211+
anchor_implementation_ = lowered_out->implementation;
288212

289-
if (create_schedule_) {
290-
int op_pattern = fpattern[op];
291-
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
292-
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
293-
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
294-
<< " anchor=" << anchor_op_ << " current=" << op;
295-
}
296-
if (op_pattern >= anchor_op_pattern_) {
297-
anchor_op_ = op;
298-
anchor_attrs_ = call_node->attrs;
299-
anchor_op_pattern_ = op_pattern;
300-
anchor_implementation_ = impl;
301-
}
302-
}
303213
if (outputs.size() != 1) {
304214
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
305215
ICHECK(tuple_type) << "Expected output to be a tuple type "
@@ -308,8 +218,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
308218
ICHECK_EQ(tuple_type->fields.size(), outputs.size());
309219
}
310220

311-
// TODO(mbs): device_copy cleanup
312-
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";
313221
readable_name_stream_ << '_' << op->name;
314222
return outputs;
315223
}
@@ -347,27 +255,146 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
347255
return {tuple[op->index]};
348256
}
349257

258+
public:
259+
// Additional outputs
260+
Array<tvm::te::Tensor> fn_inputs_;
261+
Array<te::Operation> scalars_;
262+
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
263+
std::ostringstream readable_name_stream_;
264+
OpImplementation anchor_implementation_;
265+
266+
private:
267+
tvm::Target target_;
268+
// Index of the global constants
269+
static int const_index;
270+
// Cache device copy op for equivalence checking to reduce registry lookup
271+
// overhead for each invocation of call node when retrieving schedules.
272+
const Op& device_copy_op_;
273+
};
274+
275+
int LowerToTECompute::const_index = 0;
276+
277+
// Construct a schedule for a given Relay primitive function and target.
278+
class ScheduleBuilder : ExprVisitor {
279+
public:
280+
explicit ScheduleBuilder(Target target, bool create_schedule = true)
281+
: target_(target),
282+
283+
create_schedule_(create_schedule) {
284+
// Whether to use auto_scheduler schedule.
285+
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
286+
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
287+
}
288+
289+
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
290+
LowerToTECompute lower_te_compute(target_);
291+
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
292+
std::string candidate_name = lower_te_compute.readable_name_stream_.str();
293+
VisitExpr(relay_func->body);
294+
295+
constexpr static size_t kMaxFuncNameLength = 80;
296+
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
297+
// whenever the value of kMaxFuncNameLength changes
298+
if (candidate_name.size() > kMaxFuncNameLength) {
299+
std::stringstream truncated_name;
300+
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
301+
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
302+
candidate_name = truncated_name.str();
303+
}
304+
305+
// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
306+
// no other GlobalVar ctors should appear inside the lowering machinery.
307+
auto prim_fn_var = GlobalVar(renamer(candidate_name));
308+
prim_fn_var->checked_type_ = relay_func->checked_type();
309+
310+
// Fusion over tupled results may leave identity relationships
311+
// between inputs and outputs, and those should not be scheduled.
312+
// Hence schedule only non PlaceholderOp outputs.
313+
tvm::Array<te::Tensor> tensor_outs;
314+
for (const auto& tensor : outputs) {
315+
if (!tensor->op.as<te::PlaceholderOpNode>()) {
316+
tensor_outs.push_back(tensor);
317+
}
318+
}
319+
320+
te::Schedule schedule{nullptr};
321+
tir::PrimFunc prim_func{nullptr};
322+
// No need to register schedule for device copy op.
323+
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) {
324+
if (use_auto_scheduler_) {
325+
const auto* fauto_schedule =
326+
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
327+
ICHECK(fauto_schedule != nullptr)
328+
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
329+
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
330+
if (obj.defined()) {
331+
schedule = Downcast<te::Schedule>(obj);
332+
}
333+
}
334+
if (use_meta_schedule_) {
335+
prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs);
336+
Optional<ObjectRef> opt_mod_or_base_func =
337+
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
338+
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
339+
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
340+
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
341+
prim_func = GetRef<tir::PrimFunc>(result);
342+
} else {
343+
prim_func = tir::PrimFunc(nullptr);
344+
}
345+
}
346+
347+
// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
348+
if (!schedule.defined() && !prim_func.defined()) {
349+
ICHECK(lower_te_compute.anchor_implementation_.defined());
350+
schedule =
351+
lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
352+
}
353+
if (schedule.defined()) {
354+
for (const auto& scalar : lower_te_compute.scalars_) {
355+
if (schedule->Contain(scalar)) {
356+
schedule[scalar].compute_inline();
357+
}
358+
}
359+
}
360+
}
361+
362+
return CachedFunc(target_, prim_fn_var, lower_te_compute.fn_inputs_, outputs, schedule,
363+
prim_func, {}, IRModule(Map<GlobalVar, BaseFunc>({})),
364+
lower_te_compute.constant_tensors_);
365+
}
366+
367+
void VisitExpr_(const CallNode* call_node) final {
368+
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
369+
370+
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
371+
Op op = Downcast<Op>(call_node->op);
372+
373+
if (create_schedule_) {
374+
int op_pattern = fpattern[op];
375+
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
376+
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
377+
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
378+
<< " anchor=" << anchor_op_ << " current=" << op;
379+
}
380+
if (op_pattern >= anchor_op_pattern_) {
381+
anchor_op_ = op;
382+
anchor_attrs_ = call_node->attrs;
383+
anchor_op_pattern_ = op_pattern;
384+
}
385+
}
386+
}
387+
350388
private:
351389
tvm::Target target_;
352390
Op anchor_op_;
353391
Attrs anchor_attrs_;
354392
int anchor_op_pattern_{0};
355-
OpImplementation anchor_implementation_;
356-
std::ostringstream readable_name_stream_;
357-
Array<te::Operation> scalars_;
358-
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
359393
bool use_auto_scheduler_;
360394
bool use_meta_schedule_;
361-
// Cache device copy op for equivalence checking to reduce registry lookup
362-
// overhead for each invocation of call node when retrieving schedules.
363-
const Op& device_copy_op_;
364395
bool create_schedule_;
365-
// Index of the global constants
366-
static int const_index;
367396
};
368397

369-
int ScheduleBuilder::const_index = 0;
370-
371398
/*!
372399
* \brief Create schedule for target.
373400
* \param source_func The primitive function to be lowered.

0 commit comments

Comments
 (0)