diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 891700b86a4c..1c46828c3049 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -71,6 +71,16 @@ class TargetNode : public Object { /*! \return The device type for this target */ TVM_DLL int GetTargetDeviceType() const; + /*! + * \brief Check if the target contains a key + * + * \param query_key The string name of the key to be checked + * + * \return True if the target's `TargetNode::keys` contains the + * specified key, False otherwise. + */ + TVM_DLL bool HasKey(const std::string& query_key) const; + /*! * \brief Returns a human readable representation of \p Target which includes all fields, * especially the host. Useful for diagnostic messages and debugging. diff --git a/src/target/target.cc b/src/target/target.cc index f05d4db2b888..3d51e0ad2766 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -669,6 +669,11 @@ int TargetNode::GetTargetDeviceType() const { return kind->default_device_type; } +bool TargetNode::HasKey(const std::string& query_key) const { + return std::any_of(keys.begin(), keys.end(), + [&query_key](const auto& key) { return key == query_key; }); +} + String TargetNode::ToDebugString() const { std::ostringstream os; os << "Target("; diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 4a3986460b15..18e568c83e74 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -33,6 +33,8 @@ #include +#include "ir_utils.h" + namespace tvm { namespace tir { @@ -102,8 +104,9 @@ namespace transform { Pass CombineContextCall() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - n->body = ContextCallCombiner().Combine(std::move(n->body)); + if (IsHostFunc(f).value_or(false)) { + f.CopyOnWrite()->body = ContextCallCombiner().Combine(f->body); + } return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {}); diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 9b47d84e6aa2..604dbed325ec 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -692,6 +692,16 @@ std::pair GetWmmaFragmentDimSize(const std::string& shape_str, return std::pair(0, 0); } +std::optional IsHostFunc(const PrimFunc& func) { + if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) { + return true; + } else if (auto target = func->GetAttr(tvm::attr::kTarget)) { + return target.value()->HasKey("cpu"); + } else { + return std::nullopt; + } +} + namespace transform { Pass ConvertSSA() { auto pass_func = [](IRModule mod, PassContext ctx) { diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 59dc95dcd6a0..b48502871372 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -34,6 +34,7 @@ #include #include +#include #include #include #include @@ -351,6 +352,17 @@ CollectStorageAlignAnnotation(const Stmt& body); std::pair GetWmmaFragmentDimSize(const std::string& shape_str, const std::string& scope); +/*! \brief Check if a PrimFunc is a host function + * + * \param func The function to be inspected + * + * \return True if the function is known to run on the host, false if + * the function is known to run on the device. If it cannot be + * determined (e.g. a function without a tvm::attr::kTarget + * attribute), returns std::nullopt. + */ +std::optional IsHostFunc(const PrimFunc& func); + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index 3271e6e2569a..9f6147c96a89 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -14,33 +14,95 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import tvm -from tvm import te +import tvm.testing + +from tvm.script import tir as T, ir as I + + +def _device_context(dev_type, dev_id): + ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id) + return tvm.tir.Call("handle", "tir.tvm_thread_context", [ctx]) + + +class TestCombineContextsInLoop(tvm.testing.CompareBeforeAfter): + """Device contexts should be hoisted and merged""" + + transform = tvm.tir.transform.CombineContextCall() + + def before(self): + @T.prim_func + def func(dev_type: T.int32, n: T.int32): + T.func_attr({"target": T.target("llvm")}) + A = T.allocate([n], "float32", "global") + for i in range(n): + T.call_extern( + "int32", + "fadd", + _device_context(dev_type, 0), + A, + ) + for j in range(10): + T.call_extern( + "int32", + "fadd", + _device_context(dev_type, 1), + A, + ) + T.call_extern( + "int32", + "fadd", + _device_context(dev_type, 0), + A, + ) + + return func + def expected(dev_type: T.int32, n: T.int32): + T.func_attr({"target": T.target("llvm")}) + ctx_cache_: T.handle = T.call_extern("handle", "device_context", dev_type, 0) + ctx_cache__1: T.handle = T.call_extern("handle", "device_context", dev_type, 1) + A = T.allocate([n], "float32", "global") + for i in range(n): + T.call_extern("int32", "fadd", ctx_cache_, A) + for j in range(10): + T.call_extern("int32", "fadd", ctx_cache__1, A) + T.call_extern("int32", "fadd", ctx_cache_, A) -def test_for(): - dev_type = te.var("dev_type") - def device_context(dev_id): - ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id) - return tvm.tir.Call("handle", "tir.tvm_thread_context", [ctx]) +class TestCombineContextsInLoopWithoutTarget(TestCombineContextsInLoop): + """CombineContextCall only updates host-side functions""" - ib = tvm.tir.ir_builder.create() - n = te.var("n") - A = ib.allocate("float32", n, name="A", scope="global") - with ib.for_range(0, n, name="i") as i: - ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A.asobject().data)) - with ib.for_range(0, 10, name="j") as j: - ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(1), A.asobject().data)) - ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A.asobject().data)) - body = ib.get() - mod = tvm.IRModule({"func": tvm.tir.PrimFunc([dev_type, n], body)}) + def before(self): + @T.prim_func + def func(dev_type: T.int32, n: T.int32): + A = T.allocate([n], "float32", "global") + for i in range(n): + T.call_extern( + "int32", + "fadd", + _device_context(dev_type, 0), + A, + ) + for j in range(10): + T.call_extern( + "int32", + "fadd", + _device_context(dev_type, 1), + A, + ) + T.call_extern( + "int32", + "fadd", + _device_context(dev_type, 0), + A, + ) - mod = tvm.tir.transform.CombineContextCall()(mod) + return func - assert mod["func"].body.value.dtype == "handle" - assert mod["func"].body.body.value.dtype == "handle" + expected = before if __name__ == "__main__": - test_for() + tvm.testing.main()