Skip to content

Commit e767073

Browse files
committed
This PR fixes a bug revealed as part of the IterMapSimplify change.
y - (x % 2) were simplified to y + (x % 2) - 1 which is wrong. Regression tests are added.
1 parent 2f2a385 commit e767073

File tree

2 files changed

+6
-17
lines changed

2 files changed

+6
-17
lines changed

src/arith/iter_affine_map.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,11 +1093,6 @@ class IterMapRewriter : public ExprMutator {
10931093
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);
10941094

10951095
static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
1096-
if (sign < 0 && is_const_int(rhs->extent, 2)) {
1097-
lhs->base -= rhs->scale;
1098-
sign = 1;
1099-
}
1100-
11011096
tir::ExprDeepEqual equal;
11021097
for (size_t i = 0; i < lhs->args.size(); ++i) {
11031098
IterSplitExpr lvalue = lhs->args[i];

tests/python/unittest/test_arith_iter_affine_map.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -222,18 +222,6 @@ def test_compound():
222222
assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)]))
223223

224224

225-
def test_compound_floormod_two():
226-
x = tvm.tir.Var("x", "int32")
227-
fld = tvm.tir.floordiv
228-
flm = tvm.tir.floormod
229-
230-
# extent of 2 are normalized to positive scale
231-
assert_iter_sum_pattern(
232-
expect_dict={fld(x, 2) * 2 - flm(x, 2) + 1: (8, 0, 1)},
233-
dom_map=var_dom([(x, 8)]),
234-
)
235-
236-
237225
def test_predicate():
238226
x = tvm.tir.Var("x", "int32")
239227
y = tvm.tir.Var("y", "int32")
@@ -1190,6 +1178,12 @@ def test_iter_map_simplify_unit_loop_order():
11901178
simplify_trivial_iterators=False,
11911179
)
11921180

1181+
assert_iter_map_simplfy(
1182+
{y + 64 - x % 2 * 64: y + 64 - x % 2 * 64},
1183+
var_dom([(x, 6), (y, 64)]),
1184+
simplify_trivial_iterators=False,
1185+
)
1186+
11931187

11941188
if __name__ == "__main__":
11951189
tvm.testing.main()

0 commit comments

Comments
 (0)