Skip to content

Commit f5c5912

Browse files
committed
Update reference examples for inject software pipeline unit tests
1 parent 17beccb commit f5c5912

File tree

1 file changed

+51
-53
lines changed

1 file changed

+51
-53
lines changed

tests/python/unittest/test_tir_transform_inject_software_pipeline.py

Lines changed: 51 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def transformed_simple_compute(
139139
for i in T.serial(0, 15):
140140
with T.block():
141141
T.reads([A[tx, i + 1]])
142-
T.writes([B[(i + 1) % 2, tx, 0]])
143-
B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
142+
T.writes([B[1 - i % 2, tx, 0]])
143+
B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
144144
with T.block():
145145
T.reads([B[i % 2, tx, 0]])
146146
T.writes([C[tx, i]])
@@ -202,8 +202,8 @@ def transformed_simple_compute_with_other_annotation(
202202
):
203203
with T.block():
204204
T.reads([A[tx, i + 1]])
205-
T.writes([B[(i + 1) % 2, tx, 0]])
206-
B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
205+
T.writes([B[1 - i % 2, tx, 0]])
206+
B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
207207
with T.block():
208208
T.reads([B[i % 2, tx, 0]])
209209
T.writes([C[tx, i]])
@@ -266,7 +266,7 @@ def transformed_three_stage_compute(
266266
T.where(i == 1)
267267
T.reads(B[0:2, tx, 0])
268268
T.writes(C[0:2, tx, 0])
269-
C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
269+
C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2)
270270
with T.block():
271271
T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0])
272272
T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14])
@@ -278,7 +278,7 @@ def transformed_three_stage_compute(
278278
with T.block():
279279
T.reads(B[0:2, tx, 0])
280280
T.writes(C[0:2, tx, 0])
281-
C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
281+
C[1 - i % 2, tx, 0] = B[1 - i % 2, tx, 0] + T.float32(2)
282282
with T.block():
283283
T.reads(C[0:2, tx, 0])
284284
T.writes(D[tx, i])
@@ -291,7 +291,7 @@ def transformed_three_stage_compute(
291291
T.where(i < 1)
292292
T.reads(B[0:2, tx, 0])
293293
T.writes(C[0:2, tx, 0])
294-
C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
294+
C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2)
295295
with T.block():
296296
T.reads(C[0:2, tx, 0])
297297
T.writes(D[tx, i + 14])
@@ -391,12 +391,12 @@ def transformed_dag_interleaving(
391391
BS[tx, 0] = B[tx, i + 1] + T.float32(2)
392392
with T.block():
393393
T.reads(AS[tx, 0])
394-
T.writes(AL[(i + 1) % 2, 0, 0])
395-
AL[(i + 1) % 2, 0, 0] = AS[tx, 0]
394+
T.writes(AL[1 - i % 2, 0, 0])
395+
AL[1 - i % 2, 0, 0] = AS[tx, 0]
396396
with T.block():
397397
T.reads(BS[tx, 0])
398-
T.writes(BL[(i + 1) % 2, 0, 0])
399-
BL[(i + 1) % 2, 0, 0] = BS[tx, 0]
398+
T.writes(BL[1 - i % 2, 0, 0])
399+
BL[1 - i % 2, 0, 0] = BS[tx, 0]
400400
with T.block():
401401
T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0])
402402
T.writes(C[tx, i])
@@ -475,12 +475,12 @@ def transformed_nested_pipeline_simple(
475475
for i in T.serial(0, 15):
476476
with T.block():
477477
T.reads([A[tx, i + 1, 0:16]])
478-
T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
478+
T.writes([A_shared[1 - i % 2, tx, 0, 0:16]])
479479
for j in T.serial(0, 16):
480480
with T.block():
481481
T.reads([A[tx, i + 1, j]])
482-
T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
483-
A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j]
482+
T.writes([A_shared[1 - i % 2, tx, 0, j]])
483+
A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j]
484484
with T.block():
485485
T.reads([A_shared[i % 2, tx, i, 0]])
486486
T.writes([B[0, tx, i, 0]])
@@ -491,10 +491,10 @@ def transformed_nested_pipeline_simple(
491491
for j in T.serial(0, 15):
492492
with T.block():
493493
T.reads([A_shared[i % 2, tx, i, j + 1]])
494-
T.writes([B[(j + 1) % 2, tx, i, 0]])
495-
B[(j + 1) % 2, tx, i, 0] = A_shared[
496-
i % 2, tx, 0, j + 1
497-
] * T.float32(2)
494+
T.writes([B[1 - j % 2, tx, i, 0]])
495+
B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 0, j + 1] * T.float32(
496+
2
497+
)
498498
with T.block():
499499
T.reads([B[j % 2, tx, i, 0]])
500500
T.writes([C[tx, i, j]])
@@ -516,8 +516,8 @@ def transformed_nested_pipeline_simple(
516516
for j in T.serial(0, 15):
517517
with T.block():
518518
T.reads([A_shared[1, tx, 15, j + 1]])
519-
T.writes([B[(j + 1) % 2, tx, 15, 0]])
520-
B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
519+
T.writes([B[1 - j % 2, tx, 15, 0]])
520+
B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
521521
with T.block():
522522
T.reads([B[j % 2, tx, 15, 0]])
523523
T.writes([C[tx, 15, j]])
@@ -603,30 +603,30 @@ def transformed_nested_pipeline_prefetch_inner(
603603
for i in T.serial(0, 15):
604604
with T.block():
605605
T.reads([A[tx, i + 1, 0:16]])
606-
T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
606+
T.writes([A_shared[1 - i % 2, tx, 0, 0:16]])
607607
for j in T.serial(0, 16):
608608
with T.block():
609609
T.reads([A[tx, i + 1, j]])
610-
T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
611-
A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j]
610+
T.writes([A_shared[1 - i % 2, tx, 0, j]])
611+
A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j]
612612
with T.block():
613613
T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]])
614614
T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
615615
for j in T.serial(0, 15):
616616
with T.block():
617617
T.reads([A_shared[i % 2, tx, i, j + 1]])
618-
T.writes([B[(j + 1) % 2, tx, i, 0]])
619-
B[(j + 1) % 2, tx, i, 0] = A_shared[
620-
i % 2, tx, 0, j + 1
621-
] * T.float32(2)
618+
T.writes([B[1 - j % 2, tx, i, 0]])
619+
B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 0, j + 1] * T.float32(
620+
2
621+
)
622622
with T.block():
623623
T.reads([B[j % 2, tx, i, 0]])
624624
T.writes([C[tx, i, j]])
625625
C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
626626
with T.block():
627-
T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]])
627+
T.reads([A_shared[1 - i % 2, tx, i + 1, 0]])
628628
T.writes([B[0, tx, i + 1, 0]])
629-
B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2)
629+
B[0, tx, i + 1, 0] = A_shared[1 - i % 2, tx, 0, 0] * T.float32(2)
630630
with T.block():
631631
T.reads([B[1, tx, i, 0]])
632632
T.writes([C[tx, i, 15]])
@@ -640,8 +640,8 @@ def transformed_nested_pipeline_prefetch_inner(
640640
for j in T.serial(0, 15):
641641
with T.block():
642642
T.reads([A_shared[1, tx, 15, j + 1]])
643-
T.writes([B[(j + 1) % 2, tx, 15, 0]])
644-
B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
643+
T.writes([B[1 - j % 2, tx, 15, 0]])
644+
B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
645645
with T.block():
646646
T.reads([B[j % 2, tx, 15, 0]])
647647
T.writes([C[tx, 15, j]])
@@ -768,8 +768,8 @@ def transformed_nested_pipeline_interleaving(
768768
for j in T.serial(0, 15):
769769
with T.block():
770770
T.reads([A_local[tx, i, j + 1]])
771-
T.writes([B[(j + 1) % 2, tx, i, 0]])
772-
B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2)
771+
T.writes([B[1 - j % 2, tx, i, 0]])
772+
B[1 - j % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2)
773773
with T.block():
774774
T.reads([B[j % 2, tx, i, 0]])
775775
T.writes([C[tx, i, j]])
@@ -799,8 +799,8 @@ def transformed_nested_pipeline_interleaving(
799799
for j in T.serial(0, 15):
800800
with T.block():
801801
T.reads([A_local[tx, 15, j + 1]])
802-
T.writes([B[(j + 1) % 2, tx, 15, 0]])
803-
B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2)
802+
T.writes([B[1 - j % 2, tx, 15, 0]])
803+
B[1 - j % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2)
804804
with T.block():
805805
T.reads([B[j % 2, tx, 15, 0]])
806806
T.writes([C[tx, 15, j]])
@@ -929,27 +929,25 @@ def transformed_nested_pipeline_double_buffer(
929929
for j in T.serial(0, 15):
930930
with T.block():
931931
T.reads([A_local[i % 2, tx, i, j + 1]])
932-
T.writes([B[(j + 1) % 2, tx, i, 0]])
933-
B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32(
934-
2
935-
)
932+
T.writes([B[1 - j % 2, tx, i, 0]])
933+
B[1 - j % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32(2)
936934
with T.block():
937935
T.reads([B[j % 2, tx, i, 0]])
938936
T.writes([C[tx, i, j]])
939937
C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
940938
with T.block():
941939
T.reads([A_shared[tx, 0, 0:16]])
942-
T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]])
940+
T.writes([A_local[1 - i % 2, 0, 0, 0:16]])
943941
for j in T.serial(0, 16):
944942
with T.block():
945943
T.reads([A_shared[tx, 0, j]])
946-
T.writes([A_local[(i + 1) % 2, 0, 0, j]])
944+
T.writes([A_local[1 - i % 2, 0, 0, j]])
947945
T.block_attr({"double_buffer_scope": 0})
948-
A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j]
946+
A_local[1 - i % 2, 0, 0, j] = A_shared[tx, i + 1, j]
949947
with T.block():
950-
T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]])
948+
T.reads([A_local[1 - i % 2, tx, i + 1, 0]])
951949
T.writes([B[0, tx, i + 1, 0]])
952-
B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] * T.float32(2)
950+
B[0, tx, i + 1, 0] = A_local[1 - i % 2, 0, 0, 0] * T.float32(2)
953951
with T.block():
954952
T.reads([B[1, tx, i, 0]])
955953
T.writes([C[tx, i, 15]])
@@ -963,8 +961,8 @@ def transformed_nested_pipeline_double_buffer(
963961
for j in T.serial(0, 15):
964962
with T.block():
965963
T.reads([A_local[1, tx, 15, j + 1]])
966-
T.writes([B[(j + 1) % 2, tx, 15, 0]])
967-
B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2)
964+
T.writes([B[1 - j % 2, tx, 15, 0]])
965+
B[1 - j % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2)
968966
with T.block():
969967
T.reads([B[j % 2, tx, 15, 0]])
970968
T.writes([C[tx, 15, j]])
@@ -1135,7 +1133,7 @@ def ref(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]):
11351133
with T.block():
11361134
T.where(i + 1 < 16)
11371135
T.reads(A[tx, i + 1])
1138-
T.writes(B[(i + 1) % 2, tx, 0])
1136+
T.writes(B[1 - i % 2, tx, 0])
11391137
with T.attr(0, "async_commit_queue_scope", 0):
11401138
with T.attr(0, "async_scope", 1):
11411139
B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
@@ -1350,8 +1348,8 @@ def ref(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]) -> N
13501348
B[i % 2, tx, 0] = A[tx, i] * T.float32(2)
13511349
with T.block():
13521350
T.where(i == 1 and i - 1 < 16)
1353-
T.reads(B[(i + 1) % 2, tx, 0])
1354-
T.writes(C[(i + 1) % 2, tx, 0])
1351+
T.reads(B[1 - i % 2, tx, 0])
1352+
T.writes(C[1 - i % 2, tx, 0])
13551353
with T.attr(0, "async_commit_queue_scope", 1):
13561354
with T.attr(0, "async_wait_queue_scope", 0):
13571355
with T.attr(0, "async_wait_inflight_count", 1):
@@ -1372,8 +1370,8 @@ def ref(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]) -> N
13721370
B[(i + 2) % 2, tx, 0] = A[tx, i + 2] * T.float32(2)
13731371
with T.block():
13741372
T.where(i + 2 - 1 < 16)
1375-
T.reads(B[(i + 1) % 2, tx, 0])
1376-
T.writes(C[(i + 1) % 2, tx, 0])
1373+
T.reads(B[1 - i % 2, tx, 0])
1374+
T.writes(C[1 - i % 2, tx, 0])
13771375
with T.attr(0, "async_commit_queue_scope", 1):
13781376
with T.attr(0, "async_wait_queue_scope", 0):
13791377
with T.attr(0, "async_wait_inflight_count", 1):
@@ -1394,8 +1392,8 @@ def ref(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]) -> N
13941392
for i in T.unroll(2):
13951393
with T.block():
13961394
T.where(i + 16 - 1 < 16)
1397-
T.reads(B[(i + 1) % 2, tx, 0])
1398-
T.writes(C[(i + 1) % 2, tx, 0])
1395+
T.reads(B[1 - i % 2, tx, 0])
1396+
T.writes(C[1 - i % 2, tx, 0])
13991397
with T.attr(0, "async_commit_queue_scope", 1):
14001398
with T.attr(0, "async_wait_queue_scope", 0):
14011399
with T.attr(0, "async_wait_inflight_count", 0 - i):

0 commit comments

Comments
 (0)