Skip to content

Commit ff6caa8

Browse files
add printer/parser test, fix lint
1 parent 2c70beb commit ff6caa8

File tree

4 files changed

+54
-10
lines changed

4 files changed

+54
-10
lines changed

src/tir/op/op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,12 @@ PrimExpr break_loop(Span span) {
265265

266266
TVM_FFI_STATIC_INIT_BLOCK({
267267
namespace refl = tvm::ffi::reflection;
268-
refl::GlobalDef().def("tir.thread_return", thread_return)
269-
.def("tir.continue_loop", continue_loop)
270-
.def("tir.break_loop", break_loop);
268+
refl::GlobalDef()
269+
.def("tir.thread_return", thread_return)
270+
.def("tir.continue_loop", continue_loop)
271+
.def("tir.break_loop", break_loop);
271272
});
272273

273-
274274
// maximum and min limits
275275
PrimExpr max_value(const DataType& dtype, Span span) {
276276
using namespace tir;

tests/python/tir-base/test_tir_base.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,9 @@ def func(In: T.Buffer[(2,), "int32"], Out: T.Buffer[(2,), "int32"]):
135135
func = build_tir_func(func)
136136
a = np.asarray([49, 8], "int32")
137137
b = np.zeros([2], "int32")
138-
a = tvm.nd.array(a)
139-
b = tvm.nd.array(b)
140138
func(a, b)
141-
assert b.numpy()[0] == 13
142-
assert b.numpy()[1] == 9
139+
assert b[0] == 13
140+
assert b[1] == 9
143141

144142

145143
def test_continue_loop():
@@ -165,8 +163,8 @@ def func(Out: T.Buffer[(2,), "int32"]):
165163
b = np.zeros([2], "int32")
166164
b = tvm.nd.array(b)
167165
func(b)
168-
assert b.numpy()[0] == 34
169-
assert b.numpy()[1] == 5 # 6, 12, 18, 24, 30
166+
assert b[0] == 34
167+
assert b[1] == 5 # 6, 12, 18, 24, 30
170168

171169

172170
def test_exception():

tests/python/tvmscript/test_tvmscript_printer_tir.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,5 +1046,34 @@ def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
10461046
_assert_print(main, expected_output)
10471047

10481048

1049+
def test_func_with_loop_jumps():
1050+
from tvm.script import tir as T
1051+
1052+
@T.prim_func
1053+
def main(a: T.handle, b: T.handle):
1054+
A = T.match_buffer(a, (4,), "float32")
1055+
B = T.match_buffer(b, (4,), "float32")
1056+
for i in range(1000):
1057+
if i % 13 == 0:
1058+
A[1] = A[1] + 1
1059+
continue
1060+
if A[0] >= B[0]:
1061+
break
1062+
1063+
expected_output = """
1064+
# from tvm.script import tir as T
1065+
1066+
@T.prim_func
1067+
def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
1068+
for i in range(1000):
1069+
if i % 13 == 0:
1070+
A[1] = A[1] + T.float32(1.0)
1071+
T.continue_loop()
1072+
if A[0] >= B[0]:
1073+
T.break_loop()
1074+
"""
1075+
_assert_print(main, expected_output)
1076+
1077+
10491078
if __name__ == "__main__":
10501079
tvm.testing.main()

tests/python/tvmscript/test_tvmscript_roundtrip.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4002,6 +4002,22 @@ def func(
40024002
return func
40034003

40044004

4005+
def func_with_loop_jumps():
4006+
@T.prim_func
4007+
def func(In: T.Buffer((1,), "int32"), Out: T.Buffer((2,), "int32")):
4008+
Out[0] = 0
4009+
Out[1] = 0
4010+
for i in range(1000):
4011+
if i % 13 == 0:
4012+
Out[1] = Out[1] + 1
4013+
continue
4014+
Out[0] = Out[0] + 1
4015+
if Out[0] >= In[0]:
4016+
break
4017+
4018+
return func
4019+
4020+
40054021
def op_of_literal():
40064022
op_list = [
40074023
(T.exp, 0),
@@ -4220,6 +4236,7 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")):
42204236
return_zero_private,
42214237
return_zero_private_with_attr,
42224238
func_attr_with_list,
4239+
func_with_loop_jumps,
42234240
*op_of_literal(),
42244241
*relax_match_cast_struct_info_proxy(),
42254242
relax_symbolic_size_var,

0 commit comments

Comments
 (0)