Skip to content

Commit 5fd49f7

Browse files
authored
[TIR] Restrict tir.transform.CombineContextCall to host functions (#14945)
* [Target] Added utility method TargetNode::HasKey() This utility method makes it easier to determine if a target contains a specific key. * [TIR] Added utility method tvm::tir::IsHostFunc(const PrimFunc&) For modules that contain both host and device functions, this utility function checks whether a given PrimFunc is a host function, based on the target annotation. * [TIR] Restrict tir.transform.CombineContextCall to host functions Previously, the `tir.transform.CombineContextCall` pass applied to all functions in an `IRModule`, but was only applied to modules that contain only host functions. This commit updates `tir.transform.CombineContextCall` to apply only to host functions.
1 parent 94f4e25 commit 5fd49f7

File tree

6 files changed

+124
-22
lines changed

6 files changed

+124
-22
lines changed

include/tvm/target/target.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ class TargetNode : public Object {
7171
/*! \return The device type for this target */
7272
TVM_DLL int GetTargetDeviceType() const;
7373

74+
/*!
75+
* \brief Check if the target contains a key
76+
*
77+
* \param query_key The string name of the key to be checked
78+
*
79+
* \return True if the target's `TargetNode::keys` contains the
80+
* specified key, False otherwise.
81+
*/
82+
TVM_DLL bool HasKey(const std::string& query_key) const;
83+
7484
/*!
7585
* \brief Returns a human readable representation of \p Target which includes all fields,
7686
* especially the host. Useful for diagnostic messages and debugging.

src/target/target.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,11 @@ int TargetNode::GetTargetDeviceType() const {
669669
return kind->default_device_type;
670670
}
671671

672+
bool TargetNode::HasKey(const std::string& query_key) const {
673+
return std::any_of(keys.begin(), keys.end(),
674+
[&query_key](const auto& key) { return key == query_key; });
675+
}
676+
672677
String TargetNode::ToDebugString() const {
673678
std::ostringstream os;
674679
os << "Target(";

src/tir/transforms/combine_context_call.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
#include <unordered_map>
3535

36+
#include "ir_utils.h"
37+
3638
namespace tvm {
3739
namespace tir {
3840

@@ -102,8 +104,9 @@ namespace transform {
102104

103105
Pass CombineContextCall() {
104106
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
105-
auto* n = f.CopyOnWrite();
106-
n->body = ContextCallCombiner().Combine(std::move(n->body));
107+
if (IsHostFunc(f).value_or(false)) {
108+
f.CopyOnWrite()->body = ContextCallCombiner().Combine(f->body);
109+
}
107110
return f;
108111
};
109112
return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {});

src/tir/transforms/ir_utils.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,16 @@ std::pair<int32_t, int32_t> GetWmmaFragmentDimSize(const std::string& shape_str,
692692
return std::pair<int32_t, int32_t>(0, 0);
693693
}
694694

695+
std::optional<bool> IsHostFunc(const PrimFunc& func) {
696+
if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) {
697+
return true;
698+
} else if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
699+
return target.value()->HasKey("cpu");
700+
} else {
701+
return std::nullopt;
702+
}
703+
}
704+
695705
namespace transform {
696706
Pass ConvertSSA() {
697707
auto pass_func = [](IRModule mod, PassContext ctx) {

src/tir/transforms/ir_utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <tvm/tir/op.h>
3535

3636
#include <limits>
37+
#include <optional>
3738
#include <string>
3839
#include <unordered_map>
3940
#include <utility>
@@ -351,6 +352,17 @@ CollectStorageAlignAnnotation(const Stmt& body);
351352
std::pair<int32_t, int32_t> GetWmmaFragmentDimSize(const std::string& shape_str,
352353
const std::string& scope);
353354

355+
/*! \brief Check if a PrimFunc is a host function
356+
*
357+
* \param func The function to be inspected
358+
*
359+
* \return True if the function is known to run on the host, false if
360+
* the function is known to run on the device. If it cannot be
361+
* determined (e.g. a function without a tvm::attr::kTarget
362+
* attribute), returns std::nullopt.
363+
*/
364+
std::optional<bool> IsHostFunc(const PrimFunc& func);
365+
354366
} // namespace tir
355367
} // namespace tvm
356368
#endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_

tests/python/unittest/test_tir_transform_combine_context_call.py

Lines changed: 82 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,95 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
1718
import tvm
18-
from tvm import te
19+
import tvm.testing
20+
21+
from tvm.script import tir as T, ir as I
22+
23+
24+
def _device_context(dev_type, dev_id):
25+
ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id)
26+
return tvm.tir.Call("handle", "tir.tvm_thread_context", [ctx])
27+
28+
29+
class TestCombineContextsInLoop(tvm.testing.CompareBeforeAfter):
30+
"""Device contexts should be hoisted and merged"""
31+
32+
transform = tvm.tir.transform.CombineContextCall()
33+
34+
def before(self):
35+
@T.prim_func
36+
def func(dev_type: T.int32, n: T.int32):
37+
T.func_attr({"target": T.target("llvm")})
38+
A = T.allocate([n], "float32", "global")
39+
for i in range(n):
40+
T.call_extern(
41+
"int32",
42+
"fadd",
43+
_device_context(dev_type, 0),
44+
A,
45+
)
46+
for j in range(10):
47+
T.call_extern(
48+
"int32",
49+
"fadd",
50+
_device_context(dev_type, 1),
51+
A,
52+
)
53+
T.call_extern(
54+
"int32",
55+
"fadd",
56+
_device_context(dev_type, 0),
57+
A,
58+
)
59+
60+
return func
1961

62+
def expected(dev_type: T.int32, n: T.int32):
63+
T.func_attr({"target": T.target("llvm")})
64+
ctx_cache_: T.handle = T.call_extern("handle", "device_context", dev_type, 0)
65+
ctx_cache__1: T.handle = T.call_extern("handle", "device_context", dev_type, 1)
66+
A = T.allocate([n], "float32", "global")
67+
for i in range(n):
68+
T.call_extern("int32", "fadd", ctx_cache_, A)
69+
for j in range(10):
70+
T.call_extern("int32", "fadd", ctx_cache__1, A)
71+
T.call_extern("int32", "fadd", ctx_cache_, A)
2072

21-
def test_for():
22-
dev_type = te.var("dev_type")
2373

24-
def device_context(dev_id):
25-
ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id)
26-
return tvm.tir.Call("handle", "tir.tvm_thread_context", [ctx])
74+
class TestCombineContextsInLoopWithoutTarget(TestCombineContextsInLoop):
75+
"""CombineContextCall only updates host-side functions"""
2776

28-
ib = tvm.tir.ir_builder.create()
29-
n = te.var("n")
30-
A = ib.allocate("float32", n, name="A", scope="global")
31-
with ib.for_range(0, n, name="i") as i:
32-
ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A.asobject().data))
33-
with ib.for_range(0, 10, name="j") as j:
34-
ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(1), A.asobject().data))
35-
ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A.asobject().data))
36-
body = ib.get()
37-
mod = tvm.IRModule({"func": tvm.tir.PrimFunc([dev_type, n], body)})
77+
def before(self):
78+
@T.prim_func
79+
def func(dev_type: T.int32, n: T.int32):
80+
A = T.allocate([n], "float32", "global")
81+
for i in range(n):
82+
T.call_extern(
83+
"int32",
84+
"fadd",
85+
_device_context(dev_type, 0),
86+
A,
87+
)
88+
for j in range(10):
89+
T.call_extern(
90+
"int32",
91+
"fadd",
92+
_device_context(dev_type, 1),
93+
A,
94+
)
95+
T.call_extern(
96+
"int32",
97+
"fadd",
98+
_device_context(dev_type, 0),
99+
A,
100+
)
38101

39-
mod = tvm.tir.transform.CombineContextCall()(mod)
102+
return func
40103

41-
assert mod["func"].body.value.dtype == "handle"
42-
assert mod["func"].body.body.value.dtype == "handle"
104+
expected = before
43105

44106

45107
if __name__ == "__main__":
46-
test_for()
108+
tvm.testing.main()

0 commit comments

Comments
 (0)