Skip to content

Commit 23e6cf2

Browse files
committed
Fix unit tests
1 parent 3251680 commit 23e6cf2

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/python/relax/test_vm_cuda_graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="fl
3636
R.func_attr({"global_symbol": "main"})
3737
gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),))
3838
storage: R.Object = gv[0]
39-
alloc: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
39+
alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
4040
_: R.Tuple = cls.add(x, alloc)
4141
storage1: R.Object = gv[1]
4242
gv1: R.Tuple(R.Tensor(dtype="float32"), R.Object, R.Object) = (alloc, storage1, storage)
4343
gv2: R.Tuple(R.Tensor((16, 16), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, gv1, R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((16, 16), dtype="float32")),))
4444
storage2: R.Object = R.vm.alloc_storage(R.shape((1024,)), R.prim_value(0), R.dtype("uint8"))
45-
alloc3: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
45+
alloc3 = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
4646
lv4: R.Tensor((16, 16), dtype="float32") = gv2[0]
4747
_3: R.Tuple = cls.add(lv4, alloc3)
4848
lv5: R.Tensor(dtype="float32") = alloc3
@@ -71,12 +71,12 @@ def cuda_graph_capture(alloc: R.Tensor((16, 16), dtype="float32"), storage1: R.O
7171
cls = Module
7272
R.func_attr({"global_symbol": "cuda_graph_capture"})
7373
lv0: R.Tensor((16, 16), dtype="float32") = alloc
74-
alloc1: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
74+
alloc1 = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
7575
_1: R.Tuple = cls.add(lv0, alloc1)
7676
lv1: R.Tensor(dtype="float32") = alloc1
7777
lv2: R.Tuple(R.Tensor(dtype="float32")) = (lv1,)
7878
lv3: R.Tensor(dtype="float32") = lv2[0]
79-
alloc2: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
79+
alloc2 = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
8080
_2: R.Tuple = cls.add(lv3, alloc2)
8181
lv4: R.Tensor(dtype="float32") = alloc2
8282
gv: R.Tuple(R.Tensor(dtype="float32")) = (lv4,)

0 commit comments

Comments
 (0)