Skip to content

Commit 86c1546

Browse files
committed
[ARITH][BUGFIX] Fix a bug of iter map floormod(x,2) simplify
This PR fixes a previous bug introduced in itermap detection. Specifically, y - (x % 2) were simplified to y + (x % 2) - 1. Which is wrong. The working rule is y + ((x + 1) % 2) - 1, but that rule will change the base iterator which is not desirable here. We also removed the rule that simplifies (x + 1) % 2 => 1 - x % 2 as benefit is minimal and it introduces extra negative co-efficients that hurts analysis in general (as negative co-efficients are harder in many cases).
1 parent 4e07a8e commit 86c1546

File tree

6 files changed

+86
-65
lines changed

6 files changed

+86
-65
lines changed

src/arith/iter_affine_map.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -898,11 +898,6 @@ class IterMapRewriter : public ExprMutator {
898898
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);
899899

900900
static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
901-
if (sign < 0 && is_const_int(rhs->extent, 2)) {
902-
lhs->base -= rhs->scale;
903-
sign = 1;
904-
}
905-
906901
tir::ExprDeepEqual equal;
907902
for (size_t i = 0; i < lhs->args.size(); ++i) {
908903
IterSplitExpr lvalue = lhs->args[i];

src/arith/rewrite_simplify.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
306306

307307
TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) + floormod(x, 2), floordiv(x + 1, 2));
308308

309+
// Simplify (x + 1) % 2 + x % 2 => 1
310+
// NOTE: we should avoid simplifying (x + 1) %2 => 1 - x % 2 though
311+
// mainly because introducing extra negative signs to expression can harm itertaor
312+
// analysis which usually relies on positive itertator co-efficients.
313+
TVM_TRY_REWRITE_IF(floormod(x + c1, 2) + floormod(x, 2), OneWithTypeLike(x),
314+
floormod(c1.Eval()->value, 2) == 1);
315+
TVM_TRY_REWRITE_IF(floormod(x, 2) + floormod(x + c1, 2), OneWithTypeLike(x),
316+
floormod(c1.Eval()->value, 2) == 1);
317+
309318
// canonicalization rule
310319
// will try rewrite again after canonicalization.
311320

@@ -1018,10 +1027,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
10181027
TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2),
10191028
c2.Eval()->value > 0);
10201029

1021-
TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + c1, 2), floormod(x, 2) * (-1) + 1,
1022-
floormod(c1.Eval()->value, 2) == 1);
1023-
TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
1024-
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
1030+
// (x + 5) % 2 -> (x + 1) %2, (x + 3) % 3 => x
1031+
TVM_TRY_REWRITE_IF(
1032+
floormod(x + c1, c2), floormod(x + floormod(c1, c2), c2),
1033+
c2.Eval()->value > 0 && (c1.Eval()->value >= c2.Eval()->value || c1.Eval()->value < 0));
10251034

10261035
TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2),
10271036
c2.Eval()->value > 0);

tests/python/unittest/test_arith_canonical_simplify.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,5 +415,12 @@ def test_proddiv_simplify():
415415
ck.verify(tdiv(x * (2 * y) * 3, 3 * y * z), tdiv(x * 2, z))
416416

417417

418+
def test_floormod_two():
419+
ck = CanonicalChecker()
420+
flm = tvm.te.floormod
421+
x, y = te.var("x"), te.var("y")
422+
ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1)
423+
424+
418425
if __name__ == "__main__":
419426
tvm.testing.main()

tests/python/unittest/test_arith_iter_affine_map.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,14 +199,14 @@ def test_compound():
199199
assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)]))
200200

201201

202-
def test_compound_floormod_two():
202+
def test_compound_floormod_two_regression():
203203
x = tvm.tir.Var("x", "int32")
204204
fld = tvm.tir.floordiv
205205
flm = tvm.tir.floormod
206-
207-
# extent of 2 are normalized to positive scale
208-
assert_iter_sum_pattern(
209-
expect_dict={fld(x, 2) * 2 - flm(x, 2) + 1: (8, 0, 1)},
206+
# regression
207+
# extent of 2 of negative scale cannot be normalized
208+
assert_iter_sum_failure(
209+
[fld(x, 2) * 2 - flm(x, 2) + 1],
210210
dom_map=var_dom([(x, 8)]),
211211
)
212212

tests/python/unittest/test_arith_rewrite_simplify.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ class TestSubIndex(BaseCompare):
392392
TestCase(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3)),
393393
TestCase(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1),
394394
TestCase(fld(y, 3) * 3 - y, 0 - flm(y, 3)),
395-
TestCase(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6),
396-
TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5)),
395+
TestCase(y - fld(y - 6, 5) * 5, flm(y + 4, 5) + 6),
396+
TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + 4, 5)),
397397
TestCase(y - fld(y + z, 5) * 5, flm(y + z, 5) - z),
398398
TestCase(fld(y + z, 5) * 5 - y, z - flm(y + z, 5)),
399399
TestCase(y - fld(y - z, 5) * 5, flm(y - z, 5) + z),
@@ -554,13 +554,15 @@ class TestFloormodIndex(BaseCompare):
554554
TestCase(flm(x + 10, 2), flm(x, 2)),
555555
TestCase(flm(x + y * 10, 2), flm(x, 2)),
556556
TestCase(flm(x + y * 360, 16), flm(x + y * 8, 16)),
557-
TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1),
558557
TestCase(flm(x * (-10), 2), 0),
559558
TestCase(flm(x * (-10) + y, 2), flm(y, 2)),
560559
TestCase(flm(x + (-10), 2), flm(x, 2)),
561560
TestCase(flm(x + y * (-10), 2), flm(x, 2)),
562561
TestCase(flm(x * 32 + y, 64), flm(x, 2) * 32 + y, [y >= 0, y < 32]),
563562
TestCase(flm(x * 32 - y, 64), flm(x * 32 - y, 64), [y >= 0, y < 32]),
563+
# NOTE: the followng case is covered by canonical simplify
564+
# long range simplifcation in general can be covered by canonical simplify
565+
# TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1),
564566
)
565567

566568

@@ -574,13 +576,14 @@ class TestFloorModTwo(BaseCompare):
574576
require identifying more related terms in order to apply.
575577
576578
(x + c1)//2 - (x+c2)//2 => (x%2)*( c1%2 - c1%2 ) + (c1//2 - c2//2)
579+
580+
We should not introduce extra negative coeficient to iterators
581+
however during simplification
577582
"""
578583

579584
x, y, z = te.var("x"), te.var("y"), te.var("z")
580585
test_case = tvm.testing.parameter(
581586
# Removing offsets from floormod
582-
TestCase(flm(x + 1, 2), flm(x, 2) * (-1) + 1),
583-
TestCase(flm(x + 5, 2), flm(x, 2) * (-1) + 1),
584587
TestCase(flm(x, 2) + flm(x + 1, 2), 1),
585588
TestCase(flm(x + 1, 2) + flm(x, 2), 1),
586589
# Difference of floordiv yields floormod
@@ -592,8 +595,13 @@ class TestFloorModTwo(BaseCompare):
592595
# Sum of floordiv and floormod to yield floordiv
593596
TestCase(fld(x + 1, 2) - flm(x, 2), fld(x, 2)),
594597
TestCase(fld(x, 2) + flm(x, 2), fld(x + 1, 2)),
595-
# Removal of floormod where possible
596-
TestCase(flm(x + 1, 2) * 8192, x * (-8192) + 8192, [x >= 0, x < 2]),
598+
# regression: although we can rewrite (x + 1) %2 => 1 - x%2
599+
# doing so would introduce negative co-efficient to iterators
600+
# which makes later iter map detection harder, in principle we
601+
# should not introduce additional negative signs of iterator in rewriting
602+
TestCase(flm(x + 1, 2), flm(x + 1, 2)),
603+
TestCase(flm(x + 5, 2), flm(x + 1, 2)),
604+
TestCase(flm(x + 1, 2) * 8192, flm(x + 1, 2) * 8192, [x >= 0, x < 2]),
597605
)
598606

599607

tests/python/unittest/test_tir_transform_inject_software_pipeline.py

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def transformed_simple_compute(
139139
for i in T.serial(0, 15):
140140
with T.block():
141141
T.reads([A[tx, i + 1]])
142-
T.writes([B[1 - i % 2, tx, 0]])
143-
B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
142+
T.writes([B[(i + 1) % 2, tx, 0]])
143+
B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
144144
with T.block():
145145
T.reads([B[i % 2, tx, 0]])
146146
T.writes([C[tx, i]])
@@ -202,8 +202,8 @@ def transformed_simple_compute_with_other_annotation(
202202
):
203203
with T.block():
204204
T.reads([A[tx, i + 1]])
205-
T.writes([B[1 - i % 2, tx, 0]])
206-
B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
205+
T.writes([B[(i + 1) % 2, tx, 0]])
206+
B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
207207
with T.block():
208208
T.reads([B[i % 2, tx, 0]])
209209
T.writes([C[tx, i]])
@@ -266,7 +266,7 @@ def transformed_three_stage_compute(
266266
T.where(i == 1)
267267
T.reads(B[0:2, tx, 0])
268268
T.writes(C[0:2, tx, 0])
269-
C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2)
269+
C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
270270
with T.block():
271271
T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0])
272272
T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14])
@@ -278,7 +278,7 @@ def transformed_three_stage_compute(
278278
with T.block():
279279
T.reads(B[0:2, tx, 0])
280280
T.writes(C[0:2, tx, 0])
281-
C[1 - i % 2, tx, 0] = B[1 - i % 2, tx, 0] + T.float32(2)
281+
C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
282282
with T.block():
283283
T.reads(C[0:2, tx, 0])
284284
T.writes(D[tx, i])
@@ -291,7 +291,7 @@ def transformed_three_stage_compute(
291291
T.where(i < 1)
292292
T.reads(B[0:2, tx, 0])
293293
T.writes(C[0:2, tx, 0])
294-
C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2)
294+
C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
295295
with T.block():
296296
T.reads(C[0:2, tx, 0])
297297
T.writes(D[tx, i + 14])
@@ -391,12 +391,12 @@ def transformed_dag_interleaving(
391391
BS[tx, 0] = B[tx, i + 1] + T.float32(2)
392392
with T.block():
393393
T.reads(AS[tx, 0])
394-
T.writes(AL[1 - i % 2, 0, 0])
395-
AL[1 - i % 2, 0, 0] = AS[tx, 0]
394+
T.writes(AL[(i + 1) % 2, 0, 0])
395+
AL[(i + 1) % 2, 0, 0] = AS[tx, 0]
396396
with T.block():
397397
T.reads(BS[tx, 0])
398-
T.writes(BL[1 - i % 2, 0, 0])
399-
BL[1 - i % 2, 0, 0] = BS[tx, 0]
398+
T.writes(BL[(i + 1) % 2, 0, 0])
399+
BL[(i + 1) % 2, 0, 0] = BS[tx, 0]
400400
with T.block():
401401
T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0])
402402
T.writes(C[tx, i])
@@ -475,12 +475,12 @@ def transformed_nested_pipeline_simple(
475475
for i in T.serial(0, 15):
476476
with T.block():
477477
T.reads([A[tx, i + 1, 0:16]])
478-
T.writes([A_shared[1 - i % 2, tx, 0, 0:16]])
478+
T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
479479
for j in T.serial(0, 16):
480480
with T.block():
481481
T.reads([A[tx, i + 1, j]])
482-
T.writes([A_shared[1 - i % 2, tx, 0, j]])
483-
A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j]
482+
T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
483+
A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j]
484484
with T.block():
485485
T.reads([A_shared[i % 2, tx, i, 0]])
486486
T.writes([B[0, tx, i, 0]])
@@ -491,10 +491,10 @@ def transformed_nested_pipeline_simple(
491491
for j in T.serial(0, 15):
492492
with T.block():
493493
T.reads([A_shared[i % 2, tx, i, j + 1]])
494-
T.writes([B[1 - j % 2, tx, i, 0]])
495-
B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 0, j + 1] * T.float32(
496-
2
497-
)
494+
T.writes([B[(j + 1) % 2, tx, i, 0]])
495+
B[(j + 1) % 2, tx, i, 0] = A_shared[
496+
i % 2, tx, 0, j + 1
497+
] * T.float32(2)
498498
with T.block():
499499
T.reads([B[j % 2, tx, i, 0]])
500500
T.writes([C[tx, i, j]])
@@ -516,8 +516,8 @@ def transformed_nested_pipeline_simple(
516516
for j in T.serial(0, 15):
517517
with T.block():
518518
T.reads([A_shared[1, tx, 15, j + 1]])
519-
T.writes([B[1 - j % 2, tx, 15, 0]])
520-
B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
519+
T.writes([B[(j + 1) % 2, tx, 15, 0]])
520+
B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
521521
with T.block():
522522
T.reads([B[j % 2, tx, 15, 0]])
523523
T.writes([C[tx, 15, j]])
@@ -603,30 +603,30 @@ def transformed_nested_pipeline_prefetch_inner(
603603
for i in T.serial(0, 15):
604604
with T.block():
605605
T.reads([A[tx, i + 1, 0:16]])
606-
T.writes([A_shared[1 - i % 2, tx, 0, 0:16]])
606+
T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
607607
for j in T.serial(0, 16):
608608
with T.block():
609609
T.reads([A[tx, i + 1, j]])
610-
T.writes([A_shared[1 - i % 2, tx, 0, j]])
611-
A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j]
610+
T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
611+
A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j]
612612
with T.block():
613613
T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]])
614614
T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
615615
for j in T.serial(0, 15):
616616
with T.block():
617617
T.reads([A_shared[i % 2, tx, i, j + 1]])
618-
T.writes([B[1 - j % 2, tx, i, 0]])
619-
B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 0, j + 1] * T.float32(
620-
2
621-
)
618+
T.writes([B[(j + 1) % 2, tx, i, 0]])
619+
B[(j + 1) % 2, tx, i, 0] = A_shared[
620+
i % 2, tx, 0, j + 1
621+
] * T.float32(2)
622622
with T.block():
623623
T.reads([B[j % 2, tx, i, 0]])
624624
T.writes([C[tx, i, j]])
625625
C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
626626
with T.block():
627-
T.reads([A_shared[1 - i % 2, tx, i + 1, 0]])
627+
T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]])
628628
T.writes([B[0, tx, i + 1, 0]])
629-
B[0, tx, i + 1, 0] = A_shared[1 - i % 2, tx, 0, 0] * T.float32(2)
629+
B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2)
630630
with T.block():
631631
T.reads([B[1, tx, i, 0]])
632632
T.writes([C[tx, i, 15]])
@@ -640,8 +640,8 @@ def transformed_nested_pipeline_prefetch_inner(
640640
for j in T.serial(0, 15):
641641
with T.block():
642642
T.reads([A_shared[1, tx, 15, j + 1]])
643-
T.writes([B[1 - j % 2, tx, 15, 0]])
644-
B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
643+
T.writes([B[(j + 1) % 2, tx, 15, 0]])
644+
B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
645645
with T.block():
646646
T.reads([B[j % 2, tx, 15, 0]])
647647
T.writes([C[tx, 15, j]])
@@ -768,8 +768,8 @@ def transformed_nested_pipeline_interleaving(
768768
for j in T.serial(0, 15):
769769
with T.block():
770770
T.reads([A_local[tx, i, j + 1]])
771-
T.writes([B[1 - j % 2, tx, i, 0]])
772-
B[1 - j % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2)
771+
T.writes([B[(j + 1) % 2, tx, i, 0]])
772+
B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2)
773773
with T.block():
774774
T.reads([B[j % 2, tx, i, 0]])
775775
T.writes([C[tx, i, j]])
@@ -799,8 +799,8 @@ def transformed_nested_pipeline_interleaving(
799799
for j in T.serial(0, 15):
800800
with T.block():
801801
T.reads([A_local[tx, 15, j + 1]])
802-
T.writes([B[1 - j % 2, tx, 15, 0]])
803-
B[1 - j % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2)
802+
T.writes([B[(j + 1) % 2, tx, 15, 0]])
803+
B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2)
804804
with T.block():
805805
T.reads([B[j % 2, tx, 15, 0]])
806806
T.writes([C[tx, 15, j]])
@@ -929,25 +929,27 @@ def transformed_nested_pipeline_double_buffer(
929929
for j in T.serial(0, 15):
930930
with T.block():
931931
T.reads([A_local[i % 2, tx, i, j + 1]])
932-
T.writes([B[1 - j % 2, tx, i, 0]])
933-
B[1 - j % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32(2)
932+
T.writes([B[(j + 1) % 2, tx, i, 0]])
933+
B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32(
934+
2
935+
)
934936
with T.block():
935937
T.reads([B[j % 2, tx, i, 0]])
936938
T.writes([C[tx, i, j]])
937939
C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
938940
with T.block():
939941
T.reads([A_shared[tx, 0, 0:16]])
940-
T.writes([A_local[1 - i % 2, 0, 0, 0:16]])
942+
T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]])
941943
for j in T.serial(0, 16):
942944
with T.block():
943945
T.reads([A_shared[tx, 0, j]])
944-
T.writes([A_local[1 - i % 2, 0, 0, j]])
946+
T.writes([A_local[(i + 1) % 2, 0, 0, j]])
945947
T.block_attr({"double_buffer_scope": 0})
946-
A_local[1 - i % 2, 0, 0, j] = A_shared[tx, i + 1, j]
948+
A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j]
947949
with T.block():
948-
T.reads([A_local[1 - i % 2, tx, i + 1, 0]])
950+
T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]])
949951
T.writes([B[0, tx, i + 1, 0]])
950-
B[0, tx, i + 1, 0] = A_local[1 - i % 2, 0, 0, 0] * T.float32(2)
952+
B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] * T.float32(2)
951953
with T.block():
952954
T.reads([B[1, tx, i, 0]])
953955
T.writes([C[tx, i, 15]])
@@ -961,8 +963,8 @@ def transformed_nested_pipeline_double_buffer(
961963
for j in T.serial(0, 15):
962964
with T.block():
963965
T.reads([A_local[1, tx, 15, j + 1]])
964-
T.writes([B[1 - j % 2, tx, 15, 0]])
965-
B[1 - j % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2)
966+
T.writes([B[(j + 1) % 2, tx, 15, 0]])
967+
B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2)
966968
with T.block():
967969
T.reads([B[j % 2, tx, 15, 0]])
968970
T.writes([C[tx, 15, j]])

0 commit comments

Comments
 (0)