@@ -1088,5 +1088,93 @@ def main(x: R.Tensor((8,), dtype="float32")) -> R.Tuple(R.Tensor((8,), dtype="fl
10881088 return gv
10891089
10901090
1091+ class TestStaticInputWithSymbolicShape (BaseCompare ):
1092+ @I .ir_module
1093+ class Before :
1094+ @R .function
1095+ def main (x : R .Tensor ((8 ,), "float16" ), w : R .Tensor (("m" ,))):
1096+ m = T .int64 ()
1097+ R .func_attr ({"relax.force_pure" : True , "num_input" : 1 })
1098+ storage1 = R .memory .alloc_storage (R .shape ([8 ]), 0 , "global" , "float16" )
1099+ alloc1 = R .memory .alloc_tensor (storage1 , 0 , R .shape ([8 ]), "float16" )
1100+ _ = R .call_packed ("dummy" , x , w , alloc1 , sinfo_args = (R .Tuple ,))
1101+ storage2 = R .memory .alloc_storage (R .shape ([8 ]), 0 , "global" , "float16" )
1102+ alloc2 = R .memory .alloc_tensor (storage2 , 0 , R .shape ([8 ]), "float16" )
1103+ _1 = R .call_packed ("dummy" , alloc1 , w , alloc2 , sinfo_args = (R .Tuple ,))
1104+ storage3 = R .memory .alloc_storage (R .shape ([8 ]), 0 , "global" , "float16" )
1105+ alloc3 = R .memory .alloc_tensor (storage3 , 0 , R .shape ([8 ]), "float16" )
1106+ _2 = R .call_packed ("dummy" , alloc2 , w , alloc3 , sinfo_args = (R .Tuple ,))
1107+ gv = (alloc3 ,)
1108+ return gv
1109+
1110+ @I .ir_module
1111+ class Expected :
1112+ @R .function (private = True )
1113+ def cuda_graph_alloc () -> R .Tuple (R .Object , R .Object ):
1114+ R .func_attr ({"relax.force_pure" : True })
1115+ storage1 : R .Object = R .memory .alloc_storage (
1116+ R .shape ([8 ]), R .prim_value (0 ), R .str ("global" ), R .dtype ("float16" )
1117+ )
1118+ storage2 : R .Object = R .memory .alloc_storage (
1119+ R .shape ([8 ]), R .prim_value (0 ), R .str ("global" ), R .dtype ("float16" )
1120+ )
1121+ gv : R .Tuple (R .Object , R .Object ) = storage1 , storage2
1122+ return gv
1123+
1124+ @R .function (private = True )
1125+ def main_cuda_graph_capture (
1126+ alloc1 : R .Tensor ((8 ,), dtype = "float16" ),
1127+ w : R .Tensor (("m" ,)),
1128+ alloc2 : R .Tensor ((8 ,), dtype = "float16" ),
1129+ shape_expr : R .Shape (["m" ]),
1130+ ) -> R .Tuple :
1131+ m = T .int64 ()
1132+ R .func_attr ({"relax.force_pure" : True })
1133+ R .call_packed ("dummy" , alloc1 , w , alloc2 , sinfo_args = (R .Tuple ,))
1134+ R .tuple ()
1135+ return R .tuple ()
1136+
1137+ @R .function
1138+ def main (
1139+ x : R .Tensor ((8 ,), dtype = "float16" ), w : R .Tensor (("m" ,))
1140+ ) -> R .Tuple (R .Tensor ((8 ,), dtype = "float16" )):
1141+ m = T .int64 ()
1142+ R .func_attr ({"num_input" : 1 , "relax.force_pure" : True })
1143+ cls = Expected
1144+ gv : R .Tuple (R .Object , R .Object ) = R .call_builtin_with_ctx (
1145+ "vm.builtin.cuda_graph.get_cached_alloc" ,
1146+ (cls .cuda_graph_alloc , R .prim_value (0 )),
1147+ sinfo_args = (R .Tuple (R .Object , R .Object ),),
1148+ )
1149+ storage1 : R .Object = gv [0 ]
1150+ alloc1 : R .Tensor ((8 ,), dtype = "float16" ) = R .memory .alloc_tensor (
1151+ storage1 , R .prim_value (0 ), R .shape ([8 ]), R .dtype ("float16" )
1152+ )
1153+ R .call_packed ("dummy" , x , w , alloc1 , sinfo_args = (R .Tuple ,))
1154+ storage2 : R .Object = gv [1 ]
1155+ alloc2 : R .Tensor ((8 ,), dtype = "float16" ) = R .memory .alloc_tensor (
1156+ storage2 , R .prim_value (0 ), R .shape ([8 ]), R .dtype ("float16" )
1157+ )
1158+ R .call_builtin_with_ctx (
1159+ "vm.builtin.cuda_graph.run_or_capture" ,
1160+ (
1161+ cls .main_cuda_graph_capture ,
1162+ (alloc1 , w , alloc2 , R .shape ([m ])),
1163+ R .prim_value (0 ),
1164+ R .shape ([m ]),
1165+ ),
1166+ sinfo_args = (R .Tuple ,),
1167+ )
1168+ storage3 : R .Object = R .memory .alloc_storage (
1169+ R .shape ([8 ]), R .prim_value (0 ), R .str ("global" ), R .dtype ("float16" )
1170+ )
1171+ alloc3 : R .Tensor ((8 ,), dtype = "float16" ) = R .memory .alloc_tensor (
1172+ storage3 , R .prim_value (0 ), R .shape ([8 ]), R .dtype ("float16" )
1173+ )
1174+ R .call_packed ("dummy" , alloc2 , w , alloc3 , sinfo_args = (R .Tuple ,))
1175+ gv_1 : R .Tuple (R .Tensor ((8 ,), dtype = "float16" )) = (alloc3 ,)
1176+ return gv_1
1177+
1178+
10911179if __name__ == "__main__" :
10921180 tvm .testing .main ()
0 commit comments