From ba8aaca44a912bcb0988a81020bff87dcf2513a1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Mar 2023 15:51:46 -0500 Subject: [PATCH 1/4] [Driver] Single-module lowering flow in driver_api.cc Prior to this commit, a build that used multiple targets needed to provide `tvm::build` with a `Map` specifying which target should be used to compile each `IRModule`. As a result, lowering passes could not introduce new targets based on a PrimFunc's content (e.g. a `with T.target()` frame to delegate out to another device), nor simplify based on cross-device subroutines (e.g. simplify a host-side conditional based on the known output of a device-side internal subroutine). This commit makes the `tvm::attr::kTarget` attribute (`"target"`) be the single source of truth for where a `PrimFunc` will be executed. Other existing methods for specifying the target (the `target` parameter for `tvm.build`, the keys in a `Map`, the parameter to the pass `tir::transform::BindTarget`) are still accepted as inputs, and may provide a default value for `tvm::attr::kTarget` if the attribute is missing, but may not overwrite the target attribute. This is part of a series of commits to simplify the handling of multi-target builds. --- apps/extension/tests/test_ext.py | 2 +- include/tvm/driver/driver_api.h | 3 +- .../backend/contrib/ethosu/tir/compiler.py | 4 +- src/driver/driver_api.cc | 228 +++++++---- .../example_target_hooks/relay_to_tir.cc | 2 +- src/target/llvm/llvm_module.cc | 16 +- src/target/source/codegen_c_host.h | 10 +- src/target/target_kind.cc | 2 +- src/tir/op/op.cc | 5 + src/tir/transforms/annotate_device_regions.cc | 115 +++++- .../transforms/lower_device_kernel_launch.cc | 4 +- src/tir/transforms/lower_intrin.cc | 4 + src/tir/transforms/split_host_device.cc | 3 +- ...t_tir_transform_annotate_device_regions.py | 71 ++++ .../test_tir_transform_split_host_device.py | 5 + vta/python/vta/transform.py | 9 +- vta/scripts/tune_resnet.py | 373 ++++++++++++++++++ vta/tutorials/matrix_multiply.py | 6 +- vta/tutorials/optimize/convolution_opt.py | 6 +- vta/tutorials/optimize/matrix_multiply_opt.py | 6 +- vta/tutorials/vta_get_started.py | 6 +- 21 files changed, 762 insertions(+), 118 deletions(-) create mode 100644 vta/scripts/tune_resnet.py diff --git a/apps/extension/tests/test_ext.py b/apps/extension/tests/test_ext.py index 994a673298f1..d387263a06a8 100644 --- a/apps/extension/tests/test_ext.py +++ b/apps/extension/tests/test_ext.py @@ -39,7 +39,7 @@ def test_ext_dev(): def check_llvm(): if not tvm.testing.device_enabled("llvm"): return - f = tvm.build(s, [A, B], "ext_dev", "llvm") + f = tvm.build(s, [A, B], "ext_dev", "ext_dev") dev = tvm.ext_dev(0) # launch the kernel. a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index fffcab49667c..14ea5119e0e5 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -54,7 +54,8 @@ using tvm::transform::Pass; * \param target The device Target. * \return The composite Pass for the fused module. // */ -TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target); +TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, + Optional target = NullOpt); /*! * \brief Configures and returns the composite Pass for the device Target after device/host from diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index d47b3d4a7de6..8f6232347059 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -209,7 +209,9 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function: primfunc = tir_mod["main"] primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"]) primfunc = primfunc.with_attr("ethos-u.constants", const_dict) - primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name)) + primfunc = primfunc.with_attr( + "target", tvm.target.Target(compiler_name, host=compiler_name) + ) return primfunc def __call__(self, *args, **kwargs): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 1e576bc91002..e6335449b1c2 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -281,17 +281,6 @@ Array CreatePassList(bool disable_loop_partition) { return pass_list; } -IRModule LowerWithPassList(IRModule mod, Array pass_list) { - auto optimize = tvm::transform::Sequential(pass_list); - mod = optimize(std::move(mod)); - return mod; -} - -IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { - mod = seq(std::move(mod)); - return mod; -} - // Convert te schedule to IRModule IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, @@ -343,7 +332,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") IRModule LowerModule(IRModule mod, bool simple_mode) { Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(std::move(mod), pass_list); + tvm::transform::Sequential optimize(pass_list, "tvm.lower"); + return optimize(std::move(mod)); } TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) { @@ -360,10 +350,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_ f = WithAttr(std::move(f), "tir.noalias", Bool(true)); } IRModule mod = IRModule(Map({{GlobalVar(name), f}})); - - // Get the pass list - Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(std::move(mod), pass_list); + return LowerModule(mod, simple_mode); } TVM_REGISTER_GLOBAL("driver.lower_primfunc") @@ -385,9 +372,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array& args, const std const std::unordered_map& binds, GlobalVarSupply global_var_supply, bool simple_mode) { IRModule mod = ScheduleToModule(std::move(sch), args, name, binds, global_var_supply); - // Get the legacy TE pass list - Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(mod, pass_list); + return LowerModule(mod, simple_mode); } TVM_REGISTER_GLOBAL("driver.lower_schedule") @@ -403,35 +388,42 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(), simple_mode); }); -/** - * This function takes the input module that contains both the device and host opts. - * Then, it applies transformation on the original module before splitting into separate modules for - * device and host. Then it also applies transformations on the new splitted modules. - */ -std::pair SplitMixedModule(IRModule mod_mixed, const Target& target_arg, - const Target& target_host_arg) { - Target target = target_arg, target_host = target_host_arg; - CheckAndUpdateHostConsistency(&target, &target_host); - - ICHECK(mod_mixed.defined()) << "This module must be defined"; +IRModule MergeModules(const Map& inputs) { + if (inputs.size() == 1) { + auto [target, mod] = *inputs.begin(); + return tir::transform::BindTarget(target)(mod); + } - mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); + // Take the attrs from the first module so the eventual modules have them. + IRModule first_module = (*inputs.begin()).second; + IRModule merged = IRModule(Map(), {}, {}, {}, first_module->attrs); - IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); + for (auto [target, mod] : inputs) { + mod = tir::transform::BindTarget(target)(mod); + merged->Update(mod); + } - IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target)); + return merged; +} - auto keys = target->GetKeys(); +Map SplitModule(const IRModule& module) { + Map split; - CheckAndUpdateHostConsistency(&target, &target_host); + for (auto [gvar, base_func] : module->functions) { + auto target_str = base_func->GetAttr(tvm::attr::kTarget).value()->str(); + if (auto it = split.find(target_str); it != split.end()) { + (*it).second->Add(gvar, base_func); + } else { + split.Set(target_str, IRModule({{gvar, base_func}}, {}, {}, {}, module->attrs)); + } + } - bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); - if (target_is_gpu && device_mod->functions.size() == 0) { - DLOG(WARNING) << "Specified target " << target->str() - << " but cannot find device code. Did you forget to bind?"; + Map out; + for (auto [str, mod] : split) { + out.Set(Target(str), mod); } - return {host_mod, device_mod}; + return out; } /*! @@ -478,52 +470,86 @@ runtime::Module TIRToRuntime(const Map& inputs_arg, // Update target host for all targets CheckAndUpdateHostConsistency(&inputs, &target_host); - // Take the attrs from the first module so the eventual modules have them. - // Ideally this would just be one unified module all the way through; - IRModule first_module = (*inputs.begin()).second; - IRModule mhost_all = IRModule(Map(), {}, {}, {}, first_module->attrs); - - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - for (const auto& it : inputs) { - if (it.second.defined()) { - const Target& target = it.first; - const IRModule& ir_module = it.second; - auto pair = SplitMixedModule(ir_module, target, target_host); - auto& host_mod = pair.first; - auto& device_mod = pair.second; - - ICHECK(host_mod.defined()) << "The split host module must be defined"; - - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - // We don't want library modules going back into host codegen - // unless they're supposed to. Here if we overrode the target host - // to allow lowering previously we check that it's meant to be placed - // back into the host Module. - bool overrides_host_target = - target->GetTargetDeviceType() == target_host->GetTargetDeviceType(); - bool non_host_target_kind = target->kind != target_host->kind; - if (overrides_host_target && non_host_target_kind) { - device_modules.push_back(codegen::Build(host_mod, it.first)); - } else { - mhost_all->Update(host_mod); + auto has_gpu_function = [](const IRModule& mod) -> bool { + for (const auto& [gvar, func] : mod->functions) { + if (auto target = func->GetAttr(tvm::attr::kTarget)) { + if (target.value()->HasKey("gpu")) { + return true; + } } + } + return false; + }; + + IRModule merged = MergeModules(inputs); + + bool contains_gpu_function_pre = has_gpu_function(merged); + merged = MixedModulePassManager(merged)(merged); + bool contains_gpu_function_post = has_gpu_function(merged); + if (contains_gpu_function_pre && !contains_gpu_function_post) { + DLOG(WARNING) << "Specified GPU targets, " + << "but cannot find device code. Did you forget to bind?"; + } + + Map split = SplitModule(merged); - if (device_mod->functions.size() != 0) { - device_modules.push_back(codegen::Build(device_mod, it.first)); + Map built; + for (const auto& [target, mod] : split) { + built.Set(target, codegen::Build(mod, target)); + } + + auto host_target = [&]() -> Target { + // All targets that contain a kIsEntryFunc=True function + Array targets_with_entry_func; + + // All targets that can run on the CPU and contain at least one + // function without kIsEntryFunc=False. + Array cpu_targets; + for (const auto& [target, mod] : split) { + bool contains_entry_func = false; + bool may_contain_entry_func = false; + for (const auto& [gvar, func] : mod->functions) { + Optional is_entry_func = func->attrs.GetAttr(tvm::tir::attr::kIsEntryFunc); + if (is_entry_func.defined() && is_entry_func.value()->value) { + contains_entry_func = true; + } else if (!is_entry_func.defined()) { + may_contain_entry_func = true; + } + } + + if (contains_entry_func) { + targets_with_entry_func.push_back(target); + } + + if (may_contain_entry_func && target->HasKey("cpu")) { + cpu_targets.push_back(target); } } - } - runtime::Module mhost = codegen::Build(mhost_all, target_host); - for (const auto& it : device_modules) { - if (it.operator->()) { - mhost.Import(it); + if (targets_with_entry_func.size()) { + ICHECK_EQ(targets_with_entry_func.size(), 1) + << "Expected at most one function " + << "annotated with tvm::tir::attr::kIsEntryFunc " + << "(\"" << tvm::tir::attr::kIsEntryFunc << "\"), " + << "but found: " << targets_with_entry_func; + return targets_with_entry_func[0]; + } else if (cpu_targets.size() == 1) { + return cpu_targets[0]; + } else { + LOG(FATAL) << "Could not determine which target is the host. " + << "No function was annotated with tvm::tir::attr::kIsEntryFunc (\"" + << tvm::tir::attr::kIsEntryFunc << "\"), " + << "and " << cpu_targets.size() << " targets have the 'cpu' key"; + } + }(); + + auto runtime_module = built[host_target]; + for (const auto& [target, mod] : built) { + if (!mod.same_as(runtime_module)) { + runtime_module.Import(mod); } } - - return mhost; + return runtime_module; } TVM_REGISTER_GLOBAL("driver.tir_to_runtime") @@ -564,13 +590,16 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, return TIRToRuntime(inputs, target_host); } -transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { +transform::Sequential MixedModulePassManager(IRModule mixed_mod, Optional target) { transform::PassContext pass_ctx = transform::PassContext::Current(); Array mixed_pass_list; - // FPComputeLegalize uses the target attrs added by BindTarget, so it must come first - mixed_pass_list.push_back(tir::transform::BindTarget(target)); + // FPComputeLegalize uses the target attrs added by BindTarget, so + // BindTarget must come first. + if (target) { + mixed_pass_list.push_back(tir::transform::BindTarget(target.value())); + } mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); // VerifyVTCMLimit must occur before LowerVtcmAlloc @@ -625,7 +654,28 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); - return transform::Sequential(mixed_pass_list); + // Only applies to the device functions, identified by inspection of + // each function's tvm::attr::kTarget attribute. + mixed_pass_list.push_back(tir::transform::LowerWarpMemory()); + + // Only applies to the host functions, identified by inspection of + // each function's tvm::attr::kTarget attribute. + mixed_pass_list.push_back(tir::transform::LowerTVMBuiltin()); + + // Apply to both host and device functions + mixed_pass_list.push_back(tir::transform::Simplify()); + mixed_pass_list.push_back(tir::transform::LowerCustomDatatypes()); + mixed_pass_list.push_back(tir::transform::LowerIntrin()); + mixed_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + + // Only applies to the host functions, identified by inspection of + // each function's tvm::attr::kTarget attribute. + mixed_pass_list.push_back(tir::transform::CombineContextCall()); + if (pass_ctx->GetConfig("tir.enable_debug", Bool(false)).value()) { + mixed_pass_list.push_back(tir::transform::InstallDebugSpans()); + } + + return transform::Sequential(mixed_pass_list, "tvm.build"); } TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") @@ -634,6 +684,10 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") }); transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { + LOG(WARNING) << "Use of driver.host_mod_passes is deprecated. " + << "All lowering passes are now included " + << "as part of driver.mixed_mod_passes."; + transform::PassContext pass_ctx = transform::PassContext::Current(); bool enable_debug = pass_ctx->GetConfig("tir.enable_debug", Bool(false)).value(); @@ -659,7 +713,7 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho host_pass_list.push_back(tir::transform::InstallDebugSpans()); } - return transform::Sequential(host_pass_list); + return transform::Sequential(host_pass_list, "tir.host_mod_passes"); } TVM_REGISTER_GLOBAL("driver.host_mod_passes") @@ -668,6 +722,10 @@ TVM_REGISTER_GLOBAL("driver.host_mod_passes") }); transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) { + LOG(WARNING) << "Use of driver.device_mod_passes is deprecated. " + << "All lowering passes are now included " + << "as part of driver.mixed_mod_passes."; + Array device_pass_list; runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == @@ -683,7 +741,7 @@ transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); device_pass_list.push_back(tir::transform::LowerIntrin()); - return transform::Sequential(device_pass_list); + return transform::Sequential(device_pass_list, "tir.device_mod_passes"); } TVM_REGISTER_GLOBAL("driver.device_mod_passes") diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 2b037181653c..90c0fd41dc7c 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -64,7 +64,7 @@ class ConvertAddToSubtract : public MixedModeMutator { explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) : ir_module_(ir_module), host_target_(host_target), - custom_target_(Target("example_target_hook")) {} + custom_target_(Target(Target("example_target_hook"), Target("example_target_hook"))) {} IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 34bbb6a0c6a9..f6b4ea0624ef 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -618,12 +618,16 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name, return nullptr; } -TVM_REGISTER_GLOBAL("target.build.llvm") - .set_body_typed([](IRModule mod, Target target) -> runtime::Module { - auto n = make_object(); - n->Init(mod, target); - return runtime::Module(n); - }); +namespace { +runtime::Module BuildLLVM(IRModule mod, Target target) { + auto n = make_object(); + n->Init(mod, target); + return runtime::Module(n); +} +} // namespace + +TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed(BuildLLVM); +TVM_REGISTER_GLOBAL("target.build.ext_dev").set_body_typed(BuildLLVM); TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 3e013492efc2..297b5d2ad8a6 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -90,7 +90,15 @@ class CodeGenCHost : public CodeGenC { Array function_names_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; - /*! \brief whether to emit forwared function declarations in the resulting C code */ + /*! \brief whether to emit forwared function declarations in the resulting C code + * + * Determines the behavior when encountering an unknown symbol as + * the callee in a `CallNode` whose operation is + * `builtin::call_extern`. If true, the unknown symbol will be + * forward-declared as a function, derived from the TIR types of + * CallNode's argument/return value. If false, the forward + * declaration is omitted. + */ bool emit_fwd_func_decl_; FunctionInfo GetFunctionInfo(const CallNode* op, bool has_resource_handle); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 979b755af846..7a1d2b07508b 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -444,7 +444,7 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break .set_default_keys({"cpu"}); -TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev); +TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev).set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index dad4ea98d614..466943df2992 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -97,6 +97,11 @@ Type GetType(const PrimExpr& expr) { return PointerType(PrimType(address->dtype)); } } + + if (expr.as()) { + return PointerType(PrimType(DataType::Int(8))); + } + // Default: return the type indicated by the dtype. runtime::DataType dtype = expr.dtype(); return GetTypeFromRuntimeDataType(dtype); diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index a81af7d7805b..67980a934219 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -29,31 +29,131 @@ #include #include +#include +#include +#include + namespace tvm { namespace tir { -class DeviceRegionAnnotater : public StmtMutator { +class DeviceRegionAnnotater : public StmtExprMutator { + using Parent = StmtExprMutator; + public: + static Stmt Apply(Target host_target, Target device_target, Stmt body) { + bool same_host_and_device = host_target->str() == device_target->str(); + if (same_host_and_device) { + return body; + } + + DeviceRegionAnnotater mutator(device_target); + body = mutator(body); + + // If no region was found that must be on the device, but the + // device and host differ (e.g. `T.target('c', host='llvm')`), + // then the entire region should be annotated. This preserves the + // host-side handling of DLTensor arguments, while ensuring that + // any device targets are used for the codegen. + if (mutator.current_region_ == Region::Either && !same_host_and_device) { + body = AttrStmt(device_target, tvm::attr::kTarget, 0, body); + } + + return body; + } + + private: explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tvm::attr::kTarget) { // If a target attribute already exists, use it as-is. + current_region_ = Region::Device; return GetRef(op); } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { // These attributes are only allowed in device-side code, so // they should be annotated with the function's default target. + current_region_ = Region::Device; Stmt body = GetRef(op); return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); } else { // All other annotations are ignored - return StmtMutator::VisitStmt_(op); + return Parent::VisitStmt_(op); } } - private: + Stmt VisitStmt_(const SeqStmtNode* op) final { + std::vector regions; + Array seq = op->seq.Map([&](Stmt stmt) { + current_region_ = Region::Either; + stmt = VisitStmt(stmt); + regions.push_back(current_region_); + return stmt; + }); + + bool has_host_function = std::any_of(regions.begin(), regions.end(), + [](const auto& reg) { return reg == Region::Host; }); + if (has_host_function) { + current_region_ = Region::Host; + + Array new_seq; + Array device_seq; + auto finish_device_seq = [&]() { + if (device_seq.size()) { + new_seq.push_back( + AttrStmt(device_target_, tvm::attr::kTarget, 0, SeqStmt::Flatten(device_seq))); + device_seq.clear(); + } + }; + + for (size_t i = 0; i < seq.size(); i++) { + if (regions[i] == Region::Host) { + finish_device_seq(); + new_seq.push_back(seq[i]); + } else { + device_seq.push_back(seq[i]); + } + } + finish_device_seq(); + + return SeqStmt::Flatten(new_seq); + } else if (seq.same_as(op->seq)) { + return GetRef(op); + } else { + return SeqStmt(seq); + } + } + + PrimExpr VisitExpr_(const CallNode* op) final { + // TODO(Lunderberg): Make a new attribute in builtin.cc to label + // host-only operations. + bool is_host_only_op = + op->op.same_as(builtin::tvm_call_packed()) || op->op.same_as(builtin::tvm_call_cpacked()) || + op->op.same_as(builtin::tvm_call_packed_lowered()) || + op->op.same_as(builtin::tvm_call_cpacked_lowered()) || + op->op.same_as(builtin::anylist_getitem()) || + op->op.same_as(builtin::anylist_resetitem()) || + op->op.same_as(builtin::anylist_setitem_call_packed()) || + op->op.same_as(builtin::anylist_setitem_call_cpacked()) || + op->op.same_as(builtin::tvm_struct_get()) || op->op.same_as(builtin::tvm_struct_set()) || + op->op.same_as(builtin::tvm_throw_last_error()) || + op->op.same_as(builtin::tvm_stack_alloca()) || + op->op.same_as(builtin::tvm_stack_make_shape()) || + op->op.same_as(builtin::tvm_stack_make_array()); + if (is_host_only_op) { + current_region_ = Region::Host; + } + return Parent::VisitExpr_(op); + } + Target device_target_; + + enum class Region { + Either, + Host, + Device, + }; + Region current_region_{Region::Either}; }; namespace transform { @@ -64,9 +164,12 @@ Pass AnnotateDeviceRegions() { ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; Target target = opt_target.value(); - if (target->GetHost()) { - DeviceRegionAnnotater mutator(target.WithoutHost()); - func.CopyOnWrite()->body = mutator(func->body); + if (auto opt_host = target->GetHost()) { + auto new_body = + DeviceRegionAnnotater::Apply(opt_host.value(), target.WithoutHost(), func->body); + if (!new_body.same_as(func->body)) { + func.CopyOnWrite()->body = new_body; + } } return func; }; diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 932116485fa1..a33376bd69ee 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -260,7 +260,9 @@ class DeviceKernelMutator : public StmtExprMutator { bool same_device_type = caller_target->GetTargetDeviceType() == callee_target->GetTargetDeviceType(); - if (same_device_type) { + bool linkable_module = (caller_target->GetTargetDeviceType() == kDLCPU) && + (callee_target->GetTargetDeviceType() == kDLExtDev); + if (same_device_type || linkable_module) { // Calls to another target using the same device (e.g. LLVM // calling a custom TIRToRuntime target) do not require a kernel // launch, but need to be replaced with call_extern. diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 212ccf6e5616..fbc5d4fda92d 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -44,6 +44,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") : IRMutatorWithAnalyzer(analyzer) { + if (target == "ext_dev") { + target = "llvm"; + } + std::vector patterns; patterns.push_back(target + ".FLowerIntrinsic"); patterns.push_back(target + ".FLegalize"); diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index c90384fea73a..abc7ce91efb0 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -97,7 +97,8 @@ class HostDeviceSplitter : public StmtMutator { PrimFunc device_func(params, body, kernel_ret_type); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, {tir::attr::kNoAlias, Bool(true)}, - {tir::attr::kIsGlobalFunc, Bool(true)}}); + {tir::attr::kIsGlobalFunc, Bool(true)}, + {tir::attr::kIsEntryFunc, Bool(false)}}); GlobalVar kernel_symbol_global = var_supply_(); (*device_mod_)->Add(kernel_symbol_global, device_func); diff --git a/tests/python/tir-transform/test_tir_transform_annotate_device_regions.py b/tests/python/tir-transform/test_tir_transform_annotate_device_regions.py index efa43027e9c6..7b869ddf7694 100644 --- a/tests/python/tir-transform/test_tir_transform_annotate_device_regions.py +++ b/tests/python/tir-transform/test_tir_transform_annotate_device_regions.py @@ -54,5 +54,76 @@ def expected(A: T.Buffer(1, "float32")): A[0] = 0.0 +class TestAnnotateEntireBody(BaseCompare): + """Annotation inserted to wrap entire function + + Function is assumed to belong on the device. + """ + + def before(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + A[0] = 0.0 + + def expected(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.attr(T.target("cuda"), "target", 0) + A[0] = 0.0 + + +class TestNoAnnotationForSameHostDevice(BaseCompare): + """No annotation is needed if host/device are the same""" + + def before(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm", host="llvm")}) + A[0] = 0.0 + + expected = before + + +class TestAnnotationAvoidsHostConstructs(BaseCompare): + """Device annotation does not contain host-only functions + + Calls that must be on the host side (e.g. T.call_packed) remain on + the host. + """ + + def before(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + A[0] = 0.0 + T.call_packed("dummy_function", dtype="void") + + def expected(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + with T.attr(T.target("cuda"), "target", 0): + A[0] = 0.0 + T.call_packed("dummy_function", dtype="void") + + +class TestAnnotationNoRepetition(BaseCompare): + """Device annotation does not contain host-only functions + + When placing everything that isn't a host-specific function into + target block, sequential device statements should be in the same + block. + """ + + def before(A: T.Buffer(2, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + A[0] = 0.0 + A[1] = 1.0 + T.call_packed("dummy_function", dtype="void") + + def expected(A: T.Buffer(2, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + with T.attr(T.target("cuda"), "target", 0): + A[0] = 0.0 + A[1] = 1.0 + T.call_packed("dummy_function", dtype="void") + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_split_host_device.py b/tests/python/tir-transform/test_tir_transform_split_host_device.py index 2d0d8a68d83e..fde0940eb707 100644 --- a/tests/python/tir-transform/test_tir_transform_split_host_device.py +++ b/tests/python/tir-transform/test_tir_transform_split_host_device.py @@ -123,6 +123,7 @@ def main_kernel(n: T.int32): "target": T.target("cuda"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) @@ -160,6 +161,7 @@ def main_kernel(n: T.int32) -> T.int32: "target": T.target("llvm"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) @@ -201,6 +203,7 @@ def main_kernel(n: T.int32): "target": T.target("cuda"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) @@ -262,6 +265,7 @@ def main_kernel_1(n: T.int32): "target": T.target("cuda"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) @@ -329,6 +333,7 @@ def default_function_kernel( T.func_attr( { "target": T.target("cuda"), + "tir.is_entry_func": False, "tir.is_global_func": True, "tir.noalias": True, } diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index ae83a9d66392..95df5e156e5a 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -203,7 +203,14 @@ def _post_order(op): ), op.body, ) - alloc = tvm.tir.Allocate(buffer_var, op.dtype, op.extents, op.condition, let_stmt) + alloc = tvm.tir.Allocate( + buffer_var, + op.dtype, + op.extents, + op.condition, + let_stmt, + annotations={"disable_lower_builtin": True}, + ) del var_remap[buffer_var] bufs_to_delete = [ old_buf for old_buf in buf_remap if old_buf.data.same_as(buffer_var) diff --git a/vta/scripts/tune_resnet.py b/vta/scripts/tune_resnet.py new file mode 100644 index 000000000000..3f5c693b78a0 --- /dev/null +++ b/vta/scripts/tune_resnet.py @@ -0,0 +1,373 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Perform ResNet autoTVM tuning on VTA using Relay.""" + +import argparse, os, time +from mxnet.gluon.model_zoo import vision +import numpy as np +from PIL import Image + +from tvm import topi +import tvm +from tvm import te +from tvm import rpc, autotvm, relay +from tvm.autotvm.measure.measure_methods import request_remote +from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner +from tvm.contrib import graph_executor, utils, download +from tvm.contrib.debugger import debug_executor +import vta +from vta.testing import simulator +from vta.top import graph_pack +from tvm.autotvm.task import extract_from_program + + +def parse_arguments(): + + parser = argparse.ArgumentParser(description="Train a model for image classification.") + parser.add_argument( + "--model", + type=str, + default="resnet18_v1", + choices=["resnet18_v1"], + help="Input model name.", + ) + parser.add_argument( + "--start-name", + type=str, + default="nn.max_pool2d", + help="The name of the node where packing starts", + ) + parser.add_argument( + "--stop-name", + type=str, + default="nn.global_avg_pool2d", + help="The name of the node where packing stops", + ) + parser.add_argument( + "--debug-profile", action="store_true", help="Show layer-wise time cost profiling results" + ) + parser.add_argument( + "--device", default="vta", choices=["vta", "arm_cpu"], help="Select device target" + ) + parser.add_argument( + "--measurements", type=int, default=1, help="Number of measurements during AutoTVM search" + ) + parser.add_argument("--tuner", type=str, default="random", help="AutoTVM search strategy") + parser.add_argument( + "--log-filename", type=str, default="resnet-18.log", help="AutoTVM log file name" + ) + + return parser.parse_args() + + +def register_vta_tuning_tasks(): + from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args + + @tvm.te.tag_scope(tag=topi.tag.ELEMWISE) + def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.tir.const(a_min, x.dtype) + const_max = tvm.tir.const(a_max, x.dtype) + x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA") + x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB") + return x + + # init autotvm env to register VTA operator + TaskExtractEnv() + + @autotvm.task.register("topi_nn_conv2d", override=True) + def _topi_nn_conv2d(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + + with tvm.target.vta(): + res = topi.nn.conv2d(*args, **kwargs) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.Target.current().device_name == "vta": + s = topi.generic.schedule_conv2d_nchw([res]) + else: + s = te.create_schedule([res.op]) + return s, [A, W, res] + + @autotvm.task.register("topi_nn_dense", override=True) + def _topi_nn_dense(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + + with tvm.target.vta(): + res = topi.nn.dense(*args, **kwargs) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.Target.current().device_name == "vta": + s = topi.generic.schedule_dense([res]) + else: + s = te.create_schedule([res.op]) + + return s, [A, W, res] + + +def compile_network(opt, env, target): + + # Populate the shape and data type dictionary + dtype_dict = {"data": "float32"} + shape_dict = {"data": (env.BATCH, 3, 224, 224)} + + # Get off the shelf gluon model, and convert to relay + gluon_model = vision.get_model(opt.model, pretrained=True) + mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict) + + # Update shape and type dictionary + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Perform quantization in Relay + # Note: We set opt_level to 3 in order to fold batch norm + with tvm.transform.PassContext(opt_level=3): + with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]): + relay_prog = relay.quantize.quantize(mod["main"], params=params) + + # Perform graph packing and constant folding for VTA target + if target.device_name == "vta": + assert env.BLOCK_IN == env.BLOCK_OUT + relay_prog = graph_pack( + relay_prog, + env.BATCH, + env.BLOCK_OUT, + env.WGT_WIDTH, + start_name=opt.start_name, + stop_name=opt.stop_name, + ) + + return relay_prog, params + + +def tune_tasks( + tasks, + measure_option, + tuner="xgb", + n_trial=1000, + early_stopping=None, + log_filename="tuning.log", + use_transfer_learning=True, + try_winograd=True, +): + + # create tmp log file + tmp_log_file = log_filename + ".tmp" + if os.path.exists(tmp_log_file): + os.remove(tmp_log_file) + + for i, tsk in enumerate(reversed(tasks)): + prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) + + # create tuner + if tuner == "xgb": + tuner_obj = XGBTuner(tsk, loss_type="reg") + elif tuner == "xgb_knob": + tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="knob") + elif tuner == "xgb_itervar": + tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="itervar") + elif tuner == "xgb_curve": + tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="curve") + elif tuner == "xgb_rank": + tuner_obj = XGBTuner(tsk, loss_type="rank") + elif tuner == "xgb_rank_knob": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="knob") + elif tuner == "xgb_rank_itervar": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="itervar") + elif tuner == "xgb_rank_curve": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="curve") + elif tuner == "xgb_rank_binary": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary") + elif tuner == "xgb_rank_binary_knob": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="knob") + elif tuner == "xgb_rank_binary_itervar": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="itervar") + elif tuner == "xgb_rank_binary_curve": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="curve") + elif tuner == "ga": + tuner_obj = GATuner(tsk, pop_size=50) + elif tuner == "random": + tuner_obj = RandomTuner(tsk) + elif tuner == "gridsearch": + tuner_obj = GridSearchTuner(tsk) + else: + raise ValueError("Invalid tuner: " + tuner) + + if use_transfer_learning: + if os.path.isfile(tmp_log_file): + tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) + + # do tuning + n_trial_ = min(n_trial, len(tsk.config_space)) + tuner_obj.tune( + n_trial_, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(n_trial_, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file), + ], + ) + + # pick best records to a cache file + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) + + +if __name__ == "__main__": + + opt = parse_arguments() + + # Make sure that TVM was compiled with RPC=1 + assert tvm.runtime.enabled("rpc") + + # Read in VTA environment + env = vta.get_env() + + # Get remote from fleet node + tracker_host = os.environ.get("TVM_TRACKER_HOST", None) + tracker_port = os.environ.get("TVM_TRACKER_PORT", None) + if not tracker_host or not tracker_port: + print("Set your AutoTVM tracker node host and port variables to run the autotuner") + exit() + + # Get remote + if env.TARGET != "sim": + + # Measure build start time + reconfig_start = time.time() + + # Get remote from fleet node + remote = autotvm.measure.request_remote( + env.TARGET, tracker_host, int(tracker_port), timeout=10000 + ) + + # Reconfigure the JIT runtime and FPGA. + # You can program the FPGA with your own custom bitstream + # by passing the path to the bitstream file instead of None. + vta.reconfig_runtime(remote) + vta.program_fpga(remote, bitstream=None) + + # Report on reconfiguration time + reconfig_time = time.time() - reconfig_start + print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) + + # In simulation mode, host the RPC server locally. + else: + remote = rpc.LocalSession() + + # VTA target and execution context + target = env.target if opt.device == "vta" else env.target_vta_cpu + ctx = remote.ext_dev(0) if opt.device == "vta" else remote.cpu(0) + + # Compile Relay program + print("Initial compile...") + relay_prog, params = compile_network(opt, env, target) + + # Register VTA tuning tasks + register_vta_tuning_tasks() + + # Perform task extraction on Relay program + print("Extracting tasks...") + tasks = extract_from_program( + func=relay_prog, + params=params, + ops=(relay.op.get("nn.conv2d"),), + target=tvm.target.Target(target, host=env.target_host), + ) + + # Perform Autotuning + print("Tuning...") + tuning_opt = { + "log_filename": opt.log_filename, + "tuner": opt.tuner, + "n_trial": 1e9, + "early_stopping": None, + "measure_option": autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), + runner=autotvm.RPCRunner( + env.TARGET, + tracker_host, + tracker_port, + number=4, + min_repeat_ms=150, + repeat=opt.measurements, + timeout=60, + # check_correctness=True, # TODO: re-enable when check_correctness works again. + ), + ), + } + tune_tasks(tasks, **tuning_opt) + + # Compile kernels with history best records + with autotvm.tophub.context(target, extra_files=[opt.log_filename]): + + # Compile network + print("Compiling network with best tuning parameters...") + if target.device_name != "vta": + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): + graph, lib, params = relay.build( + relay_prog, + target=tvm.target.Target(target, host=env.target_host), + params=params, + ) + else: + with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + graph, lib, params = relay.build( + relay_prog, + target=tvm.target.Target(target, host=env.target_host), + params=params, + ) + + # Export library + temp = utils.tempdir() + lib.export_library(temp.relpath("graphlib.so")) + remote.upload(temp.relpath("graphlib.so")) + lib = remote.load_module("graphlib.so") + + # If detailed runtime info is needed build with debug runtime + if opt.debug_profile: + m = debug_executor.create(graph, lib, ctx) + else: + m = graph_executor.create(graph, lib, ctx) + + # Set the network parameters and synthetic input + image = tvm.nd.array((np.random.uniform(size=(1, 3, 224, 224))).astype("float32")) + m.set_input(**params) + m.set_input("data", image) + + # Perform inference + timer = m.module.time_evaluator("run", ctx, number=4, repeat=opt.measurements) + tcost = timer() + prof_res = np.array(tcost.results) * 1000 # convert to millisecond + print( + "Mean inference time (std dev): %.2f ms (%.2f ms)" + % (np.mean(prof_res), np.std(prof_res)) + ) + + # Display profile information + if opt.debug_profile: + m.run() diff --git a/vta/tutorials/matrix_multiply.py b/vta/tutorials/matrix_multiply.py index 0d1167854458..1d1dd98dfaf3 100644 --- a/vta/tutorials/matrix_multiply.py +++ b/vta/tutorials/matrix_multiply.py @@ -392,13 +392,13 @@ # Write the compiled module into an object file. temp = utils.tempdir() -my_gemm.save(temp.relpath("gemm.o")) +my_gemm.export_library(temp.relpath("gemm.so")) # Send the executable over RPC -remote.upload(temp.relpath("gemm.o")) +remote.upload(temp.relpath("gemm.so")) # Load the compiled module -f = remote.load_module("gemm.o") +f = remote.load_module("gemm.so") ###################################################################### # Running the Function diff --git a/vta/tutorials/optimize/convolution_opt.py b/vta/tutorials/optimize/convolution_opt.py index 521a73ab510d..3c757fdc0c2b 100644 --- a/vta/tutorials/optimize/convolution_opt.py +++ b/vta/tutorials/optimize/convolution_opt.py @@ -374,9 +374,9 @@ s, [data, kernel, res], tvm.target.Target("ext_dev", host=env.target_host), name="my_conv" ) temp = utils.tempdir() -my_conv.save(temp.relpath("conv2d.o")) -remote.upload(temp.relpath("conv2d.o")) -f = remote.load_module("conv2d.o") +my_conv.export_library(temp.relpath("conv2d.so")) +remote.upload(temp.relpath("conv2d.so")) +f = remote.load_module("conv2d.so") # Get the remote device context ctx = remote.ext_dev(0) diff --git a/vta/tutorials/optimize/matrix_multiply_opt.py b/vta/tutorials/optimize/matrix_multiply_opt.py index b470475b16e7..ea70b5260c56 100644 --- a/vta/tutorials/optimize/matrix_multiply_opt.py +++ b/vta/tutorials/optimize/matrix_multiply_opt.py @@ -314,9 +314,9 @@ s, [data, weight, res], tvm.target.Target("ext_dev", host=env.target_host), name="my_gemm" ) temp = utils.tempdir() -my_gemm.save(temp.relpath("gemm.o")) -remote.upload(temp.relpath("gemm.o")) -f = remote.load_module("gemm.o") +my_gemm.export_library(temp.relpath("gemm.so")) +remote.upload(temp.relpath("gemm.so")) +f = remote.load_module("gemm.so") # Get the remote device context ctx = remote.ext_dev(0) diff --git a/vta/tutorials/vta_get_started.py b/vta/tutorials/vta_get_started.py index 3482258dece8..6edb34184fb4 100644 --- a/vta/tutorials/vta_get_started.py +++ b/vta/tutorials/vta_get_started.py @@ -327,17 +327,17 @@ # Write the compiled module into an object file. temp = utils.tempdir() -my_vadd.save(temp.relpath("vadd.o")) +my_vadd.export_library(temp.relpath("vadd.so")) # Send the executable over RPC -remote.upload(temp.relpath("vadd.o")) +remote.upload(temp.relpath("vadd.so")) ###################################################################### # Loading the Module # ~~~~~~~~~~~~~~~~~~ # We can load the compiled module from the file system to run the code. -f = remote.load_module("vadd.o") +f = remote.load_module("vadd.so") ###################################################################### # Running the Function From e7bebaf9b56cf17ccc2e75e001df9feed32c0768 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 11 Sep 2024 10:37:35 -0500 Subject: [PATCH 2/4] Apply InlinePrivateFunctions to avoid requiring TIR-to-TIR codegen --- src/driver/driver_api.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e6335449b1c2..3d6112d59fa6 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -654,6 +654,10 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Optional Date: Fri, 5 May 2023 11:37:59 -0500 Subject: [PATCH 3/4] [TIR] Output DeclBuffer in FlattenBuffer If a flattened buffer is produced for use in `BufferLoad` and `BufferStore` statements, generate a `DeclBuffer`. This is a subset of the changes made in https://github.com/apache/tvm/pull/14778, broken out for ease of testing and review. --- src/tir/transforms/flatten_buffer.cc | 51 +++++++++--- .../test_tir_transform_flatten_buffer.py | 81 +++++-------------- 2 files changed, 62 insertions(+), 70 deletions(-) diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index c04e12b8395e..0d8cc8553c95 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -41,13 +41,29 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { static PrimFunc Flatten(PrimFunc func) { arith::Analyzer ana; auto pass = BufferFlattener(&ana); - auto writer = func.CopyOnWrite(); pass.MarkBufferMapShapes(func); - writer->body = pass.VisitStmt(func->body); + auto body = pass.VisitStmt(func->body); + // The buffers in func->buffer_map are deliberately left // unflattened, as they are used for validation of user-provided // arguments. The flattened buffers used in the updated // function body alias the argument buffers. + for (size_t i = func->params.size(); i > 0; i--) { + auto handle = func->params[i - 1]; + if (auto opt = func->buffer_map.Get(handle)) { + auto old_buf = opt.value(); + if (pass.buffers_used_.count(old_buf)) { + auto new_buf = pass.GetFlattenedBuffer(old_buf); + if (!old_buf.same_as(new_buf)) { + body = DeclBuffer(new_buf, std::move(body)); + } + } + } + } + + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = std::move(body); + } return func; } @@ -153,11 +169,14 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { } Stmt VisitStmt_(const DeclBufferNode* op) final { - // TODO(rfc-70): Update the DeclBuffer node instead of - // stripping it out. Stripping it out in the current - // implementation as not all lowering passes support - // DeclBuffer. - return VisitStmt(op->body); + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + + auto new_buf = GetFlattenedBuffer(node->buffer); + if (!node->buffer.same_as(new_buf)) { + node.CopyOnWrite()->buffer = new_buf; + } + + return std::move(node); } Buffer GetFlattenedBuffer(Buffer buf) { @@ -166,16 +185,23 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { return it->second; } auto flattened = buf.GetFlattenedBuffer(); - auto writer = flattened.CopyOnWrite(); // TODO(Lunderberg): Move the handling of boolean into a // dedicated pass. if (flattened->dtype == DataType::Bool()) { - writer->dtype = DataType::Int(8); + flattened.CopyOnWrite()->dtype = DataType::Int(8); } // canonicalize shape - for (size_t i = 0; i < flattened->shape.size(); ++i) { - writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i])); + bool shape_is_changed = false; + Array new_shape; + for (const auto& dim : flattened->shape) { + auto new_dim = analyzer_->canonical_simplify(dim); + shape_is_changed = shape_is_changed || !StructuralEqual()(dim, new_dim); + new_shape.push_back(new_dim); + } + + if (shape_is_changed) { + flattened.CopyOnWrite()->shape = std::move(new_shape); } buffer_remap_[buf] = flattened; @@ -226,6 +252,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { template Node VisitBufferAccess(Node node) { ICHECK(node->buffer.defined()); + buffers_used_.insert(node->buffer); auto flattened_indices = GetSimplifiedElemOffset(node->buffer, node->indices); Buffer flattened_buffer = GetFlattenedBuffer(node->buffer); @@ -264,6 +291,8 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { /*! \brief Map of buffers being remapped. */ std::unordered_map buffer_remap_; + std::unordered_set buffers_used_; + /*! \brief The updated external buffer map. */ Map updated_extern_buffer_map_; }; diff --git a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py index 20f91b639497..cb29e79160c9 100644 --- a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py @@ -41,42 +41,10 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): C[i, j] = B_new[0, j] * 2.0 def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) + A = T.decl_buffer(256, dtype="float32", data=input_A.data) + C = T.decl_buffer(256, dtype="float32", data=input_C.data) for i in T.serial(0, 16): - B_new_data = T.allocate([16], "float32", scope="global") - B_new = T.Buffer([16], "float32", scope="global", data=B_new_data) - for j in T.serial(0, 16): - B_new[j] = A[((i * 16) + j)] + 1.0 - for j in T.serial(0, 16): - C[((i * 16) + j)] = B_new[j] * 2.0 - - -class TestElementwiseWithoutDeclBuffer(BaseCompare): - """2-d buffers are flattened to 1-d - - Like TestElementwise, but the TIR doesn't have the DeclBuffer - node. The T.Buffer declaration applies only during the - parsing the TVMScript, and doesn't occur in the TIR itself. In - this case, the allocation should be assumed to be targeting flat - memory, and should be flattened to a 1-d allocation. - """ - - def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): - for i in T.serial(0, 16): - B_new_data = T.allocate([1, 16], "float32", "global") - B_new = T.Buffer([1, 16], "float32", data=B_new_data) - for j in T.serial(0, 16): - B_new[0, j] = A[i, j] + 1.0 - for j in T.serial(0, 16): - C[i, j] = B_new[0, j] * 2.0 - - def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) - for i in T.serial(0, 16): - B_new_data = T.allocate([16], "float32", "global") - B_new = T.Buffer(16, "float32", data=B_new_data) + B_new = T.decl_buffer(16, "float32", scope="global") for j in T.serial(0, 16): B_new[j] = A[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): @@ -101,8 +69,8 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) + A = T.decl_buffer(256, dtype="float32", data=input_A.data) + C = T.decl_buffer(256, dtype="float32", data=input_C.data) i0 = T.env_thread("blockIdx.x") i1 = T.env_thread("threadIdx.x") @@ -111,8 +79,7 @@ def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B_data = T.allocate([16], "float32", scope="local") - B = T.Buffer([16], "float32", scope="local", data=B_data) + B = T.decl_buffer(16, "float32", scope="local") for j in range(0, 16): B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): @@ -136,12 +103,11 @@ def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: input_A = T.match_buffer(a, (n, m), "float32") input_C = T.match_buffer(c, (n, m), "float32") - A = T.Buffer(n * m, "float32", data=input_A.data) - C = T.Buffer(n * m, "float32", data=input_C.data) + A = T.decl_buffer(n * m, "float32", data=input_A.data) + C = T.decl_buffer(n * m, "float32", data=input_C.data) for i in range(0, n): - B_data = T.allocate([m], "float32", scope="global") - B = T.Buffer([m], "float32", scope="global", data=B_data) + B = T.decl_buffer(m, "float32", scope="global") for j in range(0, m): B[j] = A[i * m + j] + 1.0 for j in range(0, m): @@ -161,8 +127,8 @@ def before(a: T.handle, b: T.handle, n: T.int32) -> None: def expected(a: T.handle, b: T.handle, n: T.int32) -> None: input_A = T.match_buffer(a, (32, n, n), "float32") input_B = T.match_buffer(b, (32, n, n), "float32") - A = T.Buffer(n * n * 32, "float32", data=input_A.data) - B = T.Buffer(n * n * 32, "float32", data=input_B.data) + A = T.decl_buffer(n * n * 32, "float32", data=input_A.data) + B = T.decl_buffer(n * n * 32, "float32", data=input_B.data) for i in range(0, n * n * 32): B[i] = A[i] @@ -185,8 +151,8 @@ def before(a: T.handle, b: T.handle, n: T.int32) -> None: def expected(a: T.handle, b: T.handle, n: T.int32) -> None: input_A = T.match_buffer(a, (32, n, n), "float32") input_B = T.match_buffer(b, (32, n, n), "float32") - A = T.Buffer(n * n * 32, "float32", data=input_A.data) - B = T.Buffer(n * n * 32, "float32", data=input_B.data) + A = T.decl_buffer(n * n * 32, "float32", data=input_A.data) + B = T.decl_buffer(n * n * 32, "float32", data=input_B.data) for bx, tx in T.grid((n * n + 1) // 2, 64): if bx * 64 + tx < n * n * 32: @@ -205,14 +171,12 @@ def before(A: T.Buffer((4, 32), "float32"), D: T.Buffer((4, 32), "float32")): D[i, j] = C[i, j] * 2.0 def expected(input_A: T.Buffer((4, 32), "float32"), input_D: T.Buffer((4, 32), "float32")): - A = T.Buffer(128, "float32", data=input_A.data) - D = T.Buffer(128, "float32", data=input_D.data) + A = T.decl_buffer(128, "float32", data=input_A.data) + D = T.decl_buffer(128, "float32", data=input_D.data) for i, j in T.grid(4, 32): - B_data = T.allocate([128], "float32", scope="global") - B = T.Buffer([128], "float32", scope="global", data=B_data) - C_data = T.allocate([128], "float32", scope="global") - C = T.Buffer([128], "float32", scope="global", data=C_data) + B = T.decl_buffer(128, "float32", scope="global") + C = T.decl_buffer(128, "float32", scope="global") B[i * 32 + j] = A[i * 32 + j] + 1.0 C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] D[i * 32 + j] = C[i * 32 + j] * 2.0 @@ -231,11 +195,10 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")): - A = T.Buffer(256, dtype="float32", data=input_A.data) - C = T.Buffer(256, dtype="float32", data=input_C.data) + A = T.decl_buffer(256, dtype="float32", data=input_A.data) + C = T.decl_buffer(256, dtype="float32", data=input_C.data) for i0 in T.serial(0, 4): - B_new_data = T.allocate([68], "float32", scope="global") - B_new = T.Buffer([68], "float32", scope="global", data=B_new_data) + B_new = T.decl_buffer(68, "float32", scope="global") for i1 in T.serial(0, 4): for j in T.serial(0, 16): B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 @@ -252,8 +215,8 @@ def before(A: T.Buffer(10, "bool"), B: T.Buffer(10, "bool")) -> None: B[i0] = A[i0] def expected(input_A: T.Buffer(10, "bool"), input_B: T.Buffer(10, "bool")) -> None: - A = T.Buffer(10, dtype="int8", data=input_A.data) - B = T.Buffer(10, dtype="int8", data=input_B.data) + A = T.decl_buffer(10, dtype="int8", data=input_A.data) + B = T.decl_buffer(10, dtype="int8", data=input_B.data) # body for i0 in T.serial(10): B[i0] = T.cast(T.cast(A[i0], "bool"), "int8") From 021a2d19bf7145e555d13468188705792350fc5e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 9 May 2023 14:26:01 -0500 Subject: [PATCH 4/4] [TIR] Output DeclBuffer nodes during StorageFlatten When producing a flattened buffer for use in `BufferLoad` and `BufferStore` nodes, generate a `DeclBuffer` for the flattened buffer. This is a subset of the changes made in https://github.com/apache/tvm/pull/14778, broken out for ease of testing and review. --- src/tir/transforms/storage_flatten.cc | 29 ++++++++++++++--- tests/python/te/test_te_build_lower.py | 2 +- tests/python/te/test_te_hybrid_script.py | 4 +++ tests/python/te/test_te_schedule.py | 2 +- tests/python/tir-base/test_lower_build.py | 17 ++++------ .../test_tir_transform_flatten_buffer.py | 3 +- .../test_tir_transform_loop_partition.py | 32 +++++++++++-------- .../test_tir_transform_narrow_datatype.py | 6 +++- .../test_tir_transform_storage_flatten.py | 10 +++--- 9 files changed, 66 insertions(+), 39 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 06554f5f1dd1..2025c2bde481 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1346,12 +1346,30 @@ class StorageFlattener : public StmtExprMutator { auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer); - auto fptr = func.CopyOnWrite(); - fptr->body = pass(std::move(fptr->body)); + Stmt body = pass(func->body); + + for (size_t i = func->params.size(); i > 0; i--) { + auto handle = func->params[i - 1]; + if (auto opt = func->buffer_map.Get(handle)) { + auto old_buf = opt.value(); + if (pass.buf_map_.count(old_buf)) { + auto new_buf = pass.GetBufferEntry(old_buf).flattened_buffer; + if (!old_buf.same_as(new_buf)) { + body = DeclBuffer(new_buf, std::move(body)); + } + } + } + } + // The buffers in func->buffer_map are deliberately left // unflattened, as they are used for validation of user-provided // arguments. The flattened buffers used in the updated // function body alias the argument buffers. + + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + } + return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {}); @@ -1550,9 +1568,10 @@ class StorageFlattener : public StmtExprMutator { buffer_var_defines_.erase(op->buffer->data.get()); buf_map_[key].in_scope = false; - Stmt ret = - Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, - make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), body); + Stmt ret = body; + ret = DeclBuffer(e.flattened_buffer, body); + ret = Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, + make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, diff --git a/tests/python/te/test_te_build_lower.py b/tests/python/te/test_te_build_lower.py index 50d5119b43a0..6da7a2df3563 100644 --- a/tests/python/te/test_te_build_lower.py +++ b/tests/python/te/test_te_build_lower.py @@ -56,7 +56,7 @@ def test_split_uneven_unique_likely(): sch = te.create_schedule(c.op) xo, xi = sch[c].split(x, 5) stmt = tvm.lower(sch, [a, b, c])["main"].body - assert isinstance(stmt.body.body, tvm.tir.stmt.IfThenElse) + assert isinstance(stmt.body.body.body.body.body, tvm.tir.stmt.IfThenElse) if __name__ == "__main__": diff --git a/tests/python/te/test_te_hybrid_script.py b/tests/python/te/test_te_hybrid_script.py index 862e80ffb6ce..60a47699d5ce 100644 --- a/tests/python/te/test_te_hybrid_script.py +++ b/tests/python/te/test_te_hybrid_script.py @@ -756,6 +756,8 @@ def outer_product(a, b): sch[c].vectorize(ji) sch[c].reorder(ii, io, joo, joi, ji) ir = tvm.lower(sch, [a, b, c])["main"].body + assert isinstance(ir, tvm.tir.DeclBuffer) + ir = ir.body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) @@ -777,6 +779,8 @@ def outer_product(a, b): sch = te.create_schedule(c.op) sch[c].fuse(c.op.axis[0], c.op.axis[1]) ir = tvm.lower(sch, [a, b, c])["main"].body + assert isinstance(ir, tvm.tir.DeclBuffer) + ir = ir.body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) diff --git a/tests/python/te/test_te_schedule.py b/tests/python/te/test_te_schedule.py index d46db2b702c0..b3690cedc640 100644 --- a/tests/python/te/test_te_schedule.py +++ b/tests/python/te/test_te_schedule.py @@ -325,7 +325,7 @@ def test_legalize_invalid_attach(): s[A].compute_at(s[B], B.op.axis[1]) s[B].fuse(B.op.axis[0], B.op.axis[1]) stmt = tvm.lower(s, [A, B], simple_mode=True)["main"].body - assert isinstance(stmt, tvm.tir.stmt.For) + assert isinstance(stmt.body.body, tvm.tir.stmt.For) def test_compute_at(): diff --git a/tests/python/tir-base/test_lower_build.py b/tests/python/tir-base/test_lower_build.py index 0e610cc1659b..f6a871cb0001 100644 --- a/tests/python/tir-base/test_lower_build.py +++ b/tests/python/tir-base/test_lower_build.py @@ -60,9 +60,9 @@ def main( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) + A_flat = T.decl_buffer(16384, data=A.data) + B_flat = T.decl_buffer(16384, data=B.data) + C_flat = T.decl_buffer(16384, data=C.data) # body for x, y in T.grid(128, 128): C_flat[x * 128 + y] = 0.0 @@ -82,9 +82,9 @@ def main( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) + A_flat = T.decl_buffer(16384, data=A.data) + B_flat = T.decl_buffer(16384, data=B.data) + C_flat = T.decl_buffer(16384, data=C.data) # body for x, y in T.grid(128, 128): C_flat[x * 128 + y] = 0.0 @@ -144,7 +144,4 @@ def test_lower_build_lowered_module(): if __name__ == "__main__": - test_lower_build_te_schedule() - test_lower_build_tir_func() - test_lower_build_tir_module() - test_lower_build_lowered_module() + tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py index cb29e79160c9..a7965e4db423 100644 --- a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py @@ -292,8 +292,7 @@ def before(): T.evaluate(A[i0, i1, i2, i3, i4, i5]) def expected(): - A_data = T.allocate([30, 1001], dtype="float32", scope="global") - A = T.Buffer([30, 1001], dtype="float32", scope="global", axis_separators=[1], data=A_data) + A = T.decl_buffer([30, 1001], axis_separators=[1], dtype="float32", scope="global") for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) diff --git a/tests/python/tir-transform/test_tir_transform_loop_partition.py b/tests/python/tir-transform/test_tir_transform_loop_partition.py index 6468ac5396ef..1ab395f17809 100644 --- a/tests/python/tir-transform/test_tir_transform_loop_partition.py +++ b/tests/python/tir-transform/test_tir_transform_loop_partition.py @@ -17,7 +17,7 @@ import pytest import tvm import tvm.testing -from tvm import te +from tvm import te, tir from tvm.ir.module import IRModule from tvm.script import tir as T import numpy @@ -182,7 +182,11 @@ def test_vectorize(): s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) stmt = tvm.lower(s, [A, B], name="main")["main"] - body = stmt.body.body.body.body + + body = stmt + while not isinstance(body, tir.IfThenElse): + body = body.body + assert x.var.name not in str(body.condition) assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))) @@ -233,7 +237,11 @@ def test_thread_axis2(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) stmt = tvm.lower(s, [A, B], name="main")["main"] - for_body = stmt.body.body.body.body[0] + + while not isinstance(stmt, tir.SeqStmt): + stmt = stmt.body + + for_body = stmt[0] assert "threadIdx" not in str(for_body.extent) @@ -712,32 +720,28 @@ def main(): @T.prim_func def partitioned_main(): - placeholder_0_dm = T.allocate([16384], "int8", "global") - placeholder_0_dm_1 = T.Buffer([16384], dtype="int8", data=placeholder_0_dm) + placeholder_0_dm = T.decl_buffer([16384], "int8") for i3_0 in T.unroll(2): for i2_0 in T.unroll(2): - pad_temp = T.allocate([4096], "int8", "global") - pad_temp_1 = T.Buffer([4096], dtype="int8", data=pad_temp) + pad_temp = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0 and 6 <= i3_0 * 4 + ax1: - pad_temp_1[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 ] for i2_0 in T.unroll(2): - pad_temp_2 = T.allocate([4096], "int8", "global") - pad_temp_3 = T.Buffer([4096], dtype="int8", data=pad_temp_2) + pad_temp_2 = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0: - pad_temp_3[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp_2[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + ax1 * 16 + ax2 + 128 ] for i3_0 in T.unroll(2): for i2_0 in T.unroll(2): - pad_temp_4 = T.allocate([4096], "int8", "global") - pad_temp_5 = T.Buffer([4096], dtype="int8", data=pad_temp_4) + pad_temp_4 = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0 and i3_0 * 4 + ax1 < 14: - pad_temp_5[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp_4[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 + 192 ] diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py index c03dd7a5291d..e2641a65f287 100644 --- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py +++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py @@ -170,6 +170,8 @@ def check(m, target_bits, target_dtype): B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name="B") s = te.create_schedule(B.op) stmt = lower_sch(s, [A, B], target_bits) + while isinstance(stmt, tvm.tir.DeclBuffer): + stmt = stmt.body assert stmt[1].loop_var.dtype == target_dtype # i32 -> i32 @@ -221,6 +223,8 @@ def check(shapex, shapey, target_bits, target_dtype): func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) + while isinstance(stmt, tvm.tir.DeclBuffer): + stmt = stmt.body # outer loop assert stmt.loop_var.dtype == target_dtype # inner loop @@ -262,7 +266,7 @@ def check(shape, index, target_bits, target_dtype): func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) - assert stmt.value.indices[0].dtype == target_dtype + assert stmt.body.body.value.indices[0].dtype == target_dtype check( (const(2**16, "int64"), const(2**15 + 1, "int64")), diff --git a/tests/python/tir-transform/test_tir_transform_storage_flatten.py b/tests/python/tir-transform/test_tir_transform_storage_flatten.py index 8ddfbb5adfd3..d3adea149fb9 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_flatten.py +++ b/tests/python/tir-transform/test_tir_transform_storage_flatten.py @@ -53,7 +53,7 @@ def test_flatten_prefetch(): mod = tvm.transform.Sequential( [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) - stmt = mod["main"].body + stmt = mod["main"].body.body assert stmt.extent.value == 2 assert isinstance(stmt.body, tvm.tir.For) assert stmt.body.extent.value == 2 @@ -80,7 +80,7 @@ def test_flatten_storage_align(): [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) - stmt = mod["main"].body + stmt = mod["main"].body.body.body assert stmt.extents[0].value == 17 * 8 @@ -114,9 +114,9 @@ def main(A_param: T.handle, C_param: T.handle): ] )(mod) - stmt = mod["main"].body - assert isinstance(stmt.body, tvm.tir.Allocate) - assert list(stmt.body.extents) == [8] + stmt = mod["main"].body.body.body.body + assert isinstance(stmt, tvm.tir.Allocate) + assert list(stmt.extents) == [8] mod = tvm.tir.transform.ThreadSync("shared")(mod) f = mod["main"]