Skip to content

Commit cf2753e

Browse files
authored
[Relax][UnitTest] Validate IRModule with multiple targets (#16960)
[Relax][UnitTest] Validate IRModule with multiple targets This commit adds a unit test to verify that a single `IRModule` can contain functions that will be used on multiple distinct targets. Previously, this test case caused errors when running the `LegalizeOps` and `ApplyDefaultSchedule` transforms.
1 parent 604fbbd commit cf2753e

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

tests/python/relax/test_vm_build.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,5 +1246,64 @@ def test_set_input_get_failure_rpc(exec_mode):
12461246
run_on_rpc(TestVMSetInput, set_input_attempt_get, exec_mode)
12471247

12481248

1249+
@tvm.testing.requires_gpu
1250+
def test_relax_module_with_multiple_targets(exec_mode):
1251+
"""Relax functions may contain kernels for multiple targets
1252+
1253+
In this example, the module contains one function to execute on
1254+
LLVM, and one function to execute on CUDA.
1255+
1256+
"""
1257+
1258+
@I.ir_module
1259+
class Module:
1260+
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})
1261+
1262+
@R.function
1263+
def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")):
1264+
C = R.add(A, B)
1265+
return C
1266+
1267+
@R.function
1268+
def func_llvm(
1269+
A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm")
1270+
):
1271+
C = R.add(A, B)
1272+
return C
1273+
1274+
seq = tvm.ir.transform.Sequential(
1275+
[
1276+
tvm.relax.transform.LegalizeOps(),
1277+
tvm.dlight.ApplyDefaultSchedule(tvm.dlight.gpu.Fallback()),
1278+
],
1279+
name="LegalizeAndSchedule",
1280+
)
1281+
with tvm.target.Target("cuda"):
1282+
built = tvm.relax.build(seq(Module))
1283+
1284+
np_A = np.random.random([32, 32]).astype("float32")
1285+
np_B = np.random.random([32, 32]).astype("float32")
1286+
1287+
dev_llvm = tvm.device("llvm")
1288+
vm_llvm = tvm.relax.VirtualMachine(built, device=dev_llvm)
1289+
llvm_output = vm_llvm["func_llvm"](
1290+
tvm.nd.array(np_A, dev_llvm),
1291+
tvm.nd.array(np_B, dev_llvm),
1292+
)
1293+
1294+
dev_cuda = tvm.device("cuda")
1295+
vm_cuda = tvm.relax.VirtualMachine(built, device=dev_cuda)
1296+
1297+
cuda_output = vm_cuda["func_cuda"](
1298+
tvm.nd.array(np_A, dev_cuda),
1299+
tvm.nd.array(np_B, dev_cuda),
1300+
)
1301+
1302+
np_C = np_A + np_B
1303+
1304+
tvm.testing.assert_allclose(llvm_output.numpy(), np_C)
1305+
tvm.testing.assert_allclose(cuda_output.numpy(), np_C)
1306+
1307+
12491308
if __name__ == "__main__":
12501309
tvm.testing.main()

0 commit comments

Comments
 (0)