@@ -224,7 +224,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
224224
225225 LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
226226 Array<te::Tensor> outputs = lowered_out->outputs ;
227- anchor_implementation_ = lowered_out->implementation ;
227+ op_implementations_[op. operator ->()] = lowered_out->implementation ;
228228
229229 if (outputs.size () != 1 ) {
230230 const auto * tuple_type = call_node->checked_type ().as <TupleTypeNode>();
@@ -276,8 +276,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
276276 Array<tvm::te::Tensor> fn_inputs_;
277277 Array<te::Operation> scalars_;
278278 std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
279+ std::unordered_map<const OpNode*, OpImplementation> op_implementations_;
279280 std::string candidate_name_;
280- OpImplementation anchor_implementation_;
281281
282282 private:
283283 tvm::Target target_;
@@ -300,7 +300,6 @@ class ScheduleBuilder : public ExprVisitor {
300300 }
301301
302302 CachedFunc Create (const Function& relay_func, std::function<std::string(std::string)> renamer) {
303- LOG (INFO) << relay_func;
304303 LowerToTECompute lower_te_compute (target_);
305304 Array<te::Tensor> outputs = lower_te_compute.Lower (relay_func, renamer);
306305 Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_ ;
@@ -350,11 +349,9 @@ class ScheduleBuilder : public ExprVisitor {
350349
351350 // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
352351 if (!schedule.defined () && !prim_func.defined ()) {
353- ICHECK (lower_te_compute.anchor_implementation_ .defined ());
354- LOG (INFO) << lower_te_compute.candidate_name_ ;
355- LOG (INFO) << anchor_attrs_;
356- schedule =
357- lower_te_compute.anchor_implementation_ .Schedule (anchor_attrs_, tensor_outs, target_);
352+ auto anchor_impl = lower_te_compute.op_implementations_ .find (anchor_op_.operator ->());
353+ ICHECK (anchor_impl != lower_te_compute.op_implementations_ .end ());
354+ schedule = anchor_impl->second .Schedule (anchor_attrs_, tensor_outs, target_);
358355 }
359356 if (schedule.defined ()) {
360357 for (const auto & scalar : lower_te_compute.scalars_ ) {
0 commit comments