Skip to content

Commit 0af9ff9

Browse files
authored
[TIR] Restrict tir.transform.LowerTVMBuiltin to host functions (#14944)
* [Bugfix][TIR][VTA] Update host-side target, even without device func This resolves an issue introduced by the combination of #14918 and #14945. The bug occurred for targets that do not require device-side codegen, but do require a `device_type` other than `kDLCPU`. It wasn't caught by CI, as the issue only occurred with the combination of both PRs. 1. #14918 updated `SplitHostDevice` to only modify the `"target"` attribute when a device-side function has been extracted. 2. For VTA, there is no device-side function, as everything is done through host-side API calls. 3. From (1) and (2), the VTA examples kept the target `T.target("ext_dev", host="llvm")` after the `SplitHostDevice` pass, instead of being updated to `T.target("llvm")`. 4. #14945 restricted CombineContextCall to only apply to host-side passes. 5. From (4) and (5), the `CombineContextCall` pass was no longer applied to the VTA context calls. This PR fixes `SplitHostDevice`, updating the target from `T.target("ext_dev", host="llvm")` to `T.target("llvm")`, even if no device sections have been extracted from the function. * [TIR] Restrict tir.transform.LowerTVMBuiltin to host functions Previously, the `tir.transform.LowerTVMBuiltin` pass applied to all functions in an `IRModule`, but was only applied to modules that contain only host functions. This commit updates `tir.transform.LowerTVMBuiltin` to apply only to host functions. * Updated "stackvm" target to have "cpu" key. With the presence/absence of the "cpu" key in a target used to determine whether host-only calls should be run, should make sure to add it to "stackvm". * Update IsHostFunc() to use "host" tag instead of "cpu" Current CI failures due to LowerTVMBuiltin not running on "hexagon" target, and would like to avoid conflating cpu/host. * Avoid "host" tag for now * Update HEXAGON_AOT_LLVM_TARGET to be recognized as host
1 parent 2618091 commit 0af9ff9

File tree

4 files changed

+103
-9
lines changed

4 files changed

+103
-9
lines changed

python/tvm/contrib/hexagon/pytest_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
RNG_SEEDED = False
4141

4242
HEXAGON_AOT_LLVM_TARGET = (
43-
"llvm -keys=hexagon "
43+
"llvm -keys=hexagon,cpu "
4444
"-mattr=+hvxv68,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp "
4545
"-mcpu=hexagonv68 -mtriple=hexagon"
4646
)

src/target/target_kind.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,10 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
422422
.add_attr_option<Array<String>>("llvm-options")
423423
.add_attr_option<Integer>("num-cores")
424424
.add_attr_option<Integer>("vtcm-capacity")
425-
.set_default_keys({"hexagon"});
425+
.set_default_keys({"hexagon", "cpu"});
426426

427-
TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU);
427+
TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break
428+
.set_default_keys({"cpu"});
428429

429430
TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev);
430431

src/tir/transforms/lower_tvm_builtin.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,9 +629,11 @@ namespace transform {
629629

630630
Pass LowerTVMBuiltin() {
631631
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
632-
auto* n = f.CopyOnWrite();
633-
n->body = BuiltinLower().Build(n->body);
634-
VLOG(2) << "LowerTVMBuiltin: " << f;
632+
if (IsHostFunc(f).value_or(false)) {
633+
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
634+
f.CopyOnWrite()->body = BuiltinLower().Build(f->body);
635+
VLOG(2) << "LowerTVMBuiltin: " << f;
636+
}
635637
return f;
636638
};
637639
return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});

tests/python/unittest/test_tir_transform_lower_tvm_builtin.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def check_packed_func(target="llvm"):
5656
# Construct a valid IRModule to be lowered:
5757
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([a_buffer, b_buffer, c_buffer], stmt))
5858

59-
target = tvm.target.Target(target)
59+
target = tvm.target.Target(target, host="llvm")
6060
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
6161
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod)
6262
mod = tvm.tir.transform.MakePackedAPI()(mod)
@@ -189,6 +189,97 @@ def variance4(rxplaceholder: T.Buffer((T.int64(1), T.int64(32), T.int64(25690112
189189
tvm.build(func, target="llvm") # should not crash
190190

191191

192+
class TestLowerDeviceAllocate(tvm.testing.CompareBeforeAfter):
193+
"""Device allocations are lowered to TVMBackend* calls
194+
195+
This test validates the current behavior of LowerTVMBuiltin. This
196+
unit test may be improved in the future by addressing:
197+
198+
- The AttrStmt for "storage_alignment" occurs outside the LetStmt
199+
that defines the pointer, which is currently required by
200+
CodeGenLLVM. This fails to match when `map_free_vars=False`
201+
(default), because the first occurrence is undefined.
202+
203+
- The call to TVMBackendFreeWorkspace uses the allocated pointer,
204+
but occurs outside the LetStmt.
205+
206+
- TVMScript always produces "handle" dtype for
207+
`T.tvm_throw_last_error`, while LowerTVMBuiltin outputs "int32"
208+
dtype.
209+
"""
210+
211+
transform = tvm.tir.transform.LowerTVMBuiltin()
212+
213+
def before():
214+
T.func_attr({"target": T.target("llvm")})
215+
T.attr("dummy", "device_type", 2) # kDLCuda
216+
T.attr("dummy", "device_id", 0)
217+
ptr = T.allocate([16], "float32")
218+
buf = T.decl_buffer(16, "float32", data=ptr)
219+
buf[0] = 0.0
220+
221+
def expected():
222+
T.func_attr({"target": T.target("llvm")})
223+
ptr = T.handle("float32", "global")
224+
T.attr(ptr, "storage_alignment", 64)
225+
with T.LetStmt(T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32), var=ptr):
226+
if T.isnullptr(ptr):
227+
T.Call("int32", "tir.tvm_throw_last_error", [])
228+
buf = T.decl_buffer((16,), data=ptr)
229+
buf[0] = T.float32(0)
230+
if T.TVMBackendFreeWorkspace(2, 0, ptr) != 0:
231+
T.Call("int32", "tir.tvm_throw_last_error", [])
232+
233+
def test_compare(self, before, expected, transform):
234+
after = transform(before)
235+
tvm.ir.assert_structural_equal(after, expected, map_free_vars=True)
236+
237+
238+
class TestLowerCPUAllocation(tvm.testing.CompareBeforeAfter):
239+
"""CPU allocations can be handled at codegen time"""
240+
241+
transform = tvm.tir.transform.LowerTVMBuiltin()
242+
243+
def before():
244+
T.func_attr({"target": T.target("llvm")})
245+
T.attr("dummy", "device_type", 1) # kDLCPU
246+
T.attr("dummy", "device_id", 0)
247+
ptr = T.allocate([16], "float32")
248+
buf = T.decl_buffer(16, "float32", data=ptr)
249+
buf[0] = 0.0
250+
251+
def expected():
252+
T.func_attr({"target": T.target("llvm")})
253+
ptr = T.allocate([16], "float32")
254+
buf = T.decl_buffer(16, "float32", data=ptr)
255+
buf[0] = 0.0
256+
257+
258+
class TestLowerAllocateRequiresDeviceID(tvm.testing.CompareBeforeAfter):
259+
transform = tvm.tir.transform.LowerTVMBuiltin()
260+
261+
def before():
262+
T.func_attr({"target": T.target("llvm")})
263+
T.attr("dummy", "device_id", 0)
264+
ptr = T.allocate([16], "float32")
265+
buf = T.decl_buffer(16, "float32", data=ptr)
266+
buf[0] = 0.0
267+
268+
expected = tvm.TVMError
269+
270+
271+
class TestLowerAllocateRequiresDeviceType(tvm.testing.CompareBeforeAfter):
272+
transform = tvm.tir.transform.LowerTVMBuiltin()
273+
274+
def before():
275+
T.func_attr({"target": T.target("llvm")})
276+
T.attr("dummy", "device_id", 0)
277+
ptr = T.allocate([16], "float32")
278+
buf = T.decl_buffer(16, "float32", data=ptr)
279+
buf[0] = 0.0
280+
281+
expected = tvm.TVMError
282+
283+
192284
if __name__ == "__main__":
193-
test_call_packed_return_non_i32()
194-
test_lower_packed_func()
285+
tvm.testing.main()

0 commit comments

Comments
 (0)