diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index de1f17608273..dd9d471c5066 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -135,20 +135,91 @@ Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { return rewriter(body); } +class SubroutineCallRewriter : public StmtExprMutator { + public: + static Optional Apply(const Map& packed_func_methods, Stmt stmt) { + SubroutineCallRewriter rewriter(packed_func_methods); + stmt = rewriter.VisitStmt(std::move(stmt)); + if (rewriter.made_change_) { + return stmt; + } else { + return NullOpt; + } + } + + private: + explicit SubroutineCallRewriter(const Map& packed_func_methods) + : packed_func_methods(packed_func_methods) {} + + PrimExpr VisitExpr_(const CallNode* op) override { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + + if (auto* gvar_ptr = node->op.as()) { + auto gvar = GetRef(gvar_ptr); + if (auto symbol = packed_func_methods.Get(gvar)) { + Array cpacked_args; + cpacked_args.push_back(tir::StringImm(symbol.value())); + for (auto arg : node->args) { + cpacked_args.push_back(arg); + } + + // push an empty handle to be compatible with current cpacked convention + cpacked_args.push_back(tir::make_zero(DataType::Handle())); + made_change_ = true; + return tir::Call(node->dtype, tir::builtin::tvm_call_cpacked(), cpacked_args); + } + } + + return node; + } + const Map& packed_func_methods; + bool made_change_{false}; +}; + inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } -PrimFunc MakePackedAPI(PrimFunc&& func) { +/* \brief Return the global_symbol of the function, if it should be updated + * + * \param func The function to be inspected + * + * \returns The global_symbol to be used for the function at call + * sites, or NullOpt if the function is to remain unchanged. + */ +Optional RequiresPackedAPI(const PrimFunc& func) { + // A function with an explicit calling convention has already been + // lowered, and should not be modified. + if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { + if (CallingConv(opt.value()->value) != CallingConv::kDefault) { + return NullOpt; + } + } + + // Internal function calls do not need the PackedFunc API auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; + if (!global_symbol.defined()) { + return NullOpt; + } - auto target = func->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute"; - int target_device_type = target.value()->GetTargetDeviceType(); + return global_symbol; +} +PrimFunc MakePackedAPI(PrimFunc func) { + auto global_symbol = RequiresPackedAPI(func); + if (!global_symbol.defined()) { + return func; + } std::string name_hint = global_symbol.value(); + Target target = [&]() { + auto opt = func->GetAttr(tvm::attr::kTarget); + ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget (" + << tvm::attr::kTarget << "), but the function only has attributes " << func->attrs; + return opt.value(); + }(); + int target_device_type = target->GetTargetDeviceType(); + auto* func_ptr = func.CopyOnWrite(); const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); @@ -292,7 +363,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func) { func_ptr->params = args; Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); - ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << global_symbol << " variables " << undefined + ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined << " are used, but are not passed in as API arguments"; func_ptr->buffer_map = Map(); @@ -300,31 +371,47 @@ PrimFunc MakePackedAPI(PrimFunc&& func) { func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function. - return std::move(func); + return func; } namespace transform { Pass MakePackedAPI() { - auto pass_func = [](IRModule m, PassContext ctx) { - IRModuleNode* mptr = m.CopyOnWrite(); - std::vector> updates; + auto pass_func = [](IRModule mod, PassContext ctx) { + Map packed_func_methods; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto prim_func = opt.value(); + if (auto global_symbol = RequiresPackedAPI(prim_func)) { + packed_func_methods.Set(gvar, global_symbol.value()); + } + } + } + + IRModuleNode* mptr = mod.CopyOnWrite(); + IRModule updates; - for (const auto& kv : mptr->functions) { - if (auto opt = kv.second.as()) { + for (const auto& [gvar, base_func] : mptr->functions) { + if (auto opt = base_func.as()) { auto func = opt.value(); - if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == - CallingConv::kDefault) { - auto updated_func = MakePackedAPI(std::move(func)); - updates.push_back({kv.first, updated_func}); + auto orig_func = func; + + if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) { + func.CopyOnWrite()->body = body.value(); + } + + func = MakePackedAPI(std::move(func)); + + if (!func.same_as(orig_func)) { + updates->Add(gvar, func); } } } - for (const auto& pair : updates) { - mptr->AddUnchecked(pair.first, pair.second); + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); } - return m; + return mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {}); diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 47bb7bf228d4..cd27c0305c8b 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -15,8 +15,12 @@ # specific language governing permissions and limitations # under the License. +import pytest + import tvm -from tvm import te +import tvm.testing +from tvm import te, tir +from tvm.script import tir as T, ir as I from tvm.driver.build_module import schedule_to_module @@ -39,7 +43,9 @@ def test_makeapi(): ) )(mod) - f = tvm.tir.transform.MakePackedAPI()(mod)["main"] + before = mod + after = tvm.tir.transform.MakePackedAPI()(mod) + f = after["main"] assert len(f.params) == 6 @@ -59,6 +65,19 @@ def _find_next(stmt, type): return stmt +def _find_compute_scope(func): + result = None + + def _visitor(stmt): + if isinstance(stmt, tir.AttrStmt) and stmt.attr_key == "compute_scope": + nonlocal result + result = stmt + + tir.stmt_functor.post_order_visit(func.body, _visitor) + + return result + + def test_variable_passed_from_args(): ib = tvm.tir.ir_builder.create() @@ -143,5 +162,93 @@ def test_device_api_context_implicit_resource_handle(): assert call_extern.args[2] == device_context_in_resource_handle +@pytest.mark.parametrize("use_global_symbol", [True, False]) +def test_no_op_when_global_symbol_is_absent(use_global_symbol): + func_attr = {"target": tvm.target.Target("llvm")} + if use_global_symbol: + func_attr["global_symbol"] = "main" + + @T.prim_func + def before(): + T.func_attr(func_attr) + T.evaluate(0) + + after = tvm.tir.transform.MakePackedAPI()(tvm.IRModule.from_expr(before))["main"] + if use_global_symbol: + assert len(after.params) == 6 + else: + tvm.ir.assert_structural_equal(before, after) + + +def test_internal_subroutine_call(): + """Internal subroutines should not use the PackedFunc API + + A subroutine without the "global_symbol" attribute is an internal + subroutine, and is not directly exposed to a user of the generated + `runtime.Module`. Therefore, it doesn't need to follow the + PackedFunc API. + """ + + @I.ir_module + class before: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + before.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"target": T.target("llvm")}) + T.evaluate(A_data) + + after = tvm.tir.transform.MakePackedAPI()(before) + tvm.ir.assert_structural_equal(before["subroutine"], after["subroutine"]) + + compute_scope = _find_compute_scope(after["main"]) + subroutine_call_op = compute_scope.body.value.op + assert isinstance(subroutine_call_op, tvm.ir.GlobalVar), ( + f"The main function's CallNode should use the subroutine's GLobalVar as the operation, " + f"but instead has an operation of type {subroutine_call_op}" + ) + + +def test_subroutine_call_to_externally_visible_subroutine(): + """Externally-visible subroutines should use the PackedFunc API + + Because the subroutine may be called directly by a user, it must + use the PackedFunc API. Its signature should be updated to the + PackedFunc signature, and call sites should be updated to use + `T.tvm_call_cpacked`. + """ + + @I.ir_module + class before: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "main", "target": T.target("llvm")}) + before.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")}) + T.evaluate(A_data) + + after = tvm.tir.transform.MakePackedAPI()(before) + + main_compute_scope = _find_compute_scope(after["main"]) + assert main_compute_scope is not None + subroutine_compute_scope = _find_compute_scope(after["subroutine"]) + assert subroutine_compute_scope is not None + + subroutine_call_op = main_compute_scope.body.value.op + assert ( + isinstance(subroutine_call_op, tvm.ir.Op) + and subroutine_call_op.name == "tir.tvm_call_cpacked" + ), ( + f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', " + f"but instead has an operation of type {subroutine_call_op}" + ) + + if __name__ == "__main__": test_makeapi()