From 6e68fd9aff9f86412f8b7150b18ae1b374927f86 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 10 Mar 2022 14:27:34 +0900 Subject: [PATCH 1/6] Decouple TE compute and schedule lowering in ScheduleBuilder --- src/relay/backend/te_compiler_cache.cc | 255 ++++++++++++++----------- 1 file changed, 141 insertions(+), 114 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index abab8cc6e0a0..a1de51de728d 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -45,6 +45,7 @@ #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" +#include "tvm/relay/op_strategy.h" #include "utils.h" namespace tvm { @@ -115,99 +116,24 @@ Array GetShape(const Array& shape) { } // Construct a schedule for a given Relay primitive function and target. -class ScheduleBuilder : public backend::MemoizedExprTranslator> { +class LowerToTECompute : public backend::MemoizedExprTranslator> { public: - explicit ScheduleBuilder(Target target, bool create_schedule = true) - : target_(target), - device_copy_op_(Op::Get("device_copy")), - create_schedule_(create_schedule) { - // Whether to use auto_scheduler schedule. - use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - use_meta_schedule_ = backend::IsMetaScheduleEnabled(); - } + explicit LowerToTECompute(Target target) + : target_(target), device_copy_op_(Op::Get("device_copy")) {} - CachedFunc Create(const Function& relay_func, std::function renamer) { - Array fn_inputs; + Array Lower(const Function& relay_func, + std::function renamer) { for (Var param : relay_func->params) { Array inputs; for (const auto& ttype : FlattenTupleType(param->checked_type())) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - fn_inputs.push_back(tensor); inputs.push_back(tensor); + fn_inputs_.push_back(tensor); } memo_[param] = inputs; } readable_name_stream_ << "fused"; - auto outputs = this->VisitExpr(relay_func->body); - auto candidate_name = readable_name_stream_.str(); - constexpr static size_t kMaxFuncNameLength = 80; - // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME - // whenever the value of kMaxFuncNameLength changes - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hex << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - - // TODO(mbs): This should be the definitive global by which the PrimFunc is known and - // no other GlobalVar ctors should appear inside the lowering machinery. - auto prim_fn_var = GlobalVar(renamer(candidate_name)); - prim_fn_var->checked_type_ = relay_func->checked_type(); - - // Fusion over tupled results may leave identity relationships - // between inputs and outputs, and those should not be scheduled. - // Hence schedule only non PlaceholderOp outputs. - tvm::Array tensor_outs; - for (const auto& tensor : outputs) { - if (!tensor->op.as()) { - tensor_outs.push_back(tensor); - } - } - - te::Schedule schedule{nullptr}; - tir::PrimFunc prim_func{nullptr}; - // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr && create_schedule_) { - if (use_auto_scheduler_) { - const auto* fauto_schedule = - runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); - ICHECK(fauto_schedule != nullptr) - << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; - ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); - if (obj.defined()) { - schedule = Downcast(obj); - } - } - if (use_meta_schedule_) { - prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); - Optional opt_mod_or_base_func = - meta_schedule::MetaScheduleContext::QueryInsideWithScope( - prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, - Array{IRModule({{prim_fn_var, prim_func}})}); - if (const auto* result = opt_mod_or_base_func.as()) { - prim_func = GetRef(result); - } else { - prim_func = tir::PrimFunc(nullptr); - } - } - - // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined() && !prim_func.defined()) { - ICHECK(anchor_implementation_.defined()); - schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); - } - if (schedule.defined()) { - for (const auto& scalar : scalars_) { - if (schedule->Contain(scalar)) { - schedule[scalar].compute_inline(); - } - } - } - } - - return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}, - IRModule(Map({})), constant_tensors_); + return this->VisitExpr(relay_func->body); } Array VisitExpr_(const VarNode* op) final { @@ -254,7 +180,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = Op::GetAttrMap("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); ICHECK(flower_call) << "relay.backend.lower_call is not registered."; @@ -278,28 +203,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); - Array outputs; - OpImplementation impl; // TODO(mbs): device_copy cleanup ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); - outputs = lowered_out->outputs; - impl = lowered_out->implementation; + Array outputs = lowered_out->outputs; + anchor_implementation_ = lowered_out->implementation; - if (create_schedule_) { - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern >= anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - anchor_implementation_ = impl; - } - } if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); ICHECK(tuple_type) << "Expected output to be a tuple type " @@ -308,8 +218,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator ICHECK_EQ(tuple_type->fields.size(), outputs.size()); } - // TODO(mbs): device_copy cleanup - ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; readable_name_stream_ << '_' << op->name; return outputs; } @@ -347,27 +255,146 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator return {tuple[op->index]}; } + public: + // Additional outputs + Array fn_inputs_; + Array scalars_; + std::unordered_map constant_tensors_; + std::ostringstream readable_name_stream_; + OpImplementation anchor_implementation_; + + private: + tvm::Target target_; + // Index of the global constants + static int const_index; + // Cache device copy op for equivalence checking to reduce registry lookup + // overhead for each invocation of call node when retrieving schedules. + const Op& device_copy_op_; +}; + +int LowerToTECompute::const_index = 0; + +// Construct a schedule for a given Relay primitive function and target. +class ScheduleBuilder : ExprVisitor { + public: + explicit ScheduleBuilder(Target target, bool create_schedule = true) + : target_(target), + + create_schedule_(create_schedule) { + // Whether to use auto_scheduler schedule. + use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + use_meta_schedule_ = backend::IsMetaScheduleEnabled(); + } + + CachedFunc Create(const Function& relay_func, std::function renamer) { + LowerToTECompute lower_te_compute(target_); + Array outputs = lower_te_compute.Lower(relay_func, renamer); + std::string candidate_name = lower_te_compute.readable_name_stream_.str(); + VisitExpr(relay_func->body); + + constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hex << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + + // TODO(mbs): This should be the definitive global by which the PrimFunc is known and + // no other GlobalVar ctors should appear inside the lowering machinery. + auto prim_fn_var = GlobalVar(renamer(candidate_name)); + prim_fn_var->checked_type_ = relay_func->checked_type(); + + // Fusion over tupled results may leave identity relationships + // between inputs and outputs, and those should not be scheduled. + // Hence schedule only non PlaceholderOp outputs. + tvm::Array tensor_outs; + for (const auto& tensor : outputs) { + if (!tensor->op.as()) { + tensor_outs.push_back(tensor); + } + } + + te::Schedule schedule{nullptr}; + tir::PrimFunc prim_func{nullptr}; + // No need to register schedule for device copy op. + if (anchor_attrs_.as() == nullptr && create_schedule_) { + if (use_auto_scheduler_) { + const auto* fauto_schedule = + runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); + ICHECK(fauto_schedule != nullptr) + << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; + ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); + if (obj.defined()) { + schedule = Downcast(obj); + } + } + if (use_meta_schedule_) { + prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs); + Optional opt_mod_or_base_func = + meta_schedule::MetaScheduleContext::QueryInsideWithScope( + prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, + Array{IRModule({{prim_fn_var, prim_func}})}); + if (const auto* result = opt_mod_or_base_func.as()) { + prim_func = GetRef(result); + } else { + prim_func = tir::PrimFunc(nullptr); + } + } + + // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. + if (!schedule.defined() && !prim_func.defined()) { + ICHECK(lower_te_compute.anchor_implementation_.defined()); + schedule = + lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + } + if (schedule.defined()) { + for (const auto& scalar : lower_te_compute.scalars_) { + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } + } + } + } + + return CachedFunc(target_, prim_fn_var, lower_te_compute.fn_inputs_, outputs, schedule, + prim_func, {}, IRModule(Map({})), + lower_te_compute.constant_tensors_); + } + + void VisitExpr_(const CallNode* call_node) final { + static auto fpattern = Op::GetAttrMap("TOpPattern"); + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + if (create_schedule_) { + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + } + } + } + private: tvm::Target target_; Op anchor_op_; Attrs anchor_attrs_; int anchor_op_pattern_{0}; - OpImplementation anchor_implementation_; - std::ostringstream readable_name_stream_; - Array scalars_; - std::unordered_map constant_tensors_; bool use_auto_scheduler_; bool use_meta_schedule_; - // Cache device copy op for equivalence checking to reduce registry lookup - // overhead for each invocation of call node when retrieving schedules. - const Op& device_copy_op_; bool create_schedule_; - // Index of the global constants - static int const_index; }; -int ScheduleBuilder::const_index = 0; - /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. From eb1bc7e789b66eaf3d4fe01d5154c135ab275dc2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 10 Mar 2022 18:13:42 +0900 Subject: [PATCH 2/6] fixed merge conflict --- src/relay/backend/te_compiler_cache.cc | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index a1de51de728d..2c2042859ddb 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -278,9 +278,7 @@ int LowerToTECompute::const_index = 0; class ScheduleBuilder : ExprVisitor { public: explicit ScheduleBuilder(Target target, bool create_schedule = true) - : target_(target), - - create_schedule_(create_schedule) { + : target_(target), create_schedule_(create_schedule) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); use_meta_schedule_ = backend::IsMetaScheduleEnabled(); @@ -289,6 +287,7 @@ class ScheduleBuilder : ExprVisitor { CachedFunc Create(const Function& relay_func, std::function renamer) { LowerToTECompute lower_te_compute(target_); Array outputs = lower_te_compute.Lower(relay_func, renamer); + Array fn_inputs = lower_te_compute.fn_inputs_; std::string candidate_name = lower_te_compute.readable_name_stream_.str(); VisitExpr(relay_func->body); @@ -332,7 +331,7 @@ class ScheduleBuilder : ExprVisitor { } } if (use_meta_schedule_) { - prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs); + prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); Optional opt_mod_or_base_func = meta_schedule::MetaScheduleContext::QueryInsideWithScope( prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, @@ -359,9 +358,8 @@ class ScheduleBuilder : ExprVisitor { } } - return CachedFunc(target_, prim_fn_var, lower_te_compute.fn_inputs_, outputs, schedule, - prim_func, {}, IRModule(Map({})), - lower_te_compute.constant_tensors_); + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}, + IRModule(Map({})), lower_te_compute.constant_tensors_); } void VisitExpr_(const CallNode* call_node) final { From 4cd3a1657c4e2e13abe7281b7cdef5dff73b37ee Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 10 Mar 2022 18:43:15 +0900 Subject: [PATCH 3/6] removed create_schedule stuff --- src/relay/backend/te_compiler_cache.cc | 76 +++++++++++++------------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 2c2042859ddb..fc3a3ab335f4 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -46,6 +46,7 @@ #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" #include "tvm/relay/op_strategy.h" +#include "tvm/tir/function.h" #include "utils.h" namespace tvm { @@ -115,7 +116,7 @@ Array GetShape(const Array& shape) { return res; } -// Construct a schedule for a given Relay primitive function and target. +// Lowers Relay primitive Function to TE Compute class LowerToTECompute : public backend::MemoizedExprTranslator> { public: explicit LowerToTECompute(Target target) @@ -133,7 +134,21 @@ class LowerToTECompute : public backend::MemoizedExprTranslatorVisitExpr(relay_func->body); + + Array outputs = this->VisitExpr(relay_func->body); + + candidate_name_ = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes + if (candidate_name_.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name_.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hex << std::hash{}(candidate_name_) << "_"; + candidate_name_ = truncated_name.str(); + } + + return outputs; } Array VisitExpr_(const VarNode* op) final { @@ -260,11 +275,12 @@ class LowerToTECompute : public backend::MemoizedExprTranslator fn_inputs_; Array scalars_; std::unordered_map constant_tensors_; - std::ostringstream readable_name_stream_; + std::string candidate_name_; OpImplementation anchor_implementation_; private: tvm::Target target_; + std::ostringstream readable_name_stream_; // Index of the global constants static int const_index; // Cache device copy op for equivalence checking to reduce registry lookup @@ -277,33 +293,20 @@ int LowerToTECompute::const_index = 0; // Construct a schedule for a given Relay primitive function and target. class ScheduleBuilder : ExprVisitor { public: - explicit ScheduleBuilder(Target target, bool create_schedule = true) - : target_(target), create_schedule_(create_schedule) { + explicit ScheduleBuilder(Target target) : target_(target) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - use_meta_schedule_ = backend::IsMetaScheduleEnabled(); } CachedFunc Create(const Function& relay_func, std::function renamer) { LowerToTECompute lower_te_compute(target_); Array outputs = lower_te_compute.Lower(relay_func, renamer); Array fn_inputs = lower_te_compute.fn_inputs_; - std::string candidate_name = lower_te_compute.readable_name_stream_.str(); VisitExpr(relay_func->body); - constexpr static size_t kMaxFuncNameLength = 80; - // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME - // whenever the value of kMaxFuncNameLength changes - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hex << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - // TODO(mbs): This should be the definitive global by which the PrimFunc is known and // no other GlobalVar ctors should appear inside the lowering machinery. - auto prim_fn_var = GlobalVar(renamer(candidate_name)); + auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_)); prim_fn_var->checked_type_ = relay_func->checked_type(); // Fusion over tupled results may leave identity relationships @@ -319,7 +322,7 @@ class ScheduleBuilder : ExprVisitor { te::Schedule schedule{nullptr}; tir::PrimFunc prim_func{nullptr}; // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr && create_schedule_) { + if (anchor_attrs_.as() == nullptr) { if (use_auto_scheduler_) { const auto* fauto_schedule = runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); @@ -330,7 +333,7 @@ class ScheduleBuilder : ExprVisitor { schedule = Downcast(obj); } } - if (use_meta_schedule_) { + if (backend::IsMetaScheduleEnabled()) { prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); Optional opt_mod_or_base_func = meta_schedule::MetaScheduleContext::QueryInsideWithScope( @@ -368,18 +371,16 @@ class ScheduleBuilder : ExprVisitor { ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); - if (create_schedule_) { - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern >= anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - } + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; } } @@ -389,8 +390,6 @@ class ScheduleBuilder : ExprVisitor { Attrs anchor_attrs_; int anchor_op_pattern_{0}; bool use_auto_scheduler_; - bool use_meta_schedule_; - bool create_schedule_; }; /*! @@ -775,9 +774,12 @@ std::string GetUniqueName(std::string name, std::unordered_map } TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { - return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) { - return name; - }); + auto tgt = tvm::Target("ext_dev"); + LowerToTECompute lower_te_compute(tgt); + auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { return name; }); + return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_, + outputs, te::Schedule(), tir::PrimFunc(), {}, + IRModule(Map({})), lower_te_compute.constant_tensors_); }); } // namespace tec From 0c6d4a603335ae2cba2771e939eff1ddeb98fbe3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 10:45:08 +0900 Subject: [PATCH 4/6] add public, fix include path convention --- src/relay/backend/te_compiler_cache.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index fc3a3ab335f4..276c7f9f017e 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -28,11 +28,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -45,8 +47,6 @@ #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" -#include "tvm/relay/op_strategy.h" -#include "tvm/tir/function.h" #include "utils.h" namespace tvm { @@ -138,6 +138,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslator outputs = this->VisitExpr(relay_func->body); candidate_name_ = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME // whenever the value of kMaxFuncNameLength changes @@ -291,7 +292,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslator Date: Fri, 11 Mar 2022 10:57:02 +0900 Subject: [PATCH 5/6] Forgot visiting arg in ScheduleBuilder CallNode vsit --- src/relay/backend/te_compiler_cache.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 276c7f9f017e..74b9013b3659 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -300,6 +300,7 @@ class ScheduleBuilder : public ExprVisitor { } CachedFunc Create(const Function& relay_func, std::function renamer) { + LOG(INFO) << relay_func; LowerToTECompute lower_te_compute(target_); Array outputs = lower_te_compute.Lower(relay_func, renamer); Array fn_inputs = lower_te_compute.fn_inputs_; @@ -350,6 +351,8 @@ class ScheduleBuilder : public ExprVisitor { // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. if (!schedule.defined() && !prim_func.defined()) { ICHECK(lower_te_compute.anchor_implementation_.defined()); + LOG(INFO) << lower_te_compute.candidate_name_; + LOG(INFO) << anchor_attrs_; schedule = lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); } @@ -372,6 +375,10 @@ class ScheduleBuilder : public ExprVisitor { ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); + for (Expr arg : call_node->args) { + VisitExpr(arg); + } + int op_pattern = fpattern[op]; if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) From 6f019014a4614f43aefcf642981bfb15d64b09f3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 11:25:44 +0900 Subject: [PATCH 6/6] fixed anchor impl selection --- src/relay/backend/te_compiler_cache.cc | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 74b9013b3659..ffcce6e1c8da 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -224,7 +224,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslator(call_node), inputs, target_); Array outputs = lowered_out->outputs; - anchor_implementation_ = lowered_out->implementation; + op_implementations_[op.operator->()] = lowered_out->implementation; if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); @@ -276,8 +276,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator fn_inputs_; Array scalars_; std::unordered_map constant_tensors_; + std::unordered_map op_implementations_; std::string candidate_name_; - OpImplementation anchor_implementation_; private: tvm::Target target_; @@ -300,7 +300,6 @@ class ScheduleBuilder : public ExprVisitor { } CachedFunc Create(const Function& relay_func, std::function renamer) { - LOG(INFO) << relay_func; LowerToTECompute lower_te_compute(target_); Array outputs = lower_te_compute.Lower(relay_func, renamer); Array fn_inputs = lower_te_compute.fn_inputs_; @@ -350,11 +349,9 @@ class ScheduleBuilder : public ExprVisitor { // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. if (!schedule.defined() && !prim_func.defined()) { - ICHECK(lower_te_compute.anchor_implementation_.defined()); - LOG(INFO) << lower_te_compute.candidate_name_; - LOG(INFO) << anchor_attrs_; - schedule = - lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->()); + ICHECK(anchor_impl != lower_te_compute.op_implementations_.end()); + schedule = anchor_impl->second.Schedule(anchor_attrs_, tensor_outs, target_); } if (schedule.defined()) { for (const auto& scalar : lower_te_compute.scalars_) {