Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fallback to dense version for grad(reshape), grad(expand_dims) (#13599)
Browse files Browse the repository at this point in the history
* fallback to dense version for grad(reshape), grad(expand_dims)

* add _backward_reshape gpu version

* reshape test case comments

* fix gpu test

* remove mkldnn support for _backward_reshape
  • Loading branch information
yzhliu authored and eric-haibin-lin committed Dec 20, 2018
1 parent ebd5f68 commit 59f4395
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 2 deletions.
14 changes: 14 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,20 @@ NNVM_REGISTER_OP(_backward_copy)
return std::vector<bool>{true};
});

NNVM_REGISTER_OP(_backward_reshape)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
});

MXNET_OPERATOR_REGISTER_UNARY(BlockGrad)
MXNET_ADD_SPARSE_OP_ALIAS(stop_gradient)
.add_alias("stop_gradient")
Expand Down
4 changes: 4 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ NNVM_REGISTER_OP(_copy)
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::IdentityComputeEx<gpu>);

NNVM_REGISTER_OP(_backward_copy)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::IdentityComputeEx<gpu>);

NNVM_REGISTER_OP(_backward_reshape)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);

NNVM_REGISTER_OP(BlockGrad)
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from
.set_attr_parser(ParamParser<ReshapeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ReshapeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_copy"})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_reshape"})
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
Expand Down Expand Up @@ -415,7 +415,7 @@ will return a new array with shape ``(2,1,3,4)``.
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_copy"})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_reshape"})
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.add_argument("data", "NDArray-or-Symbol", "Source input")
.add_arguments(ExpandDimParam::__FIELDS__());
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,6 +2306,48 @@ def check_sparse_quadratic_function(a, b, c, expected_stype):
check_sparse_quadratic_function(a, b, 0.0, 'csr')
check_sparse_quadratic_function(a, b, 1.0, 'default')

def test_reshape_backward_fallback():
"""
out
| \
w_x x
/
w
in which x is a sparse tensor.
Due to sparse gradient optimization in sym.dot, grad(w_x) is sparse.
Though sym.reshape itself does not have sparse version,
if we somehow make grad(w) sparse as well, e.g.,
- by setting args_grad in symbol.bind
- or, we can have out_y = sym.dot(sparse_y, w), then grad(w) will be inferred as sparse
reshape backward (from w_x to w) needs to understand how to handle sparse inputs.
"""
ctx = default_context()
w_shape = (12, 4)
w_x_shape = (1, 48)
x_nd = rand_ndarray((4, 1), 'csr')

w_nd = rand_ndarray(w_shape)

w_x_nd = w_nd.reshape(w_x_shape)
out_x_nd = mx.nd.dot(x_nd, w_x_nd)

w_x_backward_grad = mx.nd.dot(x_nd, out_x_nd, transpose_a=True).asnumpy()
expected_grad_nd = w_x_backward_grad.reshape(w_shape)

x = mx.sym.Variable('x', stype='csr')
w = mx.sym.Variable('w')

w_x = mx.sym.reshape(w, w_x_shape, name="w_x")
out = mx.sym.sparse.dot(x, w_x, name='out_x')

grad_w_nd = rand_ndarray(w_shape, 'row_sparse')
executor = out.bind(ctx=ctx, args={"x": x_nd, "w": w_nd},
args_grad={"w": grad_w_nd})
executor.forward(is_train=True)
executor.backward(out_x_nd)

assert_almost_equal(grad_w_nd.asnumpy(), expected_grad_nd)

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 59f4395

Please sign in to comment.