@@ -1227,11 +1227,29 @@ def test_iter_map_simplify_unit_loop_order():
12271227
12281228
12291229def assert_normalize_to_iter_sum (index , input_iters , args , base ):
1230+ """Assert the result of arith.normalize_to_iter_sum is correct
1231+
1232+ Parameters
1233+ ----------
1234+ index : tvm.tir.PrimExpr
1235+ The index to be normalized
1236+ input_iters : Mapping[Var, Range]
1237+ The input iterators
1238+ args : List[Union[tvm.arith.IterSplitExpr, Tuple[PrimExpr, PrimExpr]]]
1239+ The expected result. Ordered list of args of the expected IterSumExpr. Each arg can be
1240+ either IterSplitExpr or a tuple of (PrimExpr, PrimExpr) where the first element is the
1241+ iterator normalized to PrimExpr and the second element is the scale.
1242+ base : tvm.tir.PrimExpr
1243+ The expected base
1244+ """
12301245 res = tvm .arith .normalize_to_iter_sum (index , input_iters )
12311246
12321247 assert isinstance (res , tvm .arith .IterSumExpr )
12331248 assert len (res .args ) == len (args )
12341249 for split , item in zip (res .args , args ):
1250+ if isinstance (item , tvm .arith .IterSplitExpr ):
1251+ tvm .ir .assert_structural_equal (split , item )
1252+ continue
12351253 tvm .testing .assert_prim_expr_equal (split .scale , item [1 ])
12361254 tvm .testing .assert_prim_expr_equal (
12371255 tvm .arith .normalize_iter_map_to_expr (split ), item [0 ] * item [1 ]
@@ -1245,6 +1263,7 @@ def test_normalize_to_iter_sum():
12451263 z = tvm .tir .Var ("z" , "int64" )
12461264 a = tvm .tir .Var ("a" , "int64" )
12471265 n = tvm .tir .Var ("n" , "int64" )
1266+ flm = tvm .tir .floormod
12481267
12491268 assert_normalize_to_iter_sum (
12501269 z + ((y + x * 4 + 2 ) * n ) + 3 ,
@@ -1285,6 +1304,21 @@ def test_normalize_to_iter_sum():
12851304 0 ,
12861305 )
12871306
1307+ # non-divisible
1308+ assert_normalize_to_iter_sum (
1309+ x // 5 ,
1310+ var_dom ([(x , 4096 )]),
1311+ [
1312+ tvm .arith .IterSplitExpr (
1313+ tvm .arith .IterMark (x , 4096 ),
1314+ lower_factor = tvm .tir .const (5 , "int64" ),
1315+ extent = tvm .tir .const (820 , "int64" ),
1316+ scale = tvm .tir .const (1 , "int64" ),
1317+ )
1318+ ],
1319+ 0 ,
1320+ )
1321+
12881322 # iter simplify
12891323 assert_normalize_to_iter_sum (
12901324 z * 2 + 2 * y * 3 + 4 * (x // 4 ) + (x % 4 ),
0 commit comments