Skip to content

Commit

Permalink
Support negative axis for gather (apache#7600)
Browse files Browse the repository at this point in the history
* Fix negative axis in gather

* Clang Format

* Black

* Empty Commit

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
2 people authored and Lokiiiiii committed Mar 5, 2021
1 parent 9759b6b commit 9a9282a
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 70 deletions.
3 changes: 3 additions & 0 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,9 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
size_t ndim_i = indices->shape.size();
ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
ICHECK_EQ(ndim_d, ndim_i);
if (axis < 0) {
axis += ndim_d;
}
ICHECK_GE(axis, 0);
ICHECK_LT(axis, ndim_d);
size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,7 @@ def gather(data, axis, indices):
The input data to the operator.
axis: int
The axis along which to index.
The axis along which to index. negative axis is supported.
indices: relay.Expr
The indices of values to gather.
Expand Down
3 changes: 3 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3231,6 +3231,9 @@ bool GatherRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const auto ndim_indices = indices->shape.size();
int axis = param->axis->value;
ICHECK_EQ(ndim_data, ndim_indices);
if (axis < 0) {
axis += ndim_data;
}
ICHECK_GE(axis, 0);
ICHECK_LT(axis, ndim_data);

Expand Down
6 changes: 3 additions & 3 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3661,13 +3661,13 @@ def test_fn(dim, descending):

inp = torch.randn(100)
verify_model(test_fn(0, True), [inp])
verify_model(test_fn(0, False), [inp])
verify_model(test_fn(-1, False), [inp])

inp = torch.randn(100, 100)
verify_model(test_fn(0, True), [inp])
verify_model(test_fn(0, False), [inp])
verify_model(test_fn(-2, False), [inp])
verify_model(test_fn(1, True), [inp])
verify_model(test_fn(1, False), [inp])
verify_model(test_fn(-1, False), [inp])


def test_logical_and():
Expand Down
223 changes: 157 additions & 66 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,12 +1105,166 @@ def verify_scatter_add(dshape, ishape, axis=0, dtype="float32"):


@tvm.testing.uses_gpu
def test_gather():
@pytest.mark.parametrize(
"data, axis, indices, ref_res",
[
([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]], [[1, 1], [4, 3]]),
([[1, 2], [3, 4]], -1, [[0, 0], [1, 0]], [[1, 1], [4, 3]]),
(
[[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]],
0,
[[[1, 0, 1], [1, 1, 0]]],
[[[6, 1, 8], [9, 10, 5]]],
),
(
[[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]],
-3,
[[[1, 0, 1], [1, 1, 0]]],
[[[6, 1, 8], [9, 10, 5]]],
),
(
[
[
[-0.2321, -0.2024, -1.7624],
[-0.3829, -0.4246, 0.2448],
[0.1822, 0.2360, -0.8965],
[0.4497, -0.2224, 0.6103],
],
[
[0.0408, -0.7667, -0.4303],
[-0.3216, 0.7489, -0.1502],
[0.0144, -0.4699, -0.0064],
[-0.0768, -1.6064, 1.3390],
],
],
1,
[[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]],
[
[[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]],
[[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]],
],
),
(
[
[
[-0.2321, -0.2024, -1.7624],
[-0.3829, -0.4246, 0.2448],
[0.1822, 0.2360, -0.8965],
[0.4497, -0.2224, 0.6103],
],
[
[0.0408, -0.7667, -0.4303],
[-0.3216, 0.7489, -0.1502],
[0.0144, -0.4699, -0.0064],
[-0.0768, -1.6064, 1.3390],
],
],
-2,
[[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]],
[
[[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]],
[[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]],
],
),
(
[
[
[-0.2321, -0.2024, -1.7624],
[-0.3829, -0.4246, 0.2448],
[0.1822, 0.2360, -0.8965],
[0.4497, -0.2224, 0.6103],
],
[
[0.0408, -0.7667, -0.4303],
[-0.3216, 0.7489, -0.1502],
[0.0144, -0.4699, -0.0064],
[-0.0768, -1.6064, 1.3390],
],
],
-2,
[[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]],
[
[[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]],
[[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]],
],
),
(
[
[
[0.3050, 1.6986, 1.1034],
[0.7020, -0.6960, -2.1818],
[0.3116, -0.5773, -0.9912],
[0.0835, -1.3915, -1.0720],
],
[
[0.1694, -0.6091, -0.6539],
[-0.5234, -0.1218, 0.5084],
[0.2374, -1.9537, -2.0078],
[-0.5700, -1.0302, 0.1558],
],
],
2,
[
[[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]],
[[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]],
],
[
[
[1.6986, 1.6986, 0.3050, 1.6986],
[0.7020, 0.7020, -2.1818, -2.1818],
[-0.5773, -0.9912, -0.5773, -0.9912],
[-1.0720, -1.0720, -1.3915, 0.0835],
],
[
[0.1694, 0.1694, -0.6091, -0.6539],
[0.5084, 0.5084, -0.1218, -0.5234],
[-1.9537, -2.0078, 0.2374, 0.2374],
[-0.5700, 0.1558, -0.5700, 0.1558],
],
],
),
(
[
[
[0.3050, 1.6986, 1.1034],
[0.7020, -0.6960, -2.1818],
[0.3116, -0.5773, -0.9912],
[0.0835, -1.3915, -1.0720],
],
[
[0.1694, -0.6091, -0.6539],
[-0.5234, -0.1218, 0.5084],
[0.2374, -1.9537, -2.0078],
[-0.5700, -1.0302, 0.1558],
],
],
-1,
[
[[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]],
[[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]],
],
[
[
[1.6986, 1.6986, 0.3050, 1.6986],
[0.7020, 0.7020, -2.1818, -2.1818],
[-0.5773, -0.9912, -0.5773, -0.9912],
[-1.0720, -1.0720, -1.3915, 0.0835],
],
[
[0.1694, 0.1694, -0.6091, -0.6539],
[0.5084, 0.5084, -0.1218, -0.5234],
[-1.9537, -2.0078, 0.2374, 0.2374],
[-0.5700, 0.1558, -0.5700, 0.1558],
],
],
),
],
)
def test_gather(data, axis, indices, ref_res):
def verify_gather(data, axis, indices, ref_res):
data = np.asarray(data, dtype="float32")
indices = np.asarray(indices, dtype="int32")
ref_res = np.asarray(ref_res)

d = relay.var("x", relay.TensorType(data.shape, "float32"))
i = relay.var("y", relay.TensorType(indices.shape, "int32"))
z = relay.gather(d, axis, i)
Expand All @@ -1123,70 +1277,7 @@ def verify_gather(data, axis, indices, ref_res):
op_res = intrp.evaluate(func)(data, indices)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

verify_gather([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]], [[1, 1], [4, 3]])
verify_gather(
[[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]],
0,
[[[1, 0, 1], [1, 1, 0]]],
[[[6, 1, 8], [9, 10, 5]]],
)
verify_gather(
[
[
[-0.2321, -0.2024, -1.7624],
[-0.3829, -0.4246, 0.2448],
[0.1822, 0.2360, -0.8965],
[0.4497, -0.2224, 0.6103],
],
[
[0.0408, -0.7667, -0.4303],
[-0.3216, 0.7489, -0.1502],
[0.0144, -0.4699, -0.0064],
[-0.0768, -1.6064, 1.3390],
],
],
1,
[[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]],
[
[[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]],
[[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]],
],
)
verify_gather(
[
[
[0.3050, 1.6986, 1.1034],
[0.7020, -0.6960, -2.1818],
[0.3116, -0.5773, -0.9912],
[0.0835, -1.3915, -1.0720],
],
[
[0.1694, -0.6091, -0.6539],
[-0.5234, -0.1218, 0.5084],
[0.2374, -1.9537, -2.0078],
[-0.5700, -1.0302, 0.1558],
],
],
2,
[
[[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]],
[[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]],
],
[
[
[1.6986, 1.6986, 0.3050, 1.6986],
[0.7020, 0.7020, -2.1818, -2.1818],
[-0.5773, -0.9912, -0.5773, -0.9912],
[-1.0720, -1.0720, -1.3915, 0.0835],
],
[
[0.1694, 0.1694, -0.6091, -0.6539],
[0.5084, 0.5084, -0.1218, -0.5234],
[-1.9537, -2.0078, 0.2374, 0.2374],
[-0.5700, 0.1558, -0.5700, 0.1558],
],
],
)
verify_gather(data, axis, indices, ref_res)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 9a9282a

Please sign in to comment.