diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 891700b86a4c..5c88807682d7 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -218,6 +218,9 @@ class Target : public ObjectRef { */ static Target WithHost(const Target& target, const Target& host); + /*! \return The target with the host stripped out */ + Target WithoutHost() const; + /*! * \brief Returns true if \p this target represents an external codegen. If so, * \p this->kind->name can be used as the "Compiler" attribute on partitioned functions, diff --git a/src/target/target.cc b/src/target/target.cc index f05d4db2b888..e479f592c640 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -662,6 +662,16 @@ Map TargetNode::Export() const { Optional TargetNode::GetHost() const { return this->host.as(); } +Target Target::WithoutHost() const { + if ((*this)->GetHost()) { + auto output = make_object(*get()); + output->host = NullOpt; + return Target(output); + } else { + return *this; + } +} + int TargetNode::GetTargetDeviceType() const { if (Optional device_type = GetAttr("target_device_type")) { return Downcast(device_type)->value; diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index dd9d471c5066..825a8da45b27 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -42,6 +42,7 @@ namespace tir { static constexpr const char* kDeviceContextVar = "device_api_context"; +namespace { class ReturnRewriter : public StmtMutator { public: explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} @@ -176,6 +177,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 82685411f592..bdb3a953e99c 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -40,6 +40,8 @@ namespace tvm { namespace tir { +namespace { + class SubroutineCallRewriter : public StmtExprMutator { public: static Optional Apply(const std::unordered_set& external_methods, @@ -84,6 +86,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + PrimFunc MakeUnpackedAPI(PrimFunc func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index 257e3eacda90..f844b51f5394 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -30,12 +30,32 @@ namespace tvm { namespace tir { namespace transform { transform::Pass BindTarget(Target target) { - auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { - if (f->GetAttr(tvm::tir::attr::kIsHostFunc) == 1) { - return WithAttr(std::move(WithoutAttr(std::move(f), tvm::tir::attr::kIsHostFunc)), - tvm::attr::kTarget, target->host.value_or(Target("llvm"))); + Target without_host = target.WithoutHost(); + Target target_host = Downcast(target->host.value_or(Target("llvm"))); + + auto fpass = [target, target_host, without_host](tir::PrimFunc func, IRModule m, + transform::PassContext ctx) { + bool is_externally_exposed = func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + + if (auto func_target = func->GetAttr(tvm::attr::kTarget)) { + auto func_target_host = func_target.value()->GetHost(); + auto target_host = target->GetHost(); + + if (target_host && !func_target_host && is_externally_exposed) { + auto new_target = Target::WithHost(func_target.value(), target_host.value()); + func = WithAttr(std::move(func), tvm::attr::kTarget, new_target); + } + } else if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) { + func = WithAttr(std::move(func), tvm::attr::kTarget, target_host); + } else if (is_externally_exposed) { + func = WithAttr(std::move(func), tvm::attr::kTarget, target); + } else { + func = WithAttr(std::move(func), tvm::attr::kTarget, without_host); } - return WithAttr(std::move(f), tvm::attr::kTarget, target); + + func = WithoutAttr(std::move(func), tvm::tir::attr::kIsHostFunc); + + return func; }; return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {}); } diff --git a/tests/python/unittest/test_tir_transform_helpers.py b/tests/python/unittest/test_tir_transform_helpers.py index 657bda591ae2..00fd12521268 100644 --- a/tests/python/unittest/test_tir_transform_helpers.py +++ b/tests/python/unittest/test_tir_transform_helpers.py @@ -85,6 +85,118 @@ def test_bind_target(): assert after["func2"].attrs["target"] == target +class TestBindTarget(tvm.testing.CompareBeforeAfter): + """BindTarget adds the "target" attribute""" + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda")) + + def before(): + T.evaluate(0) + + def expected(): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(0) + + +class TestBindTargetWithHostToExposedFunction(tvm.testing.CompareBeforeAfter): + """BindTarget adds the host target to externally-exposed functions""" + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm")) + + def before(): + T.func_attr({"global_symbol": "main"}) + T.evaluate(0) + + def expected(): + T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")}) + T.evaluate(0) + + +class TestBindTargetWithHostToInternalFunction(tvm.testing.CompareBeforeAfter): + """Internal functions have a target annotation, but without the host + + The host portion of the target annotation provides host + parameters, and is used to expose a function externally as part of + `MakePackedAPI` and `MakeUnpackedAPI`. For internal functions, no + external exposure is required, so the host attribute should not be + used. + """ + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm")) + + def before(): + T.evaluate(0) + + def expected(): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(0) + + +class TestBindTargetIgnoresExisting(tvm.testing.CompareBeforeAfter): + """BindTarget should not replace existing annotations""" + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda")) + + def before(): + T.func_attr({"target": T.target("nvptx")}) + T.evaluate(0) + + expected = before + + +class TestBindTargetUpdatesHost(tvm.testing.CompareBeforeAfter): + """BindTarget should update host for existing annotations""" + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm -opt-level=0")) + + def before(): + T.func_attr({"global_symbol": "func", "target": T.target("nvptx")}) + T.evaluate(0) + + def expected(): + T.func_attr( + { + "global_symbol": "func", + "target": T.target("nvptx", host="llvm -opt-level=0"), + } + ) + T.evaluate(0) + + +class TestBindTargetMultipleFunctions(tvm.testing.CompareBeforeAfter): + """BindTarget may apply to multiple functions in a module""" + + transform = tvm.tir.transform.BindTarget(tvm.target.Target("cuda")) + + def before(self): + @tvm.script.ir_module + class mod: + @T.prim_func + def func1(): + T.evaluate(0) + + @T.prim_func + def func2(): + T.evaluate(0) + + return mod + + def expected(self): + @tvm.script.ir_module + class mod: + @T.prim_func + def func1(): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(0) + + @T.prim_func + def func2(): + T.func_attr({"target": T.target("cuda")}) + T.evaluate(0) + + return mod + + def test_filter_primfunc(): mod = MockModule assert mod