@@ -1172,67 +1172,6 @@ def block_predicate_cache_write_output_buf() -> None:
11721172use_block_name = tvm .testing .parameter (by_dict = {"block_obj" : False , "block_name" : True })
11731173
11741174
1175- @T .prim_func
1176- def cache_write_allocate_const (
1177- A : T .Buffer ((128 , 128 ), "float32" ), C : T .Buffer ((128 , 128 ), "float16" )
1178- ):
1179- B = T .alloc_buffer ([128 , 128 ], dtype = "float32" )
1180- const = T .allocate_const ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 ], "float32" , [8 ])
1181- const_1 = T .Buffer ([8 ], dtype = "float32" , data = const )
1182- const2 = T .allocate_const ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 ], "float32" , [8 ])
1183- const_2 = T .Buffer ([8 ], dtype = "float32" , data = const )
1184- for i , j in T .grid (128 , 128 ):
1185- for x in range (8 ):
1186- with T .block ("B" ):
1187- vi , vj , vx = T .axis .remap ("SSS" , [i , j , x ])
1188- T .reads (A [vi , vj ], const_1 [vx ], const_2 [vx ])
1189- T .writes (B [vi , vj ])
1190- B [vi , vj ] = A [vi , vj ] * const_1 [vx ] + const_2 [vx ]
1191- for i , j in T .grid (128 , 128 ):
1192- with T .block ("C" ):
1193- vi , vj = T .axis .remap ("SS" , [i , j ])
1194- T .reads (B [vi , vj ])
1195- T .writes (C [vi , vj ])
1196- C [vi , vj ] = B [vi , vj ] + 1.0
1197-
1198-
1199- @T .prim_func
1200- def cache_write_allocate_const_output (
1201- A : T .Buffer ((128 , 128 ), "float32" ), C : T .Buffer ((128 , 128 ), "float16" )
1202- ):
1203- B = T .alloc_buffer ([128 , 128 ], dtype = "float32" )
1204- A_global = T .alloc_buffer ([128 , 128 ], dtype = "float32" )
1205- C_global = T .alloc_buffer ([128 , 128 ], dtype = "float16" )
1206- const_2 = T .allocate_const ([0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 ], "float32" , [8 ])
1207- const_1 = T .Buffer ([8 ], dtype = "float32" , data = const_2 )
1208- const_2_1 = T .Buffer ([8 ], dtype = "float32" , data = const_2 )
1209- const2 = T .allocate_const ([0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 ], "float32" , [8 ])
1210- for ax0 , ax1 in T .grid (128 , 128 ):
1211- with T .block ("A_global" ):
1212- v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
1213- T .reads (A [v0 , v1 ])
1214- T .writes (A_global [v0 , v1 ])
1215- A_global [v0 , v1 ] = A [v0 , v1 ]
1216- for i , j , x in T .grid (128 , 128 , 8 ):
1217- with T .block ("B" ):
1218- vi , vj , vx = T .axis .remap ("SSS" , [i , j , x ])
1219- T .reads (A_global [vi , vj ], const_1 [vx ], const_2_1 [vx ])
1220- T .writes (B [vi , vj ])
1221- B [vi , vj ] = A_global [vi , vj ] * const_1 [vx ] + const_2_1 [vx ]
1222- for i , j in T .grid (128 , 128 ):
1223- with T .block ("C" ):
1224- vi , vj = T .axis .remap ("SS" , [i , j ])
1225- T .reads (B [vi , vj ])
1226- T .writes (C_global [vi , vj ])
1227- C_global [vi , vj ] = B [vi , vj ] + T .float32 (1 )
1228- for ax0 , ax1 in T .grid (128 , 128 ):
1229- with T .block ("C_global" ):
1230- v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
1231- T .reads (C_global [v0 , v1 ])
1232- T .writes (C [v0 , v1 ])
1233- C [v0 , v1 ] = C_global [v0 , v1 ]
1234-
1235-
12361175def test_cache_read_elementwise (use_block_name ):
12371176 sch = tir .Schedule (elementwise , debug_mask = "all" )
12381177 block_b = sch .get_block ("B" )
@@ -1493,14 +1432,79 @@ def test_cache_write_fail_invalid_storage_scope(use_block_name):
14931432 sch .cache_write (block_b , 0 , "test_scope" )
14941433
14951434
1496- def test_cache_write_allocate_const ():
1497- sch = tir .Schedule (cache_write_allocate_const )
1435+ @pytest .mark .parametrize ("use_decl_buffer" , [True , False ])
1436+ def test_cache_write_allocate_const (use_decl_buffer ):
1437+ def apply_decl_buffer (* args , ** kwargs ):
1438+ if use_decl_buffer :
1439+ return T .decl_buffer (* args , ** kwargs )
1440+ else :
1441+ return T .Buffer (* args , ** kwargs )
1442+
1443+ @T .prim_func
1444+ def before (A : T .Buffer ((128 , 128 ), "float32" ), C : T .Buffer ((128 , 128 ), "float16" )):
1445+ B = T .alloc_buffer ([128 , 128 ], dtype = "float32" )
1446+ const1 = T .allocate_const ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 ], "float32" , [8 ])
1447+ const1_buf = apply_decl_buffer ([8 ], dtype = "float32" , data = const1 )
1448+ const2 = T .allocate_const ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 ], "float32" , [8 ])
1449+ const2_buf = apply_decl_buffer ([8 ], dtype = "float32" , data = const2 )
1450+ for i , j in T .grid (128 , 128 ):
1451+ for x in range (8 ):
1452+ with T .block ("B" ):
1453+ vi , vj , vx = T .axis .remap ("SSS" , [i , j , x ])
1454+ T .reads (A [vi , vj ], const1_buf [vx ], const2_buf [vx ])
1455+ T .writes (B [vi , vj ])
1456+ B [vi , vj ] = A [vi , vj ] * const1_buf [vx ] + const2_buf [vx ]
1457+ for i , j in T .grid (128 , 128 ):
1458+ with T .block ("C" ):
1459+ vi , vj = T .axis .remap ("SS" , [i , j ])
1460+ T .reads (B [vi , vj ])
1461+ T .writes (C [vi , vj ])
1462+ C [vi , vj ] = B [vi , vj ] + 1.0
1463+
1464+ @T .prim_func
1465+ def expected (A : T .Buffer ((128 , 128 ), "float32" ), C : T .Buffer ((128 , 128 ), "float16" )):
1466+ B = T .alloc_buffer ([128 , 128 ], dtype = "float32" )
1467+ A_global = T .alloc_buffer ([128 , 128 ], dtype = "float32" )
1468+ C_global = T .alloc_buffer ([128 , 128 ], dtype = "float16" )
1469+ const1 = T .allocate_const ([0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 ], "float32" , [8 ])
1470+ const1_buf = apply_decl_buffer ([8 ], dtype = "float32" , data = const1 )
1471+ const2 = T .allocate_const ([0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 ], "float32" , [8 ])
1472+ const2_buf = apply_decl_buffer ([8 ], dtype = "float32" , data = const2 )
1473+ for ax0 , ax1 in T .grid (128 , 128 ):
1474+ with T .block ("A_global" ):
1475+ v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
1476+ T .reads (A [v0 , v1 ])
1477+ T .writes (A_global [v0 , v1 ])
1478+ A_global [v0 , v1 ] = A [v0 , v1 ]
1479+ for i , j , x in T .grid (128 , 128 , 8 ):
1480+ with T .block ("B" ):
1481+ vi , vj , vx = T .axis .remap ("SSS" , [i , j , x ])
1482+ T .reads (A_global [vi , vj ], const1_buf [vx ], const2_buf [vx ])
1483+ T .writes (B [vi , vj ])
1484+ B [vi , vj ] = A_global [vi , vj ] * const1_buf [vx ] + const2_buf [vx ]
1485+ for i , j in T .grid (128 , 128 ):
1486+ with T .block ("C" ):
1487+ vi , vj = T .axis .remap ("SS" , [i , j ])
1488+ T .reads (B [vi , vj ])
1489+ T .writes (C_global [vi , vj ])
1490+ C_global [vi , vj ] = B [vi , vj ] + T .float32 (1 )
1491+ for ax0 , ax1 in T .grid (128 , 128 ):
1492+ with T .block ("C_global" ):
1493+ v0 , v1 = T .axis .remap ("SS" , [ax0 , ax1 ])
1494+ T .reads (C_global [v0 , v1 ])
1495+ T .writes (C [v0 , v1 ])
1496+ C [v0 , v1 ] = C_global [v0 , v1 ]
1497+
1498+ sch = tir .Schedule (before )
14981499 block_b = sch .get_block ("B" )
14991500 block_c = sch .get_block ("C" )
15001501 sch .cache_read (block_b , 0 , "global" )
15011502 sch .cache_write (block_c , 0 , "global" )
1502- tvm .ir .assert_structural_equal (cache_write_allocate_const_output , sch .mod ["main" ])
1503- verify_trace_roundtrip (sch = sch , mod = cache_write_allocate_const )
1503+
1504+ after = sch .mod ["main" ]
1505+
1506+ tvm .ir .assert_structural_equal (expected , after )
1507+ verify_trace_roundtrip (sch = sch , mod = before )
15041508
15051509
15061510def test_reindex_cache_read ():
0 commit comments