Skip to content

Commit 6f01901

Browse files
committed
fixed anchor impl selection
1 parent be6c258 commit 6f01901

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

src/relay/backend/te_compiler_cache.cc

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
224224

225225
LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
226226
Array<te::Tensor> outputs = lowered_out->outputs;
227-
anchor_implementation_ = lowered_out->implementation;
227+
op_implementations_[op.operator->()] = lowered_out->implementation;
228228

229229
if (outputs.size() != 1) {
230230
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
@@ -276,8 +276,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
276276
Array<tvm::te::Tensor> fn_inputs_;
277277
Array<te::Operation> scalars_;
278278
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
279+
std::unordered_map<const OpNode*, OpImplementation> op_implementations_;
279280
std::string candidate_name_;
280-
OpImplementation anchor_implementation_;
281281

282282
private:
283283
tvm::Target target_;
@@ -300,7 +300,6 @@ class ScheduleBuilder : public ExprVisitor {
300300
}
301301

302302
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
303-
LOG(INFO) << relay_func;
304303
LowerToTECompute lower_te_compute(target_);
305304
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
306305
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
@@ -350,11 +349,9 @@ class ScheduleBuilder : public ExprVisitor {
350349

351350
// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
352351
if (!schedule.defined() && !prim_func.defined()) {
353-
ICHECK(lower_te_compute.anchor_implementation_.defined());
354-
LOG(INFO) << lower_te_compute.candidate_name_;
355-
LOG(INFO) << anchor_attrs_;
356-
schedule =
357-
lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
352+
auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->());
353+
ICHECK(anchor_impl != lower_te_compute.op_implementations_.end());
354+
schedule = anchor_impl->second.Schedule(anchor_attrs_, tensor_outs, target_);
358355
}
359356
if (schedule.defined()) {
360357
for (const auto& scalar : lower_te_compute.scalars_) {

0 commit comments

Comments
 (0)