Skip to content

Commit 85c545c

Browse files
masahitqchen
authored andcommitted
Add rocm target to topi tests (#548)
* add masahi to contributors * enable rocm target in topi tests
1 parent 74b0ca8 commit 85c545c

13 files changed

+43
-38
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@ List of Contributors
3434
- To contributors: please add your name to the list.
3535
- [Qiao Zhang](https://github.com/zhangqiaorjc)
3636
- [Jian Weng](https://github.com/were)
37+
- [Masahiro Masuda](https://github.com/masahi)

topi/tests/python/test_topi_broadcast.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def check_device(device):
1313
if not tvm.module.enabled(device):
1414
print("Skip because %s is not enabled" % device)
1515
return
16-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
16+
ctx = tvm.context(device, 0)
1717
foo = tvm.build(s, [A, B], device, name="broadcast_to")
1818
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
1919
out_npy = np.broadcast_to(data_npy, out_shape)
@@ -27,6 +27,7 @@ def check_device(device):
2727
check_device("opencl")
2828
check_device("cuda")
2929
check_device("metal")
30+
check_device("rocm")
3031

3132

3233
def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
@@ -52,7 +53,7 @@ def check_device(device):
5253
if not tvm.module.enabled(device):
5354
print("Skip because %s is not enabled" % device)
5455
return
55-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
56+
ctx = tvm.context(device, 0)
5657
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
5758
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
5859
rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
@@ -81,7 +82,7 @@ def check_device(device):
8182
check_device("opencl")
8283
check_device("cuda")
8384
check_device("metal")
84-
85+
check_device("rocm")
8586

8687
def test_broadcast_to():
8788
verify_broadcast_to_ele((1,), (10,))

topi/tests/python/test_topi_conv2d_hwcn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,22 @@ def check_device(device):
3434
if not tvm.module.enabled(device):
3535
print("Skip because %s is not enabled" % device)
3636
return
37-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
37+
ctx = tvm.context(device, 0)
3838
a = tvm.nd.array(a_np, ctx)
3939
w = tvm.nd.array(w_np, ctx)
4040
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
4141
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
4242
with tvm.build_config(auto_unroll_max_step=32,
4343
auto_unroll_min_depth=0,
44-
unroll_explicit=False):
44+
unroll_explicit=device == 'rocm'):
4545
func1 = tvm.build(s1, [A, W, B], device)
4646
func2 = tvm.build(s2, [A, W, C], device)
4747
func1(a, w, b)
4848
func2(a, w, c)
4949
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
5050
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
5151

52-
for device in ['cuda', 'opencl', 'metal']:
52+
for device in ['cuda', 'opencl', 'metal', 'rocm']:
5353
check_device(device)
5454

5555

topi/tests/python/test_topi_conv2d_nchw.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,22 @@ def check_device(device):
3535
if not tvm.module.enabled(device):
3636
print("Skip because %s is not enabled" % device)
3737
return
38-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
38+
ctx = tvm.context(device, 0)
3939
a = tvm.nd.array(a_np, ctx)
4040
w = tvm.nd.array(w_np, ctx)
4141
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
4242
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
4343
with tvm.build_config(auto_unroll_max_step=32,
4444
auto_unroll_min_depth=0,
45-
unroll_explicit=False):
45+
unroll_explicit=device == 'rocm'):
4646
func1 = tvm.build(s1, [A, W, B], device)
4747
func2 = tvm.build(s2, [A, W, C], device)
4848
func1(a, w, b)
4949
func2(a, w, c)
5050
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
5151
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
5252

53-
for device in ['cuda', 'opencl', 'metal']:
53+
for device in ['cuda', 'opencl', 'metal', 'rocm']:
5454
check_device(device)
5555

5656

topi/tests/python/test_topi_dense.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def check_device(device):
3333
if not tvm.module.enabled(device):
3434
print("Skip because %s is not enabled" % device)
3535
return
36-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
36+
ctx = tvm.context(device, 0)
3737
a = tvm.nd.array(a_np, ctx)
3838
b = tvm.nd.array(b_np, ctx)
3939
c = tvm.nd.array(c_np, ctx)
@@ -42,7 +42,7 @@ def check_device(device):
4242
f(a, b, c, d)
4343
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
4444

45-
for device in ['cuda', 'opencl', 'metal']:
45+
for device in ['cuda', 'opencl', 'metal', 'rocm']:
4646
check_device(device)
4747

4848
def test_dense():

topi/tests/python/test_topi_depthwise_conv2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def get_ref_data():
8787
check_device("opencl")
8888
check_device("cuda")
8989
check_device("metal")
90-
90+
check_device("rocm")
91+
9192
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
9293
in_width = in_height
9394
filter_channel = in_channel
@@ -170,7 +171,7 @@ def get_ref_data():
170171
check_device("opencl")
171172
check_device("cuda")
172173
check_device("metal")
173-
174+
check_device("rocm")
174175

175176
def test_depthwise_conv2d():
176177
print("testing nchw")

topi/tests/python/test_topi_depthwise_conv2d_back_input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def get_ref_data():
8383
check_device("opencl")
8484
check_device("cuda")
8585
check_device("metal")
86-
86+
check_device("rocm")
8787

8888
def test_topi_depthwise_conv2d_backward_input_nhwc():
8989
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1)

topi/tests/python/test_topi_depthwise_conv2d_back_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_ref_data():
7676
check_device("opencl")
7777
check_device("cuda")
7878
check_device("metal")
79-
79+
check_device("rocm")
8080

8181
def test_topi_depthwise_conv2d_backward_weight_nhwc():
8282
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1)

topi/tests/python/test_topi_pooling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ def check_device(device):
3636
if not tvm.module.enabled(device):
3737
print("Skip because %s is not enabled" % device)
3838
return
39-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
39+
ctx = tvm.context(device, 0)
4040
a = tvm.nd.array(a_np, ctx)
4141
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
4242
f = tvm.build(s, [A, B], device)
4343
f(a, b)
4444
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
4545

46-
for device in ['cuda', 'opencl', 'metal']:
46+
for device in ['cuda', 'opencl', 'metal', 'rocm']:
4747
check_device(device)
4848

4949
def test_pool():
@@ -70,14 +70,14 @@ def check_device(device):
7070
if not tvm.module.enabled(device):
7171
print("Skip because %s is not enabled" % device)
7272
return
73-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
73+
ctx = tvm.context(device, 0)
7474
a = tvm.nd.array(a_np, ctx)
7575
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
7676
f = tvm.build(s, [A, B], device)
7777
f(a, b)
7878
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
7979

80-
for device in ['cuda', 'opencl', 'metal']:
80+
for device in ['cuda', 'opencl', 'metal', 'rocm']:
8181
check_device(device)
8282

8383
def test_global_pool():

topi/tests/python/test_topi_reduce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def check_device(device):
5050
if not tvm.module.enabled(device):
5151
print("Skip because %s is not enabled" % device)
5252
return
53-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
53+
ctx = tvm.context(device, 0)
5454
foo = tvm.build(s, [A, B], device, name="sum")
5555
# Test
5656
in_npy = np.random.uniform(size=in_shape).astype(np.float32)
@@ -76,7 +76,7 @@ def check_device(device):
7676
check_device("opencl")
7777
check_device("cuda")
7878
check_device("metal")
79-
79+
check_device("rocm")
8080

8181
def test_reduce_map():
8282
verify_reduce_map_ele(in_shape=(128, 24, 128, 24),

0 commit comments

Comments
 (0)