Skip to content

Commit f21f8f4

Browse files
committed
Decouple TE compute and schedule lowering in ScheduleBuilder
1 parent 48793f3 commit f21f8f4

File tree

1 file changed

+141
-113
lines changed

1 file changed

+141
-113
lines changed

src/relay/backend/te_compiler_cache.cc

Lines changed: 141 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "../../te/operation/create_primfunc.h"
4646
#include "../op/memory/memory.h"
4747
#include "../transforms/pass_utils.h"
48+
#include "tvm/relay/op_strategy.h"
4849
#include "utils.h"
4950

5051
namespace tvm {
@@ -115,99 +116,25 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
115116
}
116117

117118
// Construct a schedule for a given Relay primitive function and target.
118-
class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
119+
class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
119120
public:
120-
explicit ScheduleBuilder(Target target, bool create_schedule = true)
121-
: target_(target),
122-
device_copy_op_(Op::Get("device_copy")),
123-
create_schedule_(create_schedule) {
124-
// Whether to use auto_scheduler schedule.
125-
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
126-
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
127-
}
121+
explicit LowerToTECompute(Target target)
122+
: target_(target), device_copy_op_(Op::Get("device_copy")) {}
128123

129-
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
130-
Array<tvm::te::Tensor> fn_inputs;
124+
Array<te::Tensor> Lower(const Function& relay_func,
125+
std::function<std::string(std::string)> renamer) {
131126
for (Var param : relay_func->params) {
132127
Array<tvm::te::Tensor> inputs;
133128
for (const auto& ttype : FlattenTupleType(param->checked_type())) {
134129
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
135-
fn_inputs.push_back(tensor);
136130
inputs.push_back(tensor);
131+
fn_inputs_.push_back(tensor);
137132
}
138133
memo_[param] = inputs;
139134
}
140135
readable_name_stream_ << "fused";
141136
auto outputs = this->VisitExpr(relay_func->body);
142-
auto candidate_name = readable_name_stream_.str();
143-
constexpr static size_t kMaxFuncNameLength = 80;
144-
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
145-
// whenever the value of kMaxFuncNameLength changes
146-
if (candidate_name.size() > kMaxFuncNameLength) {
147-
std::stringstream truncated_name;
148-
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
149-
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
150-
candidate_name = truncated_name.str();
151-
}
152-
153-
// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
154-
// no other GlobalVar ctors should appear inside the lowering machinery.
155-
auto prim_fn_var = GlobalVar(renamer(candidate_name));
156-
prim_fn_var->checked_type_ = relay_func->checked_type();
157-
158-
// Fusion over tupled results may leave identity relationships
159-
// between inputs and outputs, and those should not be scheduled.
160-
// Hence schedule only non PlaceholderOp outputs.
161-
tvm::Array<te::Tensor> tensor_outs;
162-
for (const auto& tensor : outputs) {
163-
if (!tensor->op.as<te::PlaceholderOpNode>()) {
164-
tensor_outs.push_back(tensor);
165-
}
166-
}
167-
168-
te::Schedule schedule{nullptr};
169-
tir::PrimFunc prim_func{nullptr};
170-
// No need to register schedule for device copy op.
171-
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) {
172-
if (use_auto_scheduler_) {
173-
const auto* fauto_schedule =
174-
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
175-
ICHECK(fauto_schedule != nullptr)
176-
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
177-
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
178-
if (obj.defined()) {
179-
schedule = Downcast<te::Schedule>(obj);
180-
}
181-
}
182-
if (use_meta_schedule_) {
183-
prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs);
184-
Optional<ObjectRef> opt_mod_or_base_func =
185-
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
186-
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
187-
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
188-
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
189-
prim_func = GetRef<tir::PrimFunc>(result);
190-
} else {
191-
prim_func = tir::PrimFunc(nullptr);
192-
}
193-
}
194-
195-
// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
196-
if (!schedule.defined() && !prim_func.defined()) {
197-
ICHECK(anchor_implementation_.defined());
198-
schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
199-
}
200-
if (schedule.defined()) {
201-
for (const auto& scalar : scalars_) {
202-
if (schedule->Contain(scalar)) {
203-
schedule[scalar].compute_inline();
204-
}
205-
}
206-
}
207-
}
208-
209-
return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
210-
IRModule(Map<GlobalVar, BaseFunc>({})), constant_tensors_);
137+
return outputs;
211138
}
212139

213140
Array<te::Tensor> VisitExpr_(const VarNode* op) final {
@@ -254,7 +181,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
254181
}
255182

256183
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
257-
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
258184
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
259185
ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
260186

@@ -278,28 +204,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
278204
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
279205
Op op = Downcast<Op>(call_node->op);
280206

281-
Array<te::Tensor> outputs;
282-
OpImplementation impl;
283207
// TODO(mbs): device_copy cleanup
284208
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";
209+
285210
LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
286-
outputs = lowered_out->outputs;
287-
impl = lowered_out->implementation;
211+
Array<te::Tensor> outputs = lowered_out->outputs;
212+
anchor_implementation_ = lowered_out->implementation;
288213

289-
if (create_schedule_) {
290-
int op_pattern = fpattern[op];
291-
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
292-
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
293-
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
294-
<< " anchor=" << anchor_op_ << " current=" << op;
295-
}
296-
if (op_pattern >= anchor_op_pattern_) {
297-
anchor_op_ = op;
298-
anchor_attrs_ = call_node->attrs;
299-
anchor_op_pattern_ = op_pattern;
300-
anchor_implementation_ = impl;
301-
}
302-
}
303214
if (outputs.size() != 1) {
304215
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
305216
ICHECK(tuple_type) << "Expected output to be a tuple type "
@@ -308,8 +219,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
308219
ICHECK_EQ(tuple_type->fields.size(), outputs.size());
309220
}
310221

311-
// TODO(mbs): device_copy cleanup
312-
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";
313222
readable_name_stream_ << '_' << op->name;
314223
return outputs;
315224
}
@@ -347,27 +256,146 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
347256
return {tuple[op->index]};
348257
}
349258

259+
public:
260+
// outputs
261+
Array<tvm::te::Tensor> fn_inputs_;
262+
Array<te::Operation> scalars_;
263+
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
264+
std::ostringstream readable_name_stream_;
265+
OpImplementation anchor_implementation_;
266+
267+
private:
268+
tvm::Target target_;
269+
// Index of the global constants
270+
static int const_index;
271+
// Cache device copy op for equivalence checking to reduce registry lookup
272+
// overhead for each invocation of call node when retrieving schedules.
273+
const Op& device_copy_op_;
274+
};
275+
276+
int LowerToTECompute::const_index = 0;
277+
278+
// Construct a schedule for a given Relay primitive function and target.
279+
class ScheduleBuilder : ExprVisitor {
280+
public:
281+
explicit ScheduleBuilder(Target target, bool create_schedule = true)
282+
: target_(target),
283+
284+
create_schedule_(create_schedule) {
285+
// Whether to use auto_scheduler schedule.
286+
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
287+
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
288+
}
289+
290+
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
291+
LowerToTECompute lower_te_compute(target_);
292+
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
293+
std::string candidate_name = lower_te_compute.readable_name_stream_.str();
294+
VisitExpr(relay_func->body);
295+
296+
constexpr static size_t kMaxFuncNameLength = 80;
297+
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
298+
// whenever the value of kMaxFuncNameLength changes
299+
if (candidate_name.size() > kMaxFuncNameLength) {
300+
std::stringstream truncated_name;
301+
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
302+
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
303+
candidate_name = truncated_name.str();
304+
}
305+
306+
// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
307+
// no other GlobalVar ctors should appear inside the lowering machinery.
308+
auto prim_fn_var = GlobalVar(renamer(candidate_name));
309+
prim_fn_var->checked_type_ = relay_func->checked_type();
310+
311+
// Fusion over tupled results may leave identity relationships
312+
// between inputs and outputs, and those should not be scheduled.
313+
// Hence schedule only non PlaceholderOp outputs.
314+
tvm::Array<te::Tensor> tensor_outs;
315+
for (const auto& tensor : outputs) {
316+
if (!tensor->op.as<te::PlaceholderOpNode>()) {
317+
tensor_outs.push_back(tensor);
318+
}
319+
}
320+
321+
te::Schedule schedule{nullptr};
322+
tir::PrimFunc prim_func{nullptr};
323+
// No need to register schedule for device copy op.
324+
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) {
325+
if (use_auto_scheduler_) {
326+
const auto* fauto_schedule =
327+
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
328+
ICHECK(fauto_schedule != nullptr)
329+
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
330+
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
331+
if (obj.defined()) {
332+
schedule = Downcast<te::Schedule>(obj);
333+
}
334+
}
335+
if (use_meta_schedule_) {
336+
prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs);
337+
Optional<ObjectRef> opt_mod_or_base_func =
338+
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
339+
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
340+
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
341+
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
342+
prim_func = GetRef<tir::PrimFunc>(result);
343+
} else {
344+
prim_func = tir::PrimFunc(nullptr);
345+
}
346+
}
347+
348+
// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
349+
if (!schedule.defined() && !prim_func.defined()) {
350+
ICHECK(lower_te_compute.anchor_implementation_.defined());
351+
schedule =
352+
lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
353+
}
354+
if (schedule.defined()) {
355+
for (const auto& scalar : lower_te_compute.scalars_) {
356+
if (schedule->Contain(scalar)) {
357+
schedule[scalar].compute_inline();
358+
}
359+
}
360+
}
361+
}
362+
363+
return CachedFunc(target_, prim_fn_var, lower_te_compute.fn_inputs_, outputs, schedule,
364+
prim_func, {}, IRModule(Map<GlobalVar, BaseFunc>({})),
365+
lower_te_compute.constant_tensors_);
366+
}
367+
368+
void VisitExpr_(const CallNode* call_node) final {
369+
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
370+
371+
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
372+
Op op = Downcast<Op>(call_node->op);
373+
374+
if (create_schedule_) {
375+
int op_pattern = fpattern[op];
376+
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
377+
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
378+
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
379+
<< " anchor=" << anchor_op_ << " current=" << op;
380+
}
381+
if (op_pattern >= anchor_op_pattern_) {
382+
anchor_op_ = op;
383+
anchor_attrs_ = call_node->attrs;
384+
anchor_op_pattern_ = op_pattern;
385+
}
386+
}
387+
}
388+
350389
private:
351390
tvm::Target target_;
352391
Op anchor_op_;
353392
Attrs anchor_attrs_;
354393
int anchor_op_pattern_{0};
355-
OpImplementation anchor_implementation_;
356-
std::ostringstream readable_name_stream_;
357-
Array<te::Operation> scalars_;
358-
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
359394
bool use_auto_scheduler_;
360395
bool use_meta_schedule_;
361-
// Cache device copy op for equivalence checking to reduce registry lookup
362-
// overhead for each invocation of call node when retrieving schedules.
363-
const Op& device_copy_op_;
364396
bool create_schedule_;
365-
// Index of the global constants
366-
static int const_index;
367397
};
368398

369-
int ScheduleBuilder::const_index = 0;
370-
371399
/*!
372400
* \brief Create schedule for target.
373401
* \param source_func The primitive function to be lowered.

0 commit comments

Comments
 (0)