|
16 | 16 | # under the License. |
17 | 17 | """Test for argwhere operator""" |
18 | 18 | import numpy as np |
| 19 | +import pytest |
19 | 20 |
|
20 | 21 | import tvm |
| 22 | +import tvm.testing |
21 | 23 | from tvm import te |
22 | 24 | from tvm import topi |
23 | 25 | import tvm.topi.testing |
|
29 | 31 |
|
30 | 32 | _argwhere_compute = {"llvm": topi.argwhere, "cuda": topi.cuda.argwhere} |
31 | 33 |
|
| 34 | +data_shape = tvm.testing.parameter( |
| 35 | + (1,), |
| 36 | + (100,), |
| 37 | + (1, 1), |
| 38 | + (5, 3), |
| 39 | + (32, 64), |
| 40 | + (128, 65), |
| 41 | + (200, 500), |
| 42 | + (6, 5, 3), |
| 43 | + (1, 1, 1), |
| 44 | + (1, 1, 1, 1), |
| 45 | + (6, 4, 5, 3), |
| 46 | + (1, 1, 1, 1, 1), |
| 47 | + (6, 4, 5, 3, 7), |
| 48 | +) |
32 | 49 |
|
33 | | -def verify_argwhere(data_shape): |
| 50 | + |
| 51 | +@tvm.testing.parametrize_targets("llvm", "cuda") |
| 52 | +def test_argwhere(target, dev, data_shape): |
34 | 53 | dtype = "int32" |
35 | 54 | np_data = np.random.choice([0, 1, 2, 3], size=data_shape).astype(dtype) |
36 | 55 | np_out = np.argwhere(np_data) |
37 | 56 | out_shape = np_out.shape[0] |
| 57 | + |
38 | 58 | np_shape = np.ones(shape=(out_shape, len(data_shape)), dtype=dtype) |
39 | 59 |
|
40 | 60 | out_shape = te.placeholder(shape=(out_shape, len(data_shape)), name="out_shape", dtype=dtype) |
41 | 61 | condition = te.placeholder(shape=data_shape, name="condition", dtype=dtype) |
42 | 62 |
|
43 | | - def check_device(target): |
44 | | - dev = tvm.device(target, 0) |
45 | | - if not dev.exist or target not in _argwhere_compute: |
46 | | - return |
47 | | - |
48 | | - with tvm.target.Target(target): |
49 | | - out = _argwhere_compute[target](out_shape, condition) |
50 | | - s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule) |
51 | | - sch = s_func(out) |
52 | | - |
53 | | - func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere") |
54 | | - |
55 | | - args = [tvm.nd.array(np_shape, dev)] |
56 | | - args.append(tvm.nd.array(np_data, dev)) |
57 | | - args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype)) |
58 | | - func(*args) |
59 | | - np.set_printoptions(threshold=np.inf) |
60 | | - tvm.testing.assert_allclose(args[-1].numpy(), np.array(np_out)) |
61 | | - |
62 | | - for target, _ in tvm.testing.enabled_targets(): |
63 | | - check_device(target) |
| 63 | + with tvm.target.Target(target): |
| 64 | + out = _argwhere_compute[target](out_shape, condition) |
| 65 | + s_func = tvm.topi.testing.dispatch(target, _argwhere_schedule) |
| 66 | + sch = s_func(out) |
64 | 67 |
|
| 68 | + func = tvm.build(sch, [out_shape, condition, out], target, name="argwhere") |
65 | 69 |
|
66 | | -@tvm.testing.uses_gpu |
67 | | -def test_argwhere(): |
68 | | - verify_argwhere((1,)) |
69 | | - verify_argwhere((100,)) |
70 | | - verify_argwhere((1, 1)) |
71 | | - verify_argwhere((5, 3)) |
72 | | - verify_argwhere((32, 64)) |
73 | | - verify_argwhere((128, 65)) |
74 | | - verify_argwhere((200, 500)) |
75 | | - verify_argwhere((6, 5, 3)) |
76 | | - verify_argwhere((1, 1, 1)) |
77 | | - verify_argwhere((1, 1, 1, 1)) |
78 | | - verify_argwhere((6, 4, 5, 3)) |
79 | | - verify_argwhere((1, 1, 1, 1, 1)) |
80 | | - verify_argwhere((6, 4, 5, 3, 7)) |
| 70 | + args = [tvm.nd.array(np_shape, dev)] |
| 71 | + args.append(tvm.nd.array(np_data, dev)) |
| 72 | + args.append(tvm.nd.empty(out.shape, device=dev, dtype=condition.dtype)) |
| 73 | + func(*args) |
| 74 | + np.set_printoptions(threshold=np.inf) |
| 75 | + tvm_out = args[-1].numpy() |
| 76 | + tvm.testing.assert_allclose(tvm_out, np_out) |
81 | 77 |
|
82 | 78 |
|
83 | 79 | if __name__ == "__main__": |
84 | | - test_argwhere() |
| 80 | + tvm.testing.main() |
0 commit comments