diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 261fdf9970a3..3ad230560f3a 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1134,6 +1134,9 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, size_t ndim_i = indices->shape.size(); ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar."; ICHECK_EQ(ndim_d, ndim_i); + if (axis < 0) { + axis += ndim_d; + } ICHECK_GE(axis, 0); ICHECK_LT(axis, ndim_d); size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index f2e3850a8f67..4129b610cb7c 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1047,7 +1047,7 @@ def gather(data, axis, indices): The input data to the operator. axis: int - The axis along which to index. + The axis along which to index. negative axis is supported. indices: relay.Expr The indices of values to gather. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 941f43a5a2c4..e3929bf8b77e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3179,6 +3179,9 @@ bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs, const auto ndim_indices = indices->shape.size(); int axis = param->axis->value; ICHECK_EQ(ndim_data, ndim_indices); + if (axis < 0) { + axis += ndim_data; + } ICHECK_GE(axis, 0); ICHECK_LT(axis, ndim_data); diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 54bf2fd49acb..6491a0a464d6 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3704,13 +3704,13 @@ def test_fn(dim, descending): inp = torch.randn(100) verify_model(test_fn(0, True), [inp]) - verify_model(test_fn(0, False), [inp]) + verify_model(test_fn(-1, False), [inp]) inp = torch.randn(100, 100) verify_model(test_fn(0, True), [inp]) - verify_model(test_fn(0, False), [inp]) + verify_model(test_fn(-2, False), [inp]) verify_model(test_fn(1, True), [inp]) - verify_model(test_fn(1, False), [inp]) + verify_model(test_fn(-1, False), [inp]) def test_logical_and(): diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 31b95b0b49ae..d2a5090943c3 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1075,12 +1075,166 @@ def verify_scatter_add(dshape, ishape, axis=0, dtype="float32"): @tvm.testing.uses_gpu -def test_gather(): +@pytest.mark.parametrize( + "data, axis, indices, ref_res", + [ + ([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]], [[1, 1], [4, 3]]), + ([[1, 2], [3, 4]], -1, [[0, 0], [1, 0]], [[1, 1], [4, 3]]), + ( + [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], + 0, + [[[1, 0, 1], [1, 1, 0]]], + [[[6, 1, 8], [9, 10, 5]]], + ), + ( + [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], + -3, + [[[1, 0, 1], [1, 1, 0]]], + [[[6, 1, 8], [9, 10, 5]]], + ), + ( + [ + [ + [-0.2321, -0.2024, -1.7624], + [-0.3829, -0.4246, 0.2448], + [0.1822, 0.2360, -0.8965], + [0.4497, -0.2224, 0.6103], + ], + [ + [0.0408, -0.7667, -0.4303], + [-0.3216, 0.7489, -0.1502], + [0.0144, -0.4699, -0.0064], + [-0.0768, -1.6064, 1.3390], + ], + ], + 1, + [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]], + [ + [[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]], + [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]], + ], + ), + ( + [ + [ + [-0.2321, -0.2024, -1.7624], + [-0.3829, -0.4246, 0.2448], + [0.1822, 0.2360, -0.8965], + [0.4497, -0.2224, 0.6103], + ], + [ + [0.0408, -0.7667, -0.4303], + [-0.3216, 0.7489, -0.1502], + [0.0144, -0.4699, -0.0064], + [-0.0768, -1.6064, 1.3390], + ], + ], + -2, + [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]], + [ + [[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]], + [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]], + ], + ), + ( + [ + [ + [-0.2321, -0.2024, -1.7624], + [-0.3829, -0.4246, 0.2448], + [0.1822, 0.2360, -0.8965], + [0.4497, -0.2224, 0.6103], + ], + [ + [0.0408, -0.7667, -0.4303], + [-0.3216, 0.7489, -0.1502], + [0.0144, -0.4699, -0.0064], + [-0.0768, -1.6064, 1.3390], + ], + ], + -2, + [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]], + [ + [[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]], + [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]], + ], + ), + ( + [ + [ + [0.3050, 1.6986, 1.1034], + [0.7020, -0.6960, -2.1818], + [0.3116, -0.5773, -0.9912], + [0.0835, -1.3915, -1.0720], + ], + [ + [0.1694, -0.6091, -0.6539], + [-0.5234, -0.1218, 0.5084], + [0.2374, -1.9537, -2.0078], + [-0.5700, -1.0302, 0.1558], + ], + ], + 2, + [ + [[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]], + [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]], + ], + [ + [ + [1.6986, 1.6986, 0.3050, 1.6986], + [0.7020, 0.7020, -2.1818, -2.1818], + [-0.5773, -0.9912, -0.5773, -0.9912], + [-1.0720, -1.0720, -1.3915, 0.0835], + ], + [ + [0.1694, 0.1694, -0.6091, -0.6539], + [0.5084, 0.5084, -0.1218, -0.5234], + [-1.9537, -2.0078, 0.2374, 0.2374], + [-0.5700, 0.1558, -0.5700, 0.1558], + ], + ], + ), + ( + [ + [ + [0.3050, 1.6986, 1.1034], + [0.7020, -0.6960, -2.1818], + [0.3116, -0.5773, -0.9912], + [0.0835, -1.3915, -1.0720], + ], + [ + [0.1694, -0.6091, -0.6539], + [-0.5234, -0.1218, 0.5084], + [0.2374, -1.9537, -2.0078], + [-0.5700, -1.0302, 0.1558], + ], + ], + -1, + [ + [[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]], + [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]], + ], + [ + [ + [1.6986, 1.6986, 0.3050, 1.6986], + [0.7020, 0.7020, -2.1818, -2.1818], + [-0.5773, -0.9912, -0.5773, -0.9912], + [-1.0720, -1.0720, -1.3915, 0.0835], + ], + [ + [0.1694, 0.1694, -0.6091, -0.6539], + [0.5084, 0.5084, -0.1218, -0.5234], + [-1.9537, -2.0078, 0.2374, 0.2374], + [-0.5700, 0.1558, -0.5700, 0.1558], + ], + ], + ), + ], +) +def test_gather(data, axis, indices, ref_res): def verify_gather(data, axis, indices, ref_res): data = np.asarray(data, dtype="float32") indices = np.asarray(indices, dtype="int32") ref_res = np.asarray(ref_res) - d = relay.var("x", relay.TensorType(data.shape, "float32")) i = relay.var("y", relay.TensorType(indices.shape, "int32")) z = relay.gather(d, axis, i) @@ -1093,70 +1247,7 @@ def verify_gather(data, axis, indices, ref_res): op_res = intrp.evaluate(func)(data, indices) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - verify_gather([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]], [[1, 1], [4, 3]]) - verify_gather( - [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], - 0, - [[[1, 0, 1], [1, 1, 0]]], - [[[6, 1, 8], [9, 10, 5]]], - ) - verify_gather( - [ - [ - [-0.2321, -0.2024, -1.7624], - [-0.3829, -0.4246, 0.2448], - [0.1822, 0.2360, -0.8965], - [0.4497, -0.2224, 0.6103], - ], - [ - [0.0408, -0.7667, -0.4303], - [-0.3216, 0.7489, -0.1502], - [0.0144, -0.4699, -0.0064], - [-0.0768, -1.6064, 1.3390], - ], - ], - 1, - [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]], - [ - [[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]], - [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]], - ], - ) - verify_gather( - [ - [ - [0.3050, 1.6986, 1.1034], - [0.7020, -0.6960, -2.1818], - [0.3116, -0.5773, -0.9912], - [0.0835, -1.3915, -1.0720], - ], - [ - [0.1694, -0.6091, -0.6539], - [-0.5234, -0.1218, 0.5084], - [0.2374, -1.9537, -2.0078], - [-0.5700, -1.0302, 0.1558], - ], - ], - 2, - [ - [[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]], - [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]], - ], - [ - [ - [1.6986, 1.6986, 0.3050, 1.6986], - [0.7020, 0.7020, -2.1818, -2.1818], - [-0.5773, -0.9912, -0.5773, -0.9912], - [-1.0720, -1.0720, -1.3915, 0.0835], - ], - [ - [0.1694, 0.1694, -0.6091, -0.6539], - [0.5084, 0.5084, -0.1218, -0.5234], - [-1.9537, -2.0078, 0.2374, 0.2374], - [-0.5700, 0.1558, -0.5700, 0.1558], - ], - ], - ) + verify_gather(data, axis, indices, ref_res) @tvm.testing.uses_gpu