diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 9730d0096e58..7f69395d1c87 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -236,6 +236,20 @@ NNVM_REGISTER_OP(_backward_copy) return std::vector{true}; }); +NNVM_REGISTER_OP(_backward_reshape) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.set_attr("FCompute", UnaryOp::IdentityCompute) +.set_attr("FInplaceIdentity", + [](const NodeAttrs& attrs){ + return std::vector{true}; + }); + MXNET_OPERATOR_REGISTER_UNARY(BlockGrad) MXNET_ADD_SPARSE_OP_ALIAS(stop_gradient) .add_alias("stop_gradient") diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu index c28934e94658..14f2be02ab1a 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cu +++ b/src/operator/tensor/elemwise_unary_op_basic.cu @@ -68,6 +68,10 @@ NNVM_REGISTER_OP(_copy) .set_attr("FComputeEx", UnaryOp::IdentityComputeEx); NNVM_REGISTER_OP(_backward_copy) +.set_attr("FCompute", UnaryOp::IdentityCompute) +.set_attr("FComputeEx", UnaryOp::IdentityComputeEx); + +NNVM_REGISTER_OP(_backward_reshape) .set_attr("FCompute", UnaryOp::IdentityCompute); NNVM_REGISTER_OP(BlockGrad) diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 2ffeabc11ae3..db8efa454385 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -223,7 +223,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from .set_attr_parser(ParamParser) .set_attr("FInferShape", ReshapeShape) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_copy"}) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_reshape"}) .set_attr("FCompute", UnaryOp::IdentityCompute) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) @@ -415,7 +415,7 @@ will return a new array with shape ``(2,1,3,4)``. [](const NodeAttrs& attrs){ return std::vector{true}; }) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_copy"}) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_reshape"}) .set_attr("FCompute", UnaryOp::IdentityCompute) .add_argument("data", "NDArray-or-Symbol", "Source input") .add_arguments(ExpandDimParam::__FIELDS__()); diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 57808248b081..05175bb435f2 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -2306,6 +2306,48 @@ def check_sparse_quadratic_function(a, b, c, expected_stype): check_sparse_quadratic_function(a, b, 0.0, 'csr') check_sparse_quadratic_function(a, b, 1.0, 'default') +def test_reshape_backward_fallback(): + """ + out + | \ + w_x x + / + w + in which x is a sparse tensor. + Due to sparse gradient optimization in sym.dot, grad(w_x) is sparse. + Though sym.reshape itself does not have sparse version, + if we somehow make grad(w) sparse as well, e.g., + - by setting args_grad in symbol.bind + - or, we can have out_y = sym.dot(sparse_y, w), then grad(w) will be inferred as sparse + reshape backward (from w_x to w) needs to understand how to handle sparse inputs. + """ + ctx = default_context() + w_shape = (12, 4) + w_x_shape = (1, 48) + x_nd = rand_ndarray((4, 1), 'csr') + + w_nd = rand_ndarray(w_shape) + + w_x_nd = w_nd.reshape(w_x_shape) + out_x_nd = mx.nd.dot(x_nd, w_x_nd) + + w_x_backward_grad = mx.nd.dot(x_nd, out_x_nd, transpose_a=True).asnumpy() + expected_grad_nd = w_x_backward_grad.reshape(w_shape) + + x = mx.sym.Variable('x', stype='csr') + w = mx.sym.Variable('w') + + w_x = mx.sym.reshape(w, w_x_shape, name="w_x") + out = mx.sym.sparse.dot(x, w_x, name='out_x') + + grad_w_nd = rand_ndarray(w_shape, 'row_sparse') + executor = out.bind(ctx=ctx, args={"x": x_nd, "w": w_nd}, + args_grad={"w": grad_w_nd}) + executor.forward(is_train=True) + executor.backward(out_x_nd) + + assert_almost_equal(grad_w_nd.asnumpy(), expected_grad_nd) + if __name__ == '__main__': import nose nose.runmodule()