Skip to content

Commit

Permalink
Relaxing type requirements for slice_like op (apache#14097)
Browse files Browse the repository at this point in the history
* Relaxing types for slice_like op

* Added test

* Fix typo in test

* Fix lint
  • Loading branch information
ptrendx authored and stephenrawls committed Feb 16, 2019
1 parent f750ffa commit dc3d336
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,16 @@ Example::
return std::vector<std::string>{"data", "shape_like"};
})
.set_attr<nnvm::FInferShape>("FInferShape", SliceLikeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name;
std::vector<int> checked_in_attrs = { (*in_attrs)[0] };
bool ret = !type_is_none((*in_attrs)[1]) &&
ElemwiseType<1, 1>(attrs, &checked_in_attrs, out_attrs);
(*in_attrs)[0] = checked_in_attrs[0];
return ret;
})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice_like"})
.set_attr<FCompute>("FCompute<cpu>", SliceLikeForward<cpu>)
.add_argument("data", "NDArray-or-Symbol", "Source input")
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2515,6 +2515,20 @@ def test_slice_like():
assert_allclose(xx, xgrad.asnumpy())
assert_allclose(xgrad1.asnumpy(), mx.nd.zeros_like(xgrad1).asnumpy())

@with_seed()
def test_slice_like_different_types():
x = [[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]]

y = [[ 0., 0., 0.],
[ 0., 0., 0.]]

x = mx.nd.array(x)
y = mx.nd.array(y).astype('int32')
z = mx.nd.slice_like(x, y)
assert_allclose(z.asnumpy(), [[1,2,3],[5,6,7]])

@with_seed()
def test_flip():
for ndim in range(1, 6):
Expand Down

0 comments on commit dc3d336

Please sign in to comment.