From c1063e10dd3e3ee200c154fb44ae02cd2310d99a Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 4 Mar 2019 13:26:02 -0800 Subject: [PATCH 1/4] Relax type requirements in reshape_like --- src/operator/tensor/elemwise_unary_op_basic.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 4aaf4dfd33c4..0f835be5ab1e 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -481,7 +481,16 @@ Negative indices are supported, and `None` can be used for either `lhs_end` or ` [](const NodeAttrs& attrs) { return std::vector(1, 1); }) .set_attr("FCompute", UnaryOp::IdentityCompute) .set_attr("FInferShape", ReshapeLikeShapeCompute) -.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferType", [](const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name; + std::vector checked_in_attrs = { (*in_attrs)[0] }; + bool ret = !type_is_none((*in_attrs)[1]) && + ElemwiseType<1,1>(attrs, &checked_in_attrs, out_attrs); + (*in_attrs)[0] = checked_in_attrs[0]; + return ret; + }) .set_attr( "FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { From 1483c503aa3ceb41c773a78104b2cfb8f8555e7c Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 4 Mar 2019 13:54:12 -0800 Subject: [PATCH 2/4] Add test --- tests/python/unittest/test_operator.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ae7dc86d566c..f5bee9bc5faa 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2529,6 +2529,16 @@ def test_slice_like_different_types(): z = mx.nd.slice_like(x, y) assert_allclose(z.asnumpy(), [[1,2,3],[5,6,7]]) +@with_seed() +def test_reshape_like_different_types(): + x = mx.nd.zeros((2, 3)) + + y = mx.nd.array([[1, 2], [3, 4], [5, 6]]) + + y = mx.nd.array(y).astype('int32') + z = mx.nd.reshape_like(x, y) + assert_allclose(z.asnumpy(), [[0,0],[0,0],[0,0]]) + @with_seed() def test_flip(): for ndim in range(1, 6): From 3f90ca668e1014dd4e6fde48aeaa07c9dd850772 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 4 Mar 2019 14:17:23 -0800 Subject: [PATCH 3/4] Fix lint --- src/operator/tensor/elemwise_unary_op_basic.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 0f835be5ab1e..19a9ac8359eb 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -487,7 +487,7 @@ Negative indices are supported, and `None` can be used for either `lhs_end` or ` CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name; std::vector checked_in_attrs = { (*in_attrs)[0] }; bool ret = !type_is_none((*in_attrs)[1]) && - ElemwiseType<1,1>(attrs, &checked_in_attrs, out_attrs); + ElemwiseType<1, 1>(attrs, &checked_in_attrs, out_attrs); (*in_attrs)[0] = checked_in_attrs[0]; return ret; }) From 357ce9c90b87b8d59776e7fc7c5d56fdc6feb99b Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 5 Mar 2019 08:40:30 -0800 Subject: [PATCH 4/4] Retrigger CI