2222import tvm .script
2323import tvm .testing
2424from tvm import IRModule , relax , tir , topi
25- from tvm .relax import DynTensorType
26- from tvm .script import ir as I
27- from tvm .script import relax as R
28- from tvm .script import tir as T
25+ from tvm .script .parser import ir as I
26+ from tvm .script .parser import relax as R
27+ from tvm .script .parser import tir as T
2928
3029
3130def _check (
@@ -202,6 +201,23 @@ def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"):
202201 _check (foo , bb .get ()["foo" ])
203202
204203
204+ def test_relax_base_op ():
205+ @R .function
206+ def foo (x : R .Tensor ((4 , 4 ), "float32" )):
207+ alloc = R .builtin .alloc_tensor (R .shape ([4 , 4 ]), runtime_device_index = 0 , dtype = "float32" )
208+ shape = R .shape_of (alloc )
209+ return shape
210+
211+ x = relax .Var ("x" , R .Tensor ((4 , 4 ), "float32" ))
212+ bb = relax .BlockBuilder ()
213+ with bb .function ("foo" , (x ,)):
214+ alloc = bb .emit (relax .op .builtin .alloc_tensor (relax .ShapeExpr ((4 , 4 )), "float32" , 0 ))
215+ shape = bb .emit (relax .op .shape_of (alloc ))
216+ bb .emit_func_output (shape )
217+ # todo(yongwww): comment this check because 0 was changed to R.prim_value(0) in the printed IR
218+ # _check(foo, bb.get()["foo"])
219+
220+
205221def test_symbolic_shape ():
206222 @R .function
207223 def foo (x : R .Tensor (("m" , "n" ), "float32" )) -> R .Tensor (("m" , "n" ), "float32" ):
@@ -274,7 +290,7 @@ def foo(x: R.Tensor("float32"), y: R.Tensor("float32")):
274290 y0 = R .match_cast (y , R .Tensor ([n ], "float32" ))
275291 gv = y0
276292 R .output (gv )
277- return (x0 , ( m , n * 2 ))
293+ return (x0 , R . shape ([ m , n * 2 ] ))
278294
279295 x = relax .Var ("x" , R .Tensor ("float32" ))
280296 y = relax .Var ("y" , R .Tensor ("float32" ))
@@ -314,7 +330,7 @@ def test_tuple_return_2():
314330 def foo (x : R .Tensor ("float32" , ndim = 2 )):
315331 n , m = T .var ("int64" ), T .var ("int64" )
316332 x0 = R .match_cast (x , R .Tensor ((n , m ), "float32" ))
317- return (x0 , ( n + 1 , m , 1 ))
333+ return (x0 , R . shape ([ n + 1 , m , 1 ] ))
318334
319335 x = relax .Var ("x" , R .Tensor ("float32" , ndim = 2 ))
320336 n , m = tir .Var ("n" , "int64" ), tir .Var ("m" , "int64" )
@@ -332,7 +348,7 @@ def foo(x: R.Tensor("float32", ndim=2)):
332348 n , m = T .var ("int64" ), T .var ("int64" )
333349 x0 = R .match_cast (x , R .Tensor ((n , m ), "float32" ))
334350 t0 = (x , x0 )
335- t1 = (x , ( n , m ), t0 )
351+ t1 = (x , R . shape ([ n , m ] ), t0 )
336352 return t1
337353
338354 x = relax .Var ("x" , R .Tensor ("float32" , ndim = 2 ))
@@ -965,9 +981,9 @@ def test_vm_ops():
965981 def foo (x : R .Tensor (("m" , "n" ), dtype = "float32" )):
966982 m = T .var ("int64" )
967983 n = T .var ("int64" )
968- storage = R .vm .alloc_storage (( 4 * m * n , ), dtype = "float32" , runtime_device_index = 0 )
969- alloc = R .vm .alloc_tensor (storage , ( m , n ), offset = 0 , dtype = "float32" )
970- tensor = R .builtin .alloc_tensor (( m , n ), dtype = "float32" , runtime_device_index = 0 )
984+ storage = R .vm .alloc_storage (R . shape ([ 4 * m * n ] ), dtype = "float32" , runtime_device_index = 0 )
985+ alloc = R .vm .alloc_tensor (storage , shape = R . shape ([ m , n ] ), offset = 0 , dtype = "float32" )
986+ tensor = R .builtin .alloc_tensor (R . shape ([ m , n ] ), dtype = "float32" , runtime_device_index = 0 )
971987 _ = R .vm .call_tir_dyn ("te_func" , (x , tensor , (m , n )))
972988 gv = tensor
973989 return alloc , gv
0 commit comments