Skip to content

Commit 0fb5ae2

Browse files
[Op][Topi] Gather, GatherND, Take can accept unsigned integers as indices (#10080)
* take rel * gather and more tests * gathernd case * lint * remove test which invalidates take preconditions * re-add test * fix dumb test failure oopsie
1 parent 21154c2 commit 0fb5ae2

File tree

4 files changed

+65
-42
lines changed

4 files changed

+65
-42
lines changed

include/tvm/topi/transform.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,7 +1321,7 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
13211321
size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
13221322
ICHECK_GE(indices_dim_i, 1);
13231323
}
1324-
ICHECK(indices->dtype.is_int());
1324+
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint());
13251325

13261326
Array<PrimExpr> out_shape;
13271327
for (size_t i = 0; i < ndim_i; ++i) {
@@ -1388,7 +1388,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim
13881388
}
13891389
for (size_t i = 0; i < indices_dim0; ++i) {
13901390
indices_position.Set(0, make_const(DataType::Int(32), i));
1391-
if (indices->dtype.is_int()) {
1391+
if (indices->dtype.is_int() || indices->dtype.is_uint()) {
13921392
real_indices.push_back(indices(indices_position));
13931393
} else {
13941394
real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));

src/relay/op/tensor/transform.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1276,7 +1276,8 @@ bool TakeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
12761276
if (indices == nullptr) {
12771277
return false;
12781278
}
1279-
ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
1279+
ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
1280+
<< "indices of take must be tensor of integer";
12801281
const auto param = attrs.as<TakeAttrs>();
12811282
ICHECK(param != nullptr);
12821283

tests/python/relay/test_op_level3.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,15 @@
2121

2222
import numpy as np
2323
import pytest
24-
2524
import tvm
2625
import tvm.testing
27-
2826
from tvm import relay, te
2927
from tvm.error import TVMError
3028
from tvm.relay import create_executor, transform
3129
from tvm.relay.testing import check_grad, run_infer_type
3230

3331
from utils import ref_funcs
3432

35-
3633
executor_kind = tvm.testing.parameter("graph", "debug")
3734

3835

@@ -426,31 +423,36 @@ def test_take(self, dshape, indices_shape, oshape, axis):
426423

427424

428425
class TestTake:
429-
src_shape, indices_src, axis, mode = tvm.testing.parameters(
430-
((4,), [1], None, "clip"),
431-
((4,), [[0, 1, 2, 3]], None, "clip"),
432-
((3, 3, 3), [[11, 25]], None, "clip"),
433-
((4,), [[0, 1], [2, 3]], None, "clip"),
434-
((4,), [1], 0, "clip"),
435-
((2, 2), [[[1, 0], [0, 1]]], 0, "clip"),
436-
((2, 2), [[[1, 0], [0, 1]]], 1, "clip"),
437-
((4, 3, 5, 6), [[2, 1, 0, 0]], -2, "clip"),
438-
((3, 4), [-5, 20], None, "clip"),
439-
((3, 4), [-5, 20], None, "wrap"),
440-
((3, 4), [-1, 2], 0, "clip"),
441-
((3, 4), [-1, 2], 0, "wrap"),
442-
((3, 4), [-1, 2], 1, "clip"),
443-
((3, 4), [-1, 2], 1, "wrap"),
444-
((3, 3, 3), [[11, 25]], None, "fast"),
445-
((3, 4), [0, 2], 0, "fast"),
446-
((3, 4), [0, 2], 1, "fast"),
426+
src_shape, indices_src, axis, mode, indices_dtype = tvm.testing.parameters(
427+
((4,), [1], None, "clip", "int32"),
428+
((4,), [[0, 1, 2, 3]], None, "clip", "int32"),
429+
((3, 3, 3), [[11, 25]], None, "clip", "int32"),
430+
((4,), [[0, 1], [2, 3]], None, "clip", "int32"),
431+
((4,), [1], 0, "clip", "int32"),
432+
((2, 2), [[[1, 0], [0, 1]]], 0, "clip", "int32"),
433+
((2, 2), [[[1, 0], [0, 1]]], 1, "clip", "int32"),
434+
((4, 3, 5, 6), [[2, 1, 0, 0]], -2, "clip", "int32"),
435+
((3, 4), [-5, 20], None, "clip", "int32"),
436+
((3, 4), [-5, 20], None, "wrap", "int32"),
437+
((3, 4), [-1, 2], 0, "clip", "int32"),
438+
((3, 4), [-1, 2], 0, "wrap", "int32"),
439+
((3, 4), [-1, 2], 1, "clip", "int32"),
440+
((3, 4), [-1, 2], 1, "wrap", "int32"),
441+
((3, 3, 3), [[11, 25]], None, "fast", "int32"),
442+
((3, 4), [0, 2], 0, "fast", "int32"),
443+
((3, 4), [0, 2], 1, "fast", "int32"),
444+
((3, 4), [1, 2], 1, "clip", "uint32"),
445+
((3, 4), [1, 2], 1, "wrap", "uint16"),
446+
((3, 3, 3), [1, 2], None, "fast", "uint16"),
447+
((3, 4), [0, 2], 0, "fast", "uint8"),
447448
)
448449

449450
# Incorrect numeric output in some cases on vulkan
450451
@tvm.testing.known_failing_targets("vulkan")
451-
def test_take(self, target, dev, executor_kind, src_shape, indices_src, axis, mode):
452+
def test_take(
453+
self, target, dev, executor_kind, src_shape, indices_src, axis, mode, indices_dtype
454+
):
452455
src_dtype = "float32"
453-
indices_dtype = "int32"
454456
indices_src = np.array(indices_src, dtype=indices_dtype)
455457
x = relay.var("x", relay.TensorType(src_shape, src_dtype))
456458
indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype))
@@ -459,11 +461,16 @@ def test_take(self, target, dev, executor_kind, src_shape, indices_src, axis, mo
459461
func = relay.Function([x, indices], z)
460462
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
461463
np_mode = "raise" if mode == "fast" else mode
462-
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode)
463464

464465
op_res = relay.create_executor(executor_kind, device=dev, target=target).evaluate(func)(
465466
x_data, indices_src
466467
)
468+
469+
# Old versions of numpy has take internally cast inside take which may violate
470+
# safety rules. We have such version in i386 CI image.
471+
indices_src = indices_src.astype("int32")
472+
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode)
473+
467474
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
468475

469476

@@ -1267,12 +1274,12 @@ def test_scatter_add(self, target, dev, ref_data, dshape, ishape, axis, dtype):
12671274
],
12681275
)
12691276
def test_gather(target, dev, executor_kind, data, axis, indices, ref_res):
1270-
def verify_gather(data, axis, indices, ref_res):
1277+
def verify_gather(data, axis, indices, ref_res, indices_dtype="int32"):
12711278
data = np.asarray(data, dtype="float32")
1272-
indices = np.asarray(indices, dtype="int32")
1279+
indices = np.asarray(indices, dtype=indices_dtype)
12731280
ref_res = np.asarray(ref_res)
12741281
d = relay.var("x", relay.TensorType(data.shape, "float32"))
1275-
i = relay.var("y", relay.TensorType(indices.shape, "int32"))
1282+
i = relay.var("y", relay.TensorType(indices.shape, indices_dtype))
12761283
z = relay.gather(d, axis, i)
12771284

12781285
func = relay.Function([d, i], z)
@@ -1283,22 +1290,25 @@ def verify_gather(data, axis, indices, ref_res):
12831290
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
12841291

12851292
verify_gather(data, axis, indices, ref_res)
1293+
verify_gather(data, axis, indices, ref_res, indices_dtype="uint32")
1294+
1295+
verify_gather(data, axis, indices, ref_res)
12861296

12871297

12881298
def test_gather_nd(target, dev, executor_kind):
1289-
def verify_gather_nd(xshape, yshape, y_data, batch_dims=0):
1299+
def verify_gather_nd(xshape, yshape, y_data, batch_dims=0, indices_dtype="int32"):
12901300
x = relay.var("x", relay.TensorType(xshape, "float32"))
1291-
y = relay.var("y", relay.TensorType(yshape, "int32"))
1301+
y = relay.var("y", relay.TensorType(yshape, indices_dtype))
12921302
z = relay.gather_nd(x, y, batch_dims)
12931303

12941304
func = relay.Function([x, y], z)
12951305

12961306
x_data = np.random.uniform(size=xshape).astype("float32")
12971307

12981308
if y_data:
1299-
y_data = np.array(y_data, dtype="int32")
1309+
y_data = np.array(y_data, dtype=indices_dtype)
13001310
else:
1301-
y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32")
1311+
y_data = np.random.randint(low=0, high=2, size=yshape, dtype=indices_dtype)
13021312

13031313
ref_res = ref_funcs.gather_nd(x_data, y_data, batch_dims)
13041314

@@ -1335,6 +1345,9 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0):
13351345
verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2, 2), None, 2)
13361346
verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2, 3), None, 2)
13371347

1348+
verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2, 3), None, 2, indices_dtype="uint8")
1349+
verify_gather_nd((2, 2, 2), (2, 2, 1), [[[1], [0]], [[0], [1]]], 1, indices_dtype="uint32")
1350+
13381351

13391352
def _verify_infiniteness_ops(relay_op, ref_op):
13401353
for dtype in ["float32", "float16", "float16", "int32", "int16"]:

tests/python/topi/python/test_topi_transform.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818
import numpy as np
1919
import pytest
2020
import tvm
21-
from tvm import te
22-
from tvm import topi
23-
from tvm import relay
21+
import tvm.testing
2422
import tvm.topi.testing
23+
from tvm import relay, te, topi
2524
from tvm.contrib.nvcc import have_fp16
2625

27-
import tvm.testing
28-
2926

3027
def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
3128
A = te.placeholder(shape=in_shape, name="A")
@@ -356,9 +353,8 @@ def check_device(target, dev):
356353
)
357354

358355

359-
def verify_take(src_shape, indices_src, axis=None, mode="clip"):
356+
def verify_take(src_shape, indices_src, axis=None, mode="clip", indices_dtype="int32"):
360357
src_dtype = "float32"
361-
indices_dtype = "int32"
362358
indices_src = np.array(indices_src, dtype=indices_dtype)
363359
A = te.placeholder(shape=src_shape, dtype=src_dtype, name="A")
364360
indices = te.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
@@ -999,6 +995,9 @@ def test_take():
999995
verify_take((3, 3, 3), [[11, 25]], mode="fast")
1000996
verify_take((3, 4), [0, 2], axis=0, mode="fast")
1001997
verify_take((3, 4), [0, 2], axis=1, mode="fast")
998+
verify_take((3, 4), [1, 2], axis=1, indices_dtype="uint32")
999+
verify_take((3, 4), [1, 2], axis=1, mode="wrap", indices_dtype="uint16")
1000+
verify_take((3, 3, 3), [[11, 20]], mode="fast", indices_dtype="uint8")
10021001

10031002

10041003
@tvm.testing.uses_gpu
@@ -1010,11 +1009,21 @@ def test_gather():
10101009
verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5)))
10111010
verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 2)))
10121011
verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 10)))
1012+
verify_gather(
1013+
np.random.randn(4, 7, 5),
1014+
2,
1015+
np.random.randint(low=0, high=5, size=(4, 7, 10)).astype("uint32"),
1016+
)
1017+
verify_gather(
1018+
np.random.randn(4, 7, 5),
1019+
2,
1020+
np.random.randint(low=0, high=5, size=(4, 7, 10)).astype("uint8"),
1021+
)
10131022

10141023

10151024
@tvm.testing.uses_gpu
10161025
def test_gather_nd():
1017-
for indices_dtype in ["int32", "float32"]:
1026+
for indices_dtype in ["int32", "float32", "uint8"]:
10181027
verify_gather_nd((4,), [[1.8]], indices_dtype)
10191028
verify_gather_nd((4,), [[1, 3, 2]], indices_dtype)
10201029
verify_gather_nd((2, 3), [[1]], indices_dtype)

0 commit comments

Comments
 (0)