Skip to content

Commit eb61148

Browse files
authored
[UnitTests] Parametrized test_topi_argwhere.py (#11651)
Refactored while debugging breakage of tests in #11646. Submitting as a separate PR, as it isn't necessary or related to the primary changes in that PR.
1 parent 9ecb571 commit eb61148

File tree

1 file changed

+34
-38
lines changed

1 file changed

+34
-38
lines changed

tests/python/topi/python/test_topi_argwhere.py

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
# under the License.
1717
"""Test for argwhere operator"""
1818
import numpy as np
19+
import pytest
1920

2021
import tvm
22+
import tvm.testing
2123
from tvm import te
2224
from tvm import topi
2325
import tvm.topi.testing
@@ -29,56 +31,50 @@
2931

3032
_argwhere_compute = {"llvm": topi.argwhere, "cuda": topi.cuda.argwhere}
3133

34+
data_shape = tvm.testing.parameter(
35+
(1,),
36+
(100,),
37+
(1, 1),
38+
(5, 3),
39+
(32, 64),
40+
(128, 65),
41+
(200, 500),
42+
(6, 5, 3),
43+
(1, 1, 1),
44+
(1, 1, 1, 1),
45+
(6, 4, 5, 3),
46+
(1, 1, 1, 1, 1),
47+
(6, 4, 5, 3, 7),
48+
)
3249

33-
def verify_argwhere(data_shape):
50+
51+
@tvm.testing.parametrize_targets("llvm", "cuda")
52+
def test_argwhere(target, dev, data_shape):
3453
dtype = "int32"
3554
np_data = np.random.choice([0, 1, 2, 3], size=data_shape).astype(dtype)
3655
np_out = np.argwhere(np_data)
3756
out_shape = np_out.shape[0]
57+
3858
np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype)
3959

4060
out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype)
4161
condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype)
4262

43-
def check_device(target):
44-
dev = tvm.device(target, 0)
45-
if not dev.exist or target not in _argwhere_compute:
46-
return
47-
48-
with tvm.target.Target(target):
49-
out = _argwhere_compute[target](out_shape, condition)
50-
s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule)
51-
sch = s_func(out)
52-
53-
func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere")
54-
55-
args = [tvm.nd.array(np_shape, dev)]
56-
args.append(tvm.nd.array(np_data, dev))
57-
args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype))
58-
func(*args)
59-
np.set_printoptions(threshold=np.inf)
60-
tvm.testing.assert_allclose(args[-1].numpy(), np.array(np_out))
61-
62-
for target, _ in tvm.testing.enabled_targets():
63-
check_device(target)
63+
with tvm.target.Target(target):
64+
out = _argwhere_compute[target](out_shape, condition)
65+
s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule)
66+
sch = s_func(out)
6467

68+
func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere")
6569

66-
@tvm.testing.uses_gpu
67-
def test_argwhere():
68-
verify_argwhere((1,))
69-
verify_argwhere((100,))
70-
verify_argwhere((1, 1))
71-
verify_argwhere((5, 3))
72-
verify_argwhere((32, 64))
73-
verify_argwhere((128, 65))
74-
verify_argwhere((200, 500))
75-
verify_argwhere((6, 5, 3))
76-
verify_argwhere((1, 1, 1))
77-
verify_argwhere((1, 1, 1, 1))
78-
verify_argwhere((6, 4, 5, 3))
79-
verify_argwhere((1, 1, 1, 1, 1))
80-
verify_argwhere((6, 4, 5, 3, 7))
70+
args = [tvm.nd.array(np_shape, dev)]
71+
args.append(tvm.nd.array(np_data, dev))
72+
args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype))
73+
func(*args)
74+
np.set_printoptions(threshold=np.inf)
75+
tvm_out = args[-1].numpy()
76+
tvm.testing.assert_allclose(tvm_out, np_out)
8177

8278

8379
if __name__ == "__main__":
84-
test_argwhere()
80+
tvm.testing.main()

0 commit comments

Comments
 (0)