diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index f844b51f5394..8a5317a3c84a 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -46,7 +46,8 @@ transform::Pass BindTarget(Target target) { func = WithAttr(std::move(func), tvm::attr::kTarget, new_target); } } else if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) { - func = WithAttr(std::move(func), tvm::attr::kTarget, target_host); + func = + WithAttr(std::move(func), tvm::attr::kTarget, Target::WithHost(target_host, target_host)); } else if (is_externally_exposed) { func = WithAttr(std::move(func), tvm::attr::kTarget, target); } else { diff --git a/tests/python/unittest/test_tir_host_func.py b/tests/python/unittest/test_tir_host_func.py index ea0ad7ba4a8a..ed04985bdda1 100644 --- a/tests/python/unittest/test_tir_host_func.py +++ b/tests/python/unittest/test_tir_host_func.py @@ -22,6 +22,7 @@ # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring # fmt: off + @I.ir_module class Module: @T.prim_func @@ -33,7 +34,7 @@ def main( T.func_attr( { "global_symbol": "test", - "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), + "target": tvm.target.Target("llvm", host="llvm"), "tir.noalias": True, } )