@@ -36,14 +36,14 @@ def check_device(device):
3636 if not tvm .module .enabled (device ):
3737 print ("Skip because %s is not enabled" % device )
3838 return
39- ctx = tvm .gpu ( 0 ) if device == "cuda" else tvm . cl ( 0 )
39+ ctx = tvm .context ( device , 0 )
4040 a = tvm .nd .array (a_np , ctx )
4141 b = tvm .nd .array (np .zeros (get_const_tuple (B .shape ), dtype = dtype ), ctx )
4242 f = tvm .build (s , [A , B ], device )
4343 f (a , b )
4444 np .testing .assert_allclose (b .asnumpy (), b_np , rtol = 1e-5 )
4545
46- for device in ['cuda' , 'opencl' , 'metal' ]:
46+ for device in ['cuda' , 'opencl' , 'metal' , 'rocm' ]:
4747 check_device (device )
4848
4949def test_pool ():
@@ -70,14 +70,14 @@ def check_device(device):
7070 if not tvm .module .enabled (device ):
7171 print ("Skip because %s is not enabled" % device )
7272 return
73- ctx = tvm .gpu ( 0 ) if device == "cuda" else tvm . cl ( 0 )
73+ ctx = tvm .context ( device , 0 )
7474 a = tvm .nd .array (a_np , ctx )
7575 b = tvm .nd .array (np .zeros (get_const_tuple (B .shape ), dtype = B .dtype ), ctx )
7676 f = tvm .build (s , [A , B ], device )
7777 f (a , b )
7878 np .testing .assert_allclose (b .asnumpy (), b_np , rtol = 1e-5 )
7979
80- for device in ['cuda' , 'opencl' , 'metal' ]:
80+ for device in ['cuda' , 'opencl' , 'metal' , 'rocm' ]:
8181 check_device (device )
8282
8383def test_global_pool ():
0 commit comments