Skip to content

Commit bb768b0

Browse files
committed
[TIR] Update SplitHostDevice to post-process with ConvertSSA
Avoid duplicate variable defitions between the host and device PrimFunc.
1 parent c2c4572 commit bb768b0

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

src/tir/transforms/split_host_device.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ Pass SplitHostDevice() {
273273
}
274274
}
275275
mod->Update(device_mod);
276-
return mod;
276+
return ConvertSSA()(mod);
277277
};
278278

279279
return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {});

tests/python/unittest/test_tir_transform_split_host_device.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tvm
1818
from tvm import te
1919
import tvm.testing
20+
from tvm.script import tir as T, ir as I
2021

2122

2223
@tvm.testing.requires_cuda
@@ -48,5 +49,29 @@ def test_split_host_device_func_attr():
4849
assert fdevice.attrs["tir.is_global_func"].value
4950

5051

52+
def test_ssa_across_entire_module():
53+
"""The host and device functions should not share TIR vars
54+
55+
Any arguments that are passed from the host to the device should
56+
be in terms of independent TIR variables.
57+
"""
58+
59+
@I.ir_module
60+
class before:
61+
@T.prim_func
62+
def main():
63+
T.func_attr({"global_symbol": "main", "target": T.target("cuda")})
64+
for i in range(16):
65+
T.attr(0, "device_scope", 0)
66+
for j in range(16):
67+
T.evaluate(i)
68+
69+
after = tvm.tir.transform.SplitHostDevice()(before)
70+
loop_var = after["main"].body.loop_var
71+
param_var = after["main_kernel0"].params[0]
72+
73+
assert not loop_var.same_as(param_var)
74+
75+
5176
if __name__ == "__main__":
5277
test_split_host_device_func_attr()

0 commit comments

Comments
 (0)