@@ -570,11 +570,12 @@ def test_explicit_partition_hint():
570570 tvm .ir .assert_structural_equal (mod ["main" ], partitioned_concat )
571571
572572
573- def partition_from_scheduled_tir (prim_func , pass_cfg ):
573+ def partition_from_scheduled_tir (prim_func , pass_cfg , do_flatten = True ):
574574 with tvm .transform .PassContext (config = pass_cfg ):
575575 mod = IRModule .from_expr (prim_func .with_attr ("global_symbol" , "main" ))
576576 mod = tvm .tir .transform .LowerOpaqueBlock ()(mod )
577- mod = tvm .tir .transform .FlattenBuffer ()(mod )
577+ if do_flatten :
578+ mod = tvm .tir .transform .FlattenBuffer ()(mod )
578579 mod = tvm .tir .transform .LoopPartition ()(mod )
579580 mod = tvm .tir .transform .Simplify ()(mod )
580581 mod = tvm .tir .transform .RemoveNoOp ()(mod )
@@ -1037,6 +1038,29 @@ def concat_five_buffers_with_equalities_expected(
10371038 T_concat_1 [i0 * 129 + 129 ] = buffer_e_1 [i0 ]
10381039
10391040
1041+ @T .prim_func
1042+ def nested_partition_with_single_points (A : T .Buffer [(25 ,), "int32" ]):
1043+ for i in T .serial (5 , annotations = {"pragma_loop_partition_hint" : 1 }):
1044+ if i == 1 :
1045+ for j in T .serial (5 , annotations = {"pragma_loop_partition_hint" : 1 }):
1046+ if j > 2 :
1047+ A [i * 5 + j ] = i * 5 + j
1048+ else :
1049+ for j in T .serial (5 , annotations = {"pragma_loop_partition_hint" : 1 }):
1050+ if j > 2 :
1051+ A [i * 5 + j ] = i * 15 + j
1052+
1053+
1054+ @T .prim_func
1055+ def nested_partition_with_single_points_expected (A : T .Buffer [(25 ,), "int32" ]):
1056+ for j in range (2 ):
1057+ A [j + 3 ] = j + 3
1058+ for j in range (2 ):
1059+ A [j + 8 ] = j + 8
1060+ for i , j in T .grid (3 , 2 ):
1061+ A [i * 5 + j + 13 ] = i * 15 + j + 33
1062+
1063+
10401064@pytest .mark .parametrize (
10411065 "origin,expected" ,
10421066 [
@@ -1045,6 +1069,7 @@ def concat_five_buffers_with_equalities_expected(
10451069 (concat_func_end_point_equality , concat_func_end_point_equality_expected ),
10461070 (concat_func_edge_equalities , concat_func_edge_equalities_expected ),
10471071 (concat_five_buffers_with_equalities , concat_five_buffers_with_equalities_expected ),
1072+ (nested_partition_with_single_points , nested_partition_with_single_points_expected ),
10481073 ],
10491074)
10501075def test_single_point_partition (origin , expected ):
@@ -1062,5 +1087,63 @@ def test_single_point_partition(origin, expected):
10621087 tvm .ir .assert_structural_equal (mod ["main" ], expected )
10631088
10641089
1090+ def test_equation_on_floordiv ():
1091+ @T .prim_func
1092+ def before (A : T .Buffer [(2 , 2 , 20 ), "int32" ]):
1093+ for i in T .serial (5 , annotations = {"pragma_loop_partition_hint" : 1 }):
1094+ if i == 1 :
1095+ for vv in T .vectorized (640 , annotations = {"pragma_loop_partition_hint" : 1 }):
1096+ if i * 2 + vv // 320 == 3 :
1097+ A [i - 1 , i * 2 + vv // 320 - 3 , vv % 320 // 16 ] = 1
1098+
1099+ @T .prim_func
1100+ def expected (A : T .Buffer [(2 , 2 , 20 ), "int32" ]):
1101+ for vv in T .vectorized (320 ):
1102+ A [0 , 0 , vv // 16 ] = 1
1103+
1104+ expected = expected .with_attr ({"global_symbol" : "main" })
1105+ after = partition_from_scheduled_tir (
1106+ before .with_attr ("global_symbol" , "main" ), {}, do_flatten = False
1107+ )
1108+ tvm .ir .assert_structural_equal (after ["main" ], expected )
1109+
1110+
1111+ def test_ignore_loop_partition_hint ():
1112+ """Skip unroll body and prologue for pipeline case"""
1113+
1114+ @T .prim_func
1115+ def before (A : T .Buffer [(10 ), "float32" ], D : T .Buffer [(10 ), "float32" ]):
1116+ B = T .decl_buffer ([2 ], "float32" )
1117+ C = T .decl_buffer ([2 ], "float32" )
1118+ for i in T .serial (12 , annotations = {"pragma_loop_partition_hint" : 1 }):
1119+ if T .ignore_loop_partition (i < 10 ):
1120+ B [i % 2 ] = A [i ] + 1.0
1121+ if T .ignore_loop_partition (1 <= i and i < 11 ):
1122+ C [(i - 1 ) % 2 ] = B [(i - 1 ) % 2 ] + 2.0
1123+ if 2 <= i :
1124+ D [i - 2 ] = C [i % 2 ] + 3.0
1125+
1126+ @T .prim_func
1127+ def expected (A : T .Buffer [(10 ), "float32" ], D : T .Buffer [(10 ), "float32" ]):
1128+ B = T .decl_buffer ([2 ], "float32" )
1129+ C = T .decl_buffer ([2 ], "float32" )
1130+ for i in range (2 ):
1131+ B [i ] = A [i ] + 1.0
1132+ if i == 1 :
1133+ C [i - 1 ] = B [i - 1 ] + 2.0
1134+ for i in T .serial (10 ):
1135+ if i < 8 :
1136+ B [i % 2 ] = A [i + 2 ] + 1.0
1137+ if i < 9 :
1138+ C [(i + 1 ) % 2 ] = B [(i + 1 ) % 2 ] + 2.0
1139+ D [i ] = C [i % 2 ] + 3.0
1140+
1141+ expected = expected .with_attr ({"global_symbol" : "main" })
1142+ after = partition_from_scheduled_tir (
1143+ before .with_attr ({"global_symbol" : "main" }), {}, do_flatten = False
1144+ )
1145+ tvm .ir .assert_structural_equal (after ["main" ], expected )
1146+
1147+
10651148if __name__ == "__main__" :
10661149 tvm .testing .main ()
0 commit comments