@@ -1384,11 +1384,10 @@ def test_cache_read_allocate_const():
13841384 def before (A : T .Buffer ((8 ), "float32" ), C : T .Buffer ((8 ), "float32" )):
13851385 B = T .allocate_const ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 ], "float32" , [8 ])
13861386 B_buf = T .decl_buffer ((8 ), dtype = "float32" , data = B )
1387- for i in T . serial ( 128 ):
1387+ for i in range ( 8 ):
13881388 with T .block ("C" ):
1389- vi = T .axis .remap ( "S" , [ i ] )
1389+ vi = T .axis .spatial ( 8 , i )
13901390 T .reads (A [vi ], B_buf [vi ])
1391- T .writes (C [vi ])
13921391 C [vi ] = A [vi ] + B_buf [vi ]
13931392
13941393 @T .prim_func
@@ -1400,20 +1399,15 @@ def expected(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
14001399 for ax0 in range (8 ):
14011400 with T .block ("A_global" ):
14021401 v0 = T .axis .spatial (8 , ax0 )
1403- T .reads (A [v0 ])
1404- T .writes (A_global [v0 ])
14051402 A_global [v0 ] = A [v0 ]
14061403 for ax0 in range (8 ):
14071404 with T .block ("B_buf_global" ):
14081405 v0 = T .axis .spatial (8 , ax0 )
14091406 T .reads (B_buf [v0 ])
1410- T .writes (B_buf_global [v0 ])
14111407 B_buf_global [v0 ] = B_buf [v0 ]
1412- for i in range (128 ):
1408+ for i in range (8 ):
14131409 with T .block ("C" ):
1414- vi = T .axis .spatial (128 , i )
1415- T .reads (A_global [vi ], B_buf_global [vi ])
1416- T .writes (C [vi ])
1410+ vi = T .axis .spatial (8 , i )
14171411 C [vi ] = A_global [vi ] + B_buf_global [vi ]
14181412
14191413 sch = tir .Schedule (before )
0 commit comments