Skip to content

Commit

Permalink
Relaxing type requirements for reshape_like op (apache#14325)
Browse files Browse the repository at this point in the history
* Relax type requirements in reshape_like

* Add test

* Fix lint

* Retrigger CI
  • Loading branch information
ptrendx authored and haohuw committed Jun 23, 2019
1 parent c6bdce8 commit 2712200
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,16 @@ Negative indices are supported, and `None` can be used for either `lhs_end` or `
[](const NodeAttrs& attrs) { return std::vector<uint32_t>(1, 1); })
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<mxnet::FInferShape>("FInferShape", ReshapeLikeShapeCompute)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name;
std::vector<int> checked_in_attrs = { (*in_attrs)[0] };
bool ret = !type_is_none((*in_attrs)[1]) &&
ElemwiseType<1, 1>(attrs, &checked_in_attrs, out_attrs);
(*in_attrs)[0] = checked_in_attrs[0];
return ret;
})
.set_attr<nnvm::FGradient>(
"FGradient", [](const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2529,6 +2529,16 @@ def test_slice_like_different_types():
z = mx.nd.slice_like(x, y)
assert_allclose(z.asnumpy(), [[1,2,3],[5,6,7]])

@with_seed()
def test_reshape_like_different_types():
x = mx.nd.zeros((2, 3))

y = mx.nd.array([[1, 2], [3, 4], [5, 6]])

y = mx.nd.array(y).astype('int32')
z = mx.nd.reshape_like(x, y)
assert_allclose(z.asnumpy(), [[0,0],[0,0],[0,0]])

@with_seed()
def test_flip():
for ndim in range(1, 6):
Expand Down

0 comments on commit 2712200

Please sign in to comment.