From 668d9227cc59796576fb17f4db7e3db0f6024303 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 21 Mar 2023 08:52:09 -0500 Subject: [PATCH 1/3] [TIR] MakePackedAPI, handle missing kGlobalSymbol Previously, `MakePackedAPI` required all functions to have the `kGlobalSymbol` attribute. This commit updates the behavior such that `MakePackedAPI` only modifies PrimFuncs that have the `kGlobalSymbol` attribute, and passes through any other PrimFunc unmodified. --- src/tir/transforms/make_packed_api.cc | 32 +++++--- .../test_tir_transform_make_packed_api.py | 73 ++++++++++++++++++- 2 files changed, 94 insertions(+), 11 deletions(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index de1f17608273..355ead469787 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -139,16 +139,30 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } -PrimFunc MakePackedAPI(PrimFunc&& func) { - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; - - auto target = func->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute"; - int target_device_type = target.value()->GetTargetDeviceType(); +PrimFunc MakePackedAPI(PrimFunc func) { + // A function with an explicit calling convention has already been + // lowered, and should not be modified. + if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { + if (CallingConv(opt.value()->value) != CallingConv::kDefault) { + return func; + } + } + // An internal subroutine does not require the PackedFunc API. + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + if (!global_symbol.defined()) { + return func; + } std::string name_hint = global_symbol.value(); + Target target = [&]() { + auto opt = func->GetAttr(tvm::attr::kTarget); + ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget (" + << tvm::attr::kTarget << "), but the function only has attributes " << func->attrs; + return opt.value(); + }(); + int target_device_type = target->GetTargetDeviceType(); + auto* func_ptr = func.CopyOnWrite(); const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); @@ -292,7 +306,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func) { func_ptr->params = args; Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); - ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << global_symbol << " variables " << undefined + ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined << " are used, but are not passed in as API arguments"; func_ptr->buffer_map = Map(); @@ -300,7 +314,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func) { func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function. - return std::move(func); + return func; } namespace transform { diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 47bb7bf228d4..6276f907fcea 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -15,8 +15,12 @@ # specific language governing permissions and limitations # under the License. +import pytest + import tvm -from tvm import te +import tvm.testing +from tvm import te, tir +from tvm.script import tir as T, ir as I from tvm.driver.build_module import schedule_to_module @@ -39,7 +43,9 @@ def test_makeapi(): ) )(mod) - f = tvm.tir.transform.MakePackedAPI()(mod)["main"] + before = mod + after = tvm.tir.transform.MakePackedAPI()(mod) + f = after["main"] assert len(f.params) == 6 @@ -59,6 +65,19 @@ def _find_next(stmt, type): return stmt +def _find_compute_scope(func): + result = None + + def _visitor(stmt): + if isinstance(stmt, tir.AttrStmt) and stmt.attr_key == "compute_scope": + nonlocal result + result = stmt + + tir.stmt_functor.post_order_visit(func.body, _visitor) + + return result + + def test_variable_passed_from_args(): ib = tvm.tir.ir_builder.create() @@ -143,5 +162,55 @@ def test_device_api_context_implicit_resource_handle(): assert call_extern.args[2] == device_context_in_resource_handle +@pytest.mark.parametrize("use_global_symbol", [True, False]) +def test_no_op_when_global_symbol_is_absent(use_global_symbol): + func_attr = {"target": tvm.target.Target("llvm")} + if use_global_symbol: + func_attr["global_symbol"] = "main" + + @T.prim_func + def before(): + T.func_attr(func_attr) + T.evaluate(0) + + after = tvm.tir.transform.MakePackedAPI()(tvm.IRModule.from_expr(before))["main"] + if use_global_symbol: + assert len(after.params) == 6 + else: + tvm.ir.assert_structural_equal(before, after) + + +def test_internal_subroutine_call(): + """Internal subroutines should not use the PackedFunc API + + A subroutine without the "global_symbol" attribute is an internal + subroutine, and is not directly exposed to a user of the generated + `runtime.Module`. Therefore, it doesn't need to follow the + PackedFunc API. + """ + + @I.ir_module + class before: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + before.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"target": T.target("llvm")}) + T.evaluate(A_data) + + after = tvm.tir.transform.MakePackedAPI()(before) + tvm.ir.assert_structural_equal(before["subroutine"], after["subroutine"]) + + compute_scope = _find_compute_scope(after["main"]) + subroutine_call_op = compute_scope.body.value.op + assert isinstance(subroutine_call_op, tvm.ir.GlobalVar), ( + f"The main function's CallNode should use the subroutine's GLobalVar as the operation, " + f"but instead has an operation of type {subroutine_call_op}" + ) + + if __name__ == "__main__": test_makeapi() From d706d35a2e7490a9646cc14b6390494fa7a1bb30 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Mar 2023 12:54:32 -0500 Subject: [PATCH 2/3] [TIR] Update calls to externally-exposed subroutines in MakePackedAPI When a function is updated to use the `PackedFunc` API, any calls made to that function from elsewhere in the `IRModule` should be updated as well. --- src/tir/transforms/make_packed_api.cc | 82 ++++++++++++++++--- .../test_tir_transform_make_packed_api.py | 38 +++++++++ 2 files changed, 107 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 355ead469787..8ae79c016565 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -135,6 +135,47 @@ Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { return rewriter(body); } +class SubroutineCallRewriter : public StmtExprMutator { + public: + static Optional Apply(const Map& external_methods, Stmt stmt) { + SubroutineCallRewriter rewriter(external_methods); + stmt = rewriter.VisitStmt(std::move(stmt)); + if (rewriter.made_change_) { + return stmt; + } else { + return NullOpt; + } + } + + private: + explicit SubroutineCallRewriter(const Map& external_methods) + : external_methods_(external_methods) {} + + PrimExpr VisitExpr_(const CallNode* op) override { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + + if (auto* gvar_ptr = node->op.as()) { + auto gvar = GetRef(gvar_ptr); + if (auto symbol = external_methods_.Get(gvar)) { + Array cpacked_args; + cpacked_args.push_back(tir::StringImm(symbol.value())); + for (auto arg : node->args) { + cpacked_args.push_back(arg); + } + + // push an empty handle to be compatible with current cpacked convention + cpacked_args.push_back(tir::make_zero(DataType::Handle())); + made_change_ = true; + return tir::Call(node->dtype, tir::builtin::tvm_call_cpacked(), cpacked_args); + } + } + + return node; + } + const Map& external_methods_; + bool made_change_{false}; +}; + inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } @@ -148,7 +189,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { } } - // An internal subroutine does not require the PackedFunc API. + // Internal function calls do not need the PackedFunc API auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (!global_symbol.defined()) { return func; @@ -320,25 +361,40 @@ PrimFunc MakePackedAPI(PrimFunc func) { namespace transform { Pass MakePackedAPI() { - auto pass_func = [](IRModule m, PassContext ctx) { - IRModuleNode* mptr = m.CopyOnWrite(); - std::vector> updates; + auto pass_func = [](IRModule mod, PassContext ctx) { + Map external_methods; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto* prim_func = base_func.as()) { + if (auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol)) { + external_methods.Set(gvar, global_symbol.value()); + } + } + } - for (const auto& kv : mptr->functions) { - if (auto opt = kv.second.as()) { + IRModuleNode* mptr = mod.CopyOnWrite(); + IRModule updates; + + for (const auto& [gvar, base_func] : mptr->functions) { + if (auto opt = base_func.as()) { auto func = opt.value(); - if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == - CallingConv::kDefault) { - auto updated_func = MakePackedAPI(std::move(func)); - updates.push_back({kv.first, updated_func}); + auto orig_func = func; + + if (auto body = SubroutineCallRewriter::Apply(external_methods, func->body)) { + func.CopyOnWrite()->body = body.value(); + } + + func = MakePackedAPI(std::move(func)); + + if (!func.same_as(orig_func)) { + updates->Add(gvar, func); } } } - for (const auto& pair : updates) { - mptr->AddUnchecked(pair.first, pair.second); + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); } - return m; + return mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {}); diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 6276f907fcea..cd27c0305c8b 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -212,5 +212,43 @@ def subroutine(A_data: T.handle("float32")): ) +def test_subroutine_call_to_externally_visible_subroutine(): + """Externally-visible subroutines should use the PackedFunc API + + Because the subroutine may be called directly by a user, it must + use the PackedFunc API. Its signature should be updated to the + PackedFunc signature, and call sites should be updated to use + `T.tvm_call_cpacked`. + """ + + @I.ir_module + class before: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + before.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.evaluate(A_data) + + after = tvm.tir.transform.MakePackedAPI()(before) + + main_compute_scope = _find_compute_scope(after["main"]) + assert main_compute_scope is not None + subroutine_compute_scope = _find_compute_scope(after["subroutine"]) + assert subroutine_compute_scope is not None + + subroutine_call_op = main_compute_scope.body.value.op + assert ( + isinstance(subroutine_call_op, tvm.ir.Op) + and subroutine_call_op.name == "tir.tvm_call_cpacked" + ), ( + f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', " + f"but instead has an operation of type {subroutine_call_op}" + ) + + if __name__ == "__main__": test_makeapi() From 9859694db50440d714dc61fc284881082a8ac3f0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 22 May 2023 13:14:27 -0500 Subject: [PATCH 3/3] Bugfix, don't update the callsite unless the callee is also updated --- src/tir/transforms/make_packed_api.cc | 43 +++++++++++++++++++-------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 8ae79c016565..dd9d471c5066 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -137,8 +137,8 @@ Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { class SubroutineCallRewriter : public StmtExprMutator { public: - static Optional Apply(const Map& external_methods, Stmt stmt) { - SubroutineCallRewriter rewriter(external_methods); + static Optional Apply(const Map& packed_func_methods, Stmt stmt) { + SubroutineCallRewriter rewriter(packed_func_methods); stmt = rewriter.VisitStmt(std::move(stmt)); if (rewriter.made_change_) { return stmt; @@ -148,15 +148,15 @@ class SubroutineCallRewriter : public StmtExprMutator { } private: - explicit SubroutineCallRewriter(const Map& external_methods) - : external_methods_(external_methods) {} + explicit SubroutineCallRewriter(const Map& packed_func_methods) + : packed_func_methods(packed_func_methods) {} PrimExpr VisitExpr_(const CallNode* op) override { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); if (auto* gvar_ptr = node->op.as()) { auto gvar = GetRef(gvar_ptr); - if (auto symbol = external_methods_.Get(gvar)) { + if (auto symbol = packed_func_methods.Get(gvar)) { Array cpacked_args; cpacked_args.push_back(tir::StringImm(symbol.value())); for (auto arg : node->args) { @@ -172,7 +172,7 @@ class SubroutineCallRewriter : public StmtExprMutator { return node; } - const Map& external_methods_; + const Map& packed_func_methods; bool made_change_{false}; }; @@ -180,17 +180,33 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } -PrimFunc MakePackedAPI(PrimFunc func) { +/* \brief Return the global_symbol of the function, if it should be updated + * + * \param func The function to be inspected + * + * \returns The global_symbol to be used for the function at call + * sites, or NullOpt if the function is to remain unchanged. + */ +Optional RequiresPackedAPI(const PrimFunc& func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { if (CallingConv(opt.value()->value) != CallingConv::kDefault) { - return func; + return NullOpt; } } // Internal function calls do not need the PackedFunc API auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + if (!global_symbol.defined()) { + return NullOpt; + } + + return global_symbol; +} + +PrimFunc MakePackedAPI(PrimFunc func) { + auto global_symbol = RequiresPackedAPI(func); if (!global_symbol.defined()) { return func; } @@ -362,11 +378,12 @@ namespace transform { Pass MakePackedAPI() { auto pass_func = [](IRModule mod, PassContext ctx) { - Map external_methods; + Map packed_func_methods; for (const auto& [gvar, base_func] : mod->functions) { - if (auto* prim_func = base_func.as()) { - if (auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol)) { - external_methods.Set(gvar, global_symbol.value()); + if (auto opt = base_func.as()) { + auto prim_func = opt.value(); + if (auto global_symbol = RequiresPackedAPI(prim_func)) { + packed_func_methods.Set(gvar, global_symbol.value()); } } } @@ -379,7 +396,7 @@ Pass MakePackedAPI() { auto func = opt.value(); auto orig_func = func; - if (auto body = SubroutineCallRewriter::Apply(external_methods, func->body)) { + if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) { func.CopyOnWrite()->body = body.value(); }