Skip to content

Commit f65e24d

Browse files
committed
Improve PrimFunc readability
1 parent 52847a0 commit f65e24d

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

tests/python/tir-schedule/test_tir_schedule_cache_read_write.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,11 +1384,10 @@ def test_cache_read_allocate_const():
13841384
def before(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
13851385
B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
13861386
B_buf = T.decl_buffer((8), dtype="float32", data=B)
1387-
for i in T.serial(128):
1387+
for i in range(8):
13881388
with T.block("C"):
1389-
vi = T.axis.remap("S", [i])
1389+
vi = T.axis.spatial(8, i)
13901390
T.reads(A[vi], B_buf[vi])
1391-
T.writes(C[vi])
13921391
C[vi] = A[vi] + B_buf[vi]
13931392

13941393
@T.prim_func
@@ -1400,20 +1399,15 @@ def expected(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
14001399
for ax0 in range(8):
14011400
with T.block("A_global"):
14021401
v0 = T.axis.spatial(8, ax0)
1403-
T.reads(A[v0])
1404-
T.writes(A_global[v0])
14051402
A_global[v0] = A[v0]
14061403
for ax0 in range(8):
14071404
with T.block("B_buf_global"):
14081405
v0 = T.axis.spatial(8, ax0)
14091406
T.reads(B_buf[v0])
1410-
T.writes(B_buf_global[v0])
14111407
B_buf_global[v0] = B_buf[v0]
1412-
for i in range(128):
1408+
for i in range(8):
14131409
with T.block("C"):
1414-
vi = T.axis.spatial(128, i)
1415-
T.reads(A_global[vi], B_buf_global[vi])
1416-
T.writes(C[vi])
1410+
vi = T.axis.spatial(8, i)
14171411
C[vi] = A_global[vi] + B_buf_global[vi]
14181412

14191413
sch = tir.Schedule(before)

0 commit comments

Comments
 (0)