@@ -2162,5 +2162,97 @@ def main(
21622162 _check (Module , Expected )
21632163
21642164
2165+ def test_use_as_inplace_and_dps ():
2166+ @I .ir_module
2167+ class Module :
2168+ # we will use it both in-place and normally (DPS)
2169+ @T .prim_func (private = True )
2170+ def add (
2171+ A : T .Buffer ((T .int64 (10 ), T .int64 (20 )), "float32" ),
2172+ B : T .Buffer ((), "float32" ),
2173+ Out : T .Buffer ((T .int64 (10 ), T .int64 (20 )), "float32" ),
2174+ ):
2175+ T .func_attr ({"tir.noalias" : T .bool (True )})
2176+ for ax0 , ax1 in T .grid (T .int64 (10 ), T .int64 (20 )):
2177+ with T .block ("T_add" ):
2178+ v_ax0 , v_ax1 = T .axis .remap ("SS" , [ax0 , ax1 ])
2179+ Out [v_ax0 , v_ax1 ] = A [v_ax0 , v_ax1 ] + B [()]
2180+
2181+ @R .function (private = True )
2182+ def fused_sums (
2183+ x : R .Tensor ((10 , 20 ), dtype = "float32" ), p0 : R .Tensor ((), dtype = "float32" )
2184+ ) -> R .Tensor ((10 , 20 ), dtype = "float32" ):
2185+ R .func_attr ({"Primitive" : 1 })
2186+ cls = Module
2187+ with R .dataflow ():
2188+ lv = R .call_tir (
2189+ cls .add ,
2190+ (x , p0 ),
2191+ out_sinfo = R .Tensor ((10 , 20 ), dtype = "float32" ),
2192+ )
2193+ lv1 = R .call_tir_inplace (
2194+ cls .add ,
2195+ (x , p0 , lv ),
2196+ inplace_indices = [2 ],
2197+ out_sinfo = R .Tensor ((10 , 20 ), dtype = "float32" ),
2198+ )
2199+ lv2 = R .call_tir_inplace (
2200+ cls .add ,
2201+ (x , p0 , lv1 ),
2202+ inplace_indices = [2 ],
2203+ out_sinfo = R .Tensor ((10 , 20 ), dtype = "float32" ),
2204+ )
2205+ R .output (lv2 )
2206+ return lv2
2207+
2208+ @R .function
2209+ def main (
2210+ x : R .Tensor ((10 , 20 ), dtype = "float32" ), p0 : R .Tensor ((), dtype = "float32" )
2211+ ) -> R .Tensor ((10 , 20 ), dtype = "float32" ):
2212+ cls = Module
2213+ with R .dataflow ():
2214+ gv1 : R .Tensor ((10 , 20 ), dtype = "float32" ) = cls .fused_sums (x , p0 )
2215+ R .output (gv1 )
2216+ return gv1
2217+
2218+ @I .ir_module
2219+ class Expected :
2220+ @T .prim_func (private = True )
2221+ def fused_sums (
2222+ x : T .Buffer ((T .int64 (10 ), T .int64 (20 )), "float32" ),
2223+ p0 : T .Buffer ((), "float32" ),
2224+ p_output0 : T .Buffer ((T .int64 (10 ), T .int64 (20 )), "float32" ),
2225+ ):
2226+ T .func_attr ({"tir.noalias" : T .bool (True )})
2227+ for ax0 , ax1 in T .grid (T .int64 (10 ), T .int64 (20 )):
2228+ with T .block ("T_add" ):
2229+ v_ax0 , v_ax1 = T .axis .remap ("SS" , [ax0 , ax1 ])
2230+ p_output0 [v_ax0 , v_ax1 ] = x [v_ax0 , v_ax1 ] + p0 [()]
2231+ for ax0 , ax1 in T .grid (T .int64 (10 ), T .int64 (20 )):
2232+ with T .block ("T_add" ):
2233+ v_ax0 , v_ax1 = T .axis .remap ("SS" , [ax0 , ax1 ])
2234+ p_output0 [v_ax0 , v_ax1 ] = x [v_ax0 , v_ax1 ] + p0 [()]
2235+ for ax0 , ax1 in T .grid (T .int64 (10 ), T .int64 (20 )):
2236+ with T .block ("T_add" ):
2237+ v_ax0 , v_ax1 = T .axis .remap ("SS" , [ax0 , ax1 ])
2238+ p_output0 [v_ax0 , v_ax1 ] = x [v_ax0 , v_ax1 ] + p0 [()]
2239+
2240+ @R .function
2241+ def main (
2242+ x : R .Tensor ((10 , 20 ), dtype = "float32" ), p0 : R .Tensor ((), dtype = "float32" )
2243+ ) -> R .Tensor ((10 , 20 ), dtype = "float32" ):
2244+ cls = Expected
2245+ with R .dataflow ():
2246+ gv1 : R .Tensor ((10 , 20 ), dtype = "float32" ) = R .call_tir (
2247+ cls .fused_sums ,
2248+ (x , p0 ),
2249+ out_sinfo = R .Tensor ((10 , 20 ), dtype = "float32" ),
2250+ )
2251+ R .output (gv1 )
2252+ return gv1
2253+
2254+ _check (Module , Expected )
2255+
2256+
21652257if __name__ == "__main__" :
21662258 tvm .testing .main ()
0 commit comments