@@ -1159,6 +1159,89 @@ def leaky_relu(var_rxplaceholder: T.handle, var_compute: T.handle):
11591159 tvm .ir .assert_structural_equal (mod , Expected )
11601160
11611161
1162+ def test_prelu ():
1163+ # fmt: off
1164+ @tvm .script .ir_module
1165+ class PRelu :
1166+ @R .function
1167+ def main (x : R .Tensor ((2 , 3 ), "float32" ), y : R .Tensor ((1 ,), "float32" )) -> R .Tensor ((2 , 3 ), "float32" ):
1168+ gv : R .Tensor ((2 , 3 ), "float32" ) = R .nn .prelu (x , y )
1169+ return gv
1170+
1171+ @tvm .script .ir_module
1172+ class Expected :
1173+ @R .function
1174+ def main (x : R .Tensor ((2 , 3 ), dtype = "float32" ), y : R .Tensor ((1 ,), dtype = "float32" )) -> R .Tensor ((2 , 3 ), dtype = "float32" ):
1175+ gv = R .call_tir (Expected .prelu , (x , y ), out_sinfo = R .Tensor ((2 , 3 ), dtype = "float32" ))
1176+ return gv
1177+
1178+ @T .prim_func (private = True )
1179+ def prelu (x : T .Buffer ((T .int64 (2 ), T .int64 (3 )), "float32" ), y : T .Buffer ((T .int64 (1 ),), "float32" ), compute : T .Buffer ((T .int64 (2 ), T .int64 (3 )), "float32" )):
1180+ T .func_attr ({"tir.noalias" : True })
1181+ # with T.block("root"):
1182+ slope_broadcasted = T .alloc_buffer ((T .int64 (3 ),))
1183+ for c in range (T .int64 (3 )):
1184+ with T .block ("slope_broadcasted" ):
1185+ v_c = T .axis .spatial (T .int64 (3 ), c )
1186+ T .reads (y [T .int64 (0 )])
1187+ T .writes (slope_broadcasted [v_c ])
1188+ slope_broadcasted [v_c ] = y [T .int64 (0 )]
1189+ for i0 , i1 in T .grid (T .int64 (2 ), T .int64 (3 )):
1190+ with T .block ("compute" ):
1191+ v_i0 , v_i1 = T .axis .remap ("SS" , [i0 , i1 ])
1192+ T .reads (x [v_i0 , v_i1 ], slope_broadcasted [v_i1 ])
1193+ T .writes (compute [v_i0 , v_i1 ])
1194+ compute [v_i0 , v_i1 ] = T .Select (T .float32 (0.0 ) < x [v_i0 , v_i1 ], x [v_i0 , v_i1 ], x [v_i0 , v_i1 ] * slope_broadcasted [v_i1 ])
1195+ # fmt: on
1196+
1197+ mod = LegalizeOps ()(PRelu )
1198+ tvm .ir .assert_structural_equal (mod , Expected )
1199+
1200+
1201+ def test_prelu_symbolic ():
1202+ # fmt: off
1203+ @tvm .script .ir_module
1204+ class PRelu :
1205+ @R .function
1206+ def main (x : R .Tensor (("m" , 7 ), "float32" ), y : R .Tensor ((1 ,), "float32" )) -> R .Tensor (("m" , 7 ), "float32" ):
1207+ m = T .int64 ()
1208+ gv : R .Tensor ((m , 7 ), "float32" ) = R .nn .prelu (x , y )
1209+ return gv
1210+
1211+ @tvm .script .ir_module
1212+ class Expected :
1213+ @R .function
1214+ def main (x : R .Tensor (("m" , 7 ), dtype = "float32" ), y : R .Tensor ((1 ,), dtype = "float32" )) -> R .Tensor (("m" , 7 ), dtype = "float32" ):
1215+ m = T .int64 ()
1216+ gv = R .call_tir (Expected .prelu , (x , y ), out_sinfo = R .Tensor ((m , 7 ), dtype = "float32" ))
1217+ return gv
1218+
1219+ @T .prim_func (private = True )
1220+ def prelu (var_x : T .handle , y : T .Buffer ((T .int64 (1 ),), "float32" ), var_compute : T .handle ):
1221+ T .func_attr ({"tir.noalias" : True })
1222+ m = T .int64 ()
1223+ x = T .match_buffer (var_x , (m , T .int64 (7 )))
1224+ compute = T .match_buffer (var_compute , (m , T .int64 (7 )))
1225+ # with T.block("root"):
1226+ slope_broadcasted = T .alloc_buffer ((T .int64 (7 ),))
1227+ for c in range (T .int64 (7 )):
1228+ with T .block ("slope_broadcasted" ):
1229+ v_c = T .axis .spatial (T .int64 (7 ), c )
1230+ T .reads (y [T .int64 (0 )])
1231+ T .writes (slope_broadcasted [v_c ])
1232+ slope_broadcasted [v_c ] = y [T .int64 (0 )]
1233+ for i0 , i1 in T .grid (m , T .int64 (7 )):
1234+ with T .block ("compute" ):
1235+ v_i0 , v_i1 = T .axis .remap ("SS" , [i0 , i1 ])
1236+ T .reads (x [v_i0 , v_i1 ], slope_broadcasted [v_i1 ])
1237+ T .writes (compute [v_i0 , v_i1 ])
1238+ compute [v_i0 , v_i1 ] = T .Select (T .float32 (0.0 ) < x [v_i0 , v_i1 ], x [v_i0 , v_i1 ], x [v_i0 , v_i1 ] * slope_broadcasted [v_i1 ])
1239+ # fmt: on
1240+
1241+ mod = LegalizeOps ()(PRelu )
1242+ tvm .ir .assert_structural_equal (mod , Expected )
1243+
1244+
11621245def test_gelu ():
11631246 # fmt: off
11641247 @tvm .script .ir_module
0 commit comments