From ed4c9f997c252bf168401d522ced654e2541a626 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 22 May 2023 09:27:25 -0500 Subject: [PATCH] [TIR] Handle subroutine calls in MakeUnpackedAPI Prior to this commit, MakeUnpackedAPI required all functions to be annotated with `kGlobalSymbol` (`"global_symbol"`). This commit updates the transformation to apply only to functions that have the `kGlobalSymbol` attribute, and to update any internal callsites of the modified functions. This is analogous to the changes made in https://github.com/apache/tvm/pull/14913, which updates `MakePackedAPI`. --- src/tir/transforms/make_unpacked_api.cc | 109 +++++++++--- .../test_tir_transform_make_unpacked_api.py | 158 +++++++++++++++++- 2 files changed, 242 insertions(+), 25 deletions(-) diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index e327b3094594..82685411f592 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -40,20 +40,79 @@ namespace tvm { namespace tir { -PrimFunc MakeUnpackedAPI(PrimFunc&& func) { +class SubroutineCallRewriter : public StmtExprMutator { + public: + static Optional Apply(const std::unordered_set& external_methods, + Stmt stmt) { + SubroutineCallRewriter rewriter(external_methods); + stmt = rewriter.VisitStmt(std::move(stmt)); + if (rewriter.made_change_) { + return stmt; + } else { + return NullOpt; + } + } + + private: + explicit SubroutineCallRewriter(const std::unordered_set& external_methods) + : external_methods_(external_methods) {} + + PrimExpr VisitExpr_(const CallNode* op) override { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + + if (auto gvar = node->op.as()) { + if (external_methods_.count(gvar)) { + Array args = node->args.Map([this](const PrimExpr& arg) -> PrimExpr { + if (auto* as_call = arg.as()) { + if (as_call->op.same_as(builtin::tvm_stack_make_array())) { + PrimExpr data_ptr = as_call->args[0]; + made_change_ = true; + return data_ptr; + } + } + return arg; + }); + if (!args.same_as(node->args)) { + node.CopyOnWrite()->args = args; + } + } + } + + return std::move(node); + } + const std::unordered_set& external_methods_; + bool made_change_{false}; +}; + +PrimFunc MakeUnpackedAPI(PrimFunc func) { + // A function with an explicit calling convention has already been + // lowered, and should not be modified. + if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { + if (CallingConv(opt.value()->value) != CallingConv::kDefault) { + return func; + } + } + + // Internal function calls do not need API updates auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol) << "MakeUnpackedAPI: Expect PrimFunc to have the global_symbol attribute"; + if (!global_symbol.defined()) { + return func; + } - auto target = func->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "MakeUnpackedAPI: Require the target attribute"; + Target target = [&]() { + auto opt = func->GetAttr(tvm::attr::kTarget); + ICHECK(opt) << "MakeUnpackedAPI required the function to be annotated with tvm::attr::kTarget (" + << tvm::attr::kTarget << "), but the function only has attributes " << func->attrs; + return opt.value(); + }(); + int target_device_type = target->GetTargetDeviceType(); auto* func_ptr = func.CopyOnWrite(); // Setup device context - int target_device_type = target.value()->GetTargetDeviceType(); Integer device_type(target_device_type); Integer device_id(0); - PrimExpr node = StringImm("default"); + ObjectRef node = String("default"); const Stmt nop = Evaluate(0); std::vector device_init; @@ -82,31 +141,43 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { func_ptr->buffer_map = Map(); // return the function. - return std::move(func); + return func; } namespace transform { Pass MakeUnpackedAPI() { - auto pass_func = [](IRModule m, PassContext ctx) { - IRModuleNode* mptr = m.CopyOnWrite(); - std::vector> updates; + auto pass_func = [](IRModule mod, PassContext ctx) { + std::unordered_set external_methods; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto* prim_func = base_func.as()) { + if (prim_func->GetAttr(tvm::attr::kGlobalSymbol)) { + external_methods.insert(gvar.get()); + } + } + } + + IRModule updates; - for (const auto& kv : mptr->functions) { - if (auto opt = kv.second.as()) { + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { auto func = opt.value(); - if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == - CallingConv::kDefault) { - auto updated_func = MakeUnpackedAPI(std::move(func)); - updates.push_back({kv.first, updated_func}); + + if (auto body = SubroutineCallRewriter::Apply(external_methods, func->body)) { + func.CopyOnWrite()->body = body.value(); + } + + func = MakeUnpackedAPI(std::move(func)); + if (!func.same_as(base_func)) { + updates->Add(gvar, func); } } } - for (const auto& pair : updates) { - mptr->AddUnchecked(pair.first, pair.second); + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); } - return m; + return mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakeUnpackedAPI", {}); diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index 245ff53f9105..bb9fe8ab8267 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -17,7 +17,8 @@ import pytest import tvm -from tvm import te +from tvm import te, tir +from tvm.script import tir as T, ir as I import numpy @@ -39,17 +40,20 @@ def mod(mod_without_attrs): return mod -def test_fails_if_not_global_symbol(mod_without_attrs): - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))( +def test_noop_if_not_global_symbol(mod_without_attrs): + before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))( mod_without_attrs ) - with pytest.raises(tvm.TVMError, match="Expect PrimFunc to have the global_symbol attribute"): - f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] + after = tvm.tir.transform.MakeUnpackedAPI()(before) + tvm.ir.assert_structural_equal(before, after) def test_fails_if_no_target(mod_without_attrs): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod_without_attrs) - with pytest.raises(tvm.TVMError, match="Require the target attribute"): + with pytest.raises( + tvm.TVMError, + match="MakeUnpackedAPI required the function to be annotated with tvm::attr::kTarget", + ): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] @@ -134,5 +138,147 @@ def test_body(): assert f.params[2].name == "A" +class TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter): + """Internal subroutines do not require modification + + A subroutine without the "global_symbol" attribute is an internal + subroutine, and is not directly exposed to a user of the generated + `runtime.Module`. + """ + + transform = tvm.tir.transform.MakeUnpackedAPI() + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + mod.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"target": T.target("llvm")}) + T.evaluate(A_data) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A_data: T.handle("float32")) -> T.int32: + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + mod.subroutine(A_data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"target": T.target("llvm")}) + T.evaluate(A_data) + + return mod + + +class TestSubroutineCallToExternallyVisibleSubroutine(tvm.testing.CompareBeforeAfter): + """Externally-visible subroutines should be updated + + Subroutines that are exposed externally should be updated by + MakeUnpackedAPI. + """ + + transform = tvm.tir.transform.MakeUnpackedAPI() + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + mod.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.evaluate(A_data) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A_data: T.handle("float32")) -> T.int32: + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + mod.subroutine(A_data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")) -> T.int32: + T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.evaluate(A_data) + + return mod + + +class TestCallExternallyVisibleSubroutineWithDLTensor(tvm.testing.CompareBeforeAfter): + """Callsites of externally-visible subroutines may require updates + + The MakeUnpackedAPI transform lowers all buffers into a data + pointer to a primitive type. If a subroutine call is currently + passing a DLTensor produced by `T.tvm_make_stack_array` into the + subroutine, the callsite should be updated to instead pass the + data pointer directly. + """ + + transform = tvm.tir.transform.MakeUnpackedAPI() + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + mod.subroutine( + T.tvm_stack_make_array( + A.data, + T.tvm_stack_make_shape(1, dtype="handle"), + T.reinterpret(T.uint64(0), dtype="handle"), + T.uint32(1), + T.Cast("float32", 0), + 0, + dtype="handle", + ) + ) + + @T.prim_func + def subroutine(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.evaluate(A.data) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A_data: T.handle("float32")) -> T.int32: + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + mod.subroutine(A_data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")) -> T.int32: + T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + T.evaluate(A_data) + + return mod + + if __name__ == "__main__": tvm.testing.main()