@@ -1915,5 +1915,79 @@ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle):
19151915 )
19161916
19171917
1918+ def test_compute_at_sliced_concatenate ():
1919+ @T .prim_func
1920+ def before ():
1921+ X = T .alloc_buffer ((1 , 16 , 28 , 64 ), "float32" )
1922+ Y = T .alloc_buffer ((1 , 32 , 28 , 64 ), "float32" )
1923+ Z = T .alloc_buffer ((1 , 53 , 28 , 64 ), "float32" )
1924+ Concat = T .alloc_buffer ((1 , 101 , 28 , 64 ), "float32" )
1925+ Slice = T .alloc_buffer ((1 , 87 , 28 , 64 ), "float32" )
1926+ for ax0 , ax1 , ax2 , ax3 in T .grid (1 , 16 , 28 , 64 ):
1927+ with T .block ("compute" ):
1928+ v_ax0 , v_ax1 , v_ax2 , v_ax3 = T .axis .remap ("SSSS" , [ax0 , ax1 , ax2 , ax3 ])
1929+ X [v_ax0 , v_ax1 , v_ax2 , v_ax3 ] = 1.0
1930+ for ax0 , ax1 , ax2 , ax3 in T .grid (1 , 101 , 28 , 64 ):
1931+ with T .block ("T_concat" ):
1932+ v_ax0 , v_ax1 , v_ax2 , v_ax3 = T .axis .remap ("SSSS" , [ax0 , ax1 , ax2 , ax3 ])
1933+ Concat [v_ax0 , v_ax1 , v_ax2 , v_ax3 ] = T .if_then_else (
1934+ 85 <= v_ax1 ,
1935+ X [v_ax0 , v_ax1 - 85 , v_ax2 , v_ax3 ],
1936+ T .if_then_else (
1937+ 53 <= v_ax1 ,
1938+ Y [v_ax0 , v_ax1 - 53 , v_ax2 , v_ax3 ],
1939+ Z [v_ax0 , v_ax1 , v_ax2 , v_ax3 ],
1940+ ),
1941+ )
1942+ for ax0 , ax1 , ax2 , ax3 in T .grid (1 , 87 , 28 , 64 ):
1943+ with T .block ("T_strided_slice" ):
1944+ v_ax0 , v_ax1 , v_ax2 , v_ax3 = T .axis .remap ("SSSS" , [ax0 , ax1 , ax2 , ax3 ])
1945+ Slice [v_ax0 , v_ax1 , v_ax2 , v_ax3 ] = Concat [v_ax0 , v_ax1 , v_ax2 , v_ax3 ]
1946+
1947+ @T .prim_func
1948+ def expect ():
1949+ X = T .alloc_buffer ((1 , 16 , 28 , 64 ))
1950+ Y = T .alloc_buffer ((1 , 32 , 28 , 64 ))
1951+ Z = T .alloc_buffer ((1 , 53 , 28 , 64 ))
1952+ Concat = T .alloc_buffer ((1 , 101 , 28 , 64 ))
1953+ Slice = T .alloc_buffer ((1 , 87 , 28 , 64 ))
1954+ for ax0 in range (1 ):
1955+ for ax0_1 , ax1 , ax2 in T .grid (2 , 28 , 64 ):
1956+ with T .block ("compute" ):
1957+ v_ax0 = T .axis .spatial (1 , 0 )
1958+ v_ax1 = T .axis .spatial (16 , ax0_1 )
1959+ v_ax2 , v_ax3 = T .axis .remap ("SS" , [ax1 , ax2 ])
1960+ X [v_ax0 , v_ax1 , v_ax2 , v_ax3 ] = T .float32 (1 )
1961+ for ax0_1 , ax1 , ax2 in T .grid (87 , 28 , 64 ):
1962+ with T .block ("T_concat" ):
1963+ v_ax0 = T .axis .spatial (1 , 0 )
1964+ v_ax1 = T .axis .spatial (101 , ax0_1 )
1965+ v_ax2 , v_ax3 = T .axis .remap ("SS" , [ax1 , ax2 ])
1966+ Concat [v_ax0 , v_ax1 , v_ax2 , v_ax3 ] = T .if_then_else (
1967+ 85 <= v_ax1 ,
1968+ X [v_ax0 , v_ax1 - 85 , v_ax2 , v_ax3 ],
1969+ T .if_then_else (
1970+ 53 <= v_ax1 ,
1971+ Y [v_ax0 , v_ax1 - 53 , v_ax2 , v_ax3 ],
1972+ Z [v_ax0 , v_ax1 , v_ax2 , v_ax3 ],
1973+ ),
1974+ )
1975+ for ax1 , ax2 , ax3 in T .grid (87 , 28 , 64 ):
1976+ with T .block ("T_strided_slice" ):
1977+ v_ax0 , v_ax1 , v_ax2 , v_ax3 = T .axis .remap ("SSSS" , [ax0 , ax1 , ax2 , ax3 ])
1978+ Slice [v_ax0 , v_ax1 , v_ax2 , v_ax3 ] = Concat [v_ax0 , v_ax1 , v_ax2 , v_ax3 ]
1979+
1980+ sch = tir .Schedule (before , debug_mask = "all" )
1981+ blk1 = sch .get_block ("compute" )
1982+ blk2 = sch .get_block ("T_concat" )
1983+ blk3 = sch .get_block ("T_strided_slice" )
1984+ loop = sch .get_loops (blk3 )[0 ]
1985+ sch .compute_at (blk2 , loop )
1986+ sch .compute_at (blk1 , loop )
1987+ after = sch .mod ["main" ]
1988+ assert_structural_equal_ignore_global_symbol (expect , after )
1989+ verify_trace_roundtrip (sch = sch , mod = before )
1990+
1991+
19181992if __name__ == "__main__" :
19191993 tvm .testing .main ()
0 commit comments