From 809ebafba0b8ac0bce98e60b48ee87c4398e5ba6 Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 5 Jun 2017 10:24:24 -0700 Subject: [PATCH 1/4] Initial checkin --- src/operator/tensor/indexing_op.cc | 75 +++++++++++ src/operator/tensor/indexing_op.cu | 6 + src/operator/tensor/indexing_op.h | 192 +++++++++++++++++++++++++++++ 3 files changed, 273 insertions(+) diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index e85645f59506..f303a54bc908 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -264,5 +264,80 @@ Examples:: .add_argument("indices", "NDArray-or-Symbol", "array of locations where to set on_value") .add_arguments(OneHotParam::__FIELDS__()); + +NNVM_REGISTER_OP(one_hot) +.describe(R"code(Returns a one-hot array. + +The locations represented by `indices` take value `on_value`, while all +other locations take value `off_value`. + +`one_hot` operation with `indices` of shape ``(i0, i1)`` and `depth` of ``d`` would result +in an output array of shape ``(i0, i1, d)`` with:: + + output[i,j,:] = off_value + output[i,j,indices[i,j]] = on_value + +Examples:: + + one_hot([1,0,2,0], 3) = [[ 0. 1. 0.] + [ 1. 0. 0.] + [ 0. 0. 1.] + [ 1. 0. 0.]] + + one_hot([1,0,2,0], 3, on_value=8, off_value=1, + dtype='int32') = [[1 8 1] + [8 1 1] + [1 1 8] + [8 1 1]] + + one_hot([[1,0],[1,0],[2,0]], 3) = [[[ 0. 1. 0.] + [ 1. 0. 0.]] + + [[ 0. 1. 0.] + [ 1. 0. 0.]] + + [[ 0. 0. 1.] + [ 1. 0. 0.]]] +)code" ADD_FILELINE) +.set_num_outputs(1) +.set_num_inputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"indices"}; + }) +.set_attr("FInferShape", OneHotOpShape) +.set_attr("FInferType", OneHotOpType) +.set_attr("FCompute", OneHotOpForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("indices", "NDArray-or-Symbol", "array of locations where to set on_value") +.add_arguments(OneHotParam::__FIELDS__()); + +NNVM_REGISTER_OP(sparse_retain) +.describe(R"code(pick rows specified by user input index array from a row sparse matrix +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "indices"}; + }) +.set_attr("FInferShape", SparseRetainOpShape) +.set_attr("FInferType", SparseRetainOpType) +.set_attr("FComputeEx", SparseRetainOpForwardEx) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + return MakeNonlossGradNode("_backward_sparse_retain", n, ograds, + {n->inputs[sr::kIdx]}, n->attrs.dict); + }) +.add_argument("data", "NDArray-or-Symbol", "The input array for sparse_retain operator.") +.add_argument("indices", "NDArray-or-Symbol", "The index array of rows ids that will be retained."); + +NNVM_REGISTER_OP(_backward_sparse_retain) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FComputeEx", SparseRetainOpBackwardEx); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 287ec25d70be..4378bd574932 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -26,6 +26,12 @@ NNVM_REGISTER_OP(batch_take) NNVM_REGISTER_OP(one_hot) .set_attr("FCompute", OneHotOpForward); +NNVM_REGISTER_OP(sparse_retain) +.set_attr("FComputeEx", SparseRetainOpForwardEx); + +NNVM_REGISTER_OP(_backward_sparse_retain) +.set_attr("FComputeEx", SparseRetainOpBackwardEx); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 12523e237cf2..f84b64253604 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -795,6 +795,198 @@ void OneHotOpForward(const nnvm::NodeAttrs& attrs, }); } +/*! + * \brief sparse retain namespace + */ +namespace sr { +enum SparseRetainOpInputs {kArr, kIdx}; +enum SparseRetainOpOutputs {kOut}; +} // namespace sr + +inline bool SparseRetainOpShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U) + << "sparse_retain operator takes 2 arguments (" << in_attrs->size() << " given)"; + CHECK_EQ(out_attrs->size(), 1U); + + TShape tshape((*in_attrs)[sr::kArr]); + shape_assign(&tshape, (*out_attrs)[sr::kOut]); + SHAPE_ASSIGN_CHECK(*in_attrs, sr::kArr, tshape); + SHAPE_ASSIGN_CHECK(*out_attrs, sr::kOut, tshape); + return true; +} + +inline bool SparseRetainOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_NE((*in_attrs)[sr::kIdx], -1) << "Index type must be set for sparse_retain operator"; + + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[sr::kArr]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[sr::kOut]); + return (*in_attrs)[0] != -1; +} + +inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + if (kRowSparseStorage == in_attrs->at(sr::kArr)) { + out_attrs->at(sr::kOut) = kRowSparseStorage; + } + return true; +} + +inline bool SparseRetainBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + out_attrs->at(sr::kOut) = kRowSparseStorage; + return true; +} + +struct SparseRetainRspForward { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, RType* out_idx, + const DType* in_data, const RType* in_idx, + const IType* idx, const size_t nnr, + const size_t num_cols) { + const RType irow = idx[i]; + int j = -1, left = 0, right = nnr - 1; + while (left <= right) { + int m = left + (right - left) / 2; + const auto in_idx_m = in_idx[m]; + if (in_idx_m == irow) { + j = m; + break; + } else if (in_idx_m < irow) { + left = m + 1; + } else { + right = m - 1; + } + } + out_idx[i] = idx[i]; + if (j >= 0) { + const size_t in_offset = j * num_cols; + const size_t out_offset = i * num_cols; + for (size_t k = 0; k < num_cols; ++k) { + out_data[out_offset+k] = in_data[in_offset+k]; + } + } + } +}; + +template +void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[sr::kOut], kWriteTo) << "sparse_retain only supports req=\'write\'"; + + CHECK_EQ(inputs[sr::kArr].storage_type(), kRowSparseStorage) + << "sparse_retain operator only takes row sparse NDArray as input"; + CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage) + << "sparse_retain operator only takes default NDArray as its index array"; + CHECK_EQ(outputs[sr::kOut].storage_type(), kRowSparseStorage) + << "sparse_retain operator only outputs row sparse NDArray"; + + const NDArray& input_nd = inputs[sr::kArr]; + const TBlob idx_data = inputs[sr::kIdx].data(); + + if (req[sr::kOut] == kNullOp + || !input_nd.storage_initialized() + || idx_data.Size() == 0U) return; + + const TBlob input_data = input_nd.data(); + if (input_data.shape_[0] == 0) return; + const TBlob input_idx = input_nd.aux_data(rowsparse::kIdx); + + NDArray output_nd = outputs[sr::kOut]; + output_nd.CheckAndAlloc({mshadow::Shape2(idx_data.Size(), output_nd.shape()[1])}); + TBlob output_data = output_nd.data(); + TBlob output_idx = output_nd.aux_data(rowsparse::kIdx); + + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(output_data.type_flag_, DType, { // output data type + MSHADOW_INT_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type + MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type + Kernel::Launch(s, output_data.Size(), output_data.dptr()); + Kernel::Launch(s, idx_data.Size(), output_data.dptr(), + output_idx.dptr(), input_data.dptr(), input_idx.dptr(), + idx_data.dptr(), input_data.shape_[0], input_data.shape_[1]); + }); + }); + }); +} + +template +struct SparseRetainRspBackward { + template + MSHADOW_XINLINE static void Map(int i, DType* in_grad, RType* in_grad_idx, + const DType* out_grad, const IType* idx, + const size_t num_cols) { + const RType irow = idx[i]; + in_grad_idx[i] = irow; + const size_t out_offset = irow * num_cols; + const size_t in_offset = i * num_cols; + for (size_t j = 0; j < num_cols; ++j) { + KERNEL_ASSIGN(in_grad[in_offset+j], req, out_grad[out_offset+j]); + } + } +}; + +template +void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + CHECK_NE(req[sr::kArr], kWriteInplace); + CHECK_EQ(req[sr::kIdx], kNullOp) + << "sparse_retain backward does not support the gradient of the index array"; + + CHECK_EQ(inputs[sr::kOut].storage_type(), kDefaultStorage) + << "sparse_retain backward only takes default NDArray as ograd"; + CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage) + << "sparse_retain backward only takes default NDArray as its index array"; + CHECK_EQ(outputs[sr::kArr].storage_type(), kRowSparseStorage) + << "sparse_retain backward only outputs row sparse NDArray as grad of input"; + + const TBlob out_grad_data = inputs[sr::kOut].data(); + const TBlob idx_data = inputs[sr::kIdx].data(); + + NDArray in_grad_nd = outputs[sr::kArr]; + in_grad_nd.CheckAndAlloc({mshadow::Shape2(idx_data.Size(), out_grad_data.shape_[1])}); + TBlob in_grad_data = in_grad_nd.data(); + TBlob in_grad_idx = in_grad_nd.aux_data(rowsparse::kIdx); + + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(out_grad_data.type_flag_, DType, { // output data type + MSHADOW_INT_TYPE_SWITCH(in_grad_idx.type_flag_, RType, { // row index data type + MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type + MXNET_ASSIGN_REQ_SWITCH(req[sr::kArr], req_type, { + Kernel, xpu>::Launch( + s, in_grad_idx.Size(), in_grad_data.dptr(), in_grad_idx.dptr(), + out_grad_data.dptr(), idx_data.dptr(), out_grad_data.shape_[1]); + }); + }); + }); + }); +} + } // namespace op } // namespace mxnet #ifdef __CUDACC__ From 4cf1f938fc8d16ad48c24a4421046ad28265d16e Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 5 Jun 2017 11:46:24 -0700 Subject: [PATCH 2/4] Fix bugs --- src/operator/tensor/indexing_op.cc | 53 ++---------------------------- src/operator/tensor/indexing_op.h | 8 ++--- 2 files changed, 6 insertions(+), 55 deletions(-) diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index f303a54bc908..a7b40cf035c5 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -216,55 +216,6 @@ Examples:: .add_argument("a", "NDArray-or-Symbol", "The input array") .add_argument("indices", "NDArray-or-Symbol", "The index array"); -NNVM_REGISTER_OP(one_hot) -.describe(R"code(Returns a one-hot array. - -The locations represented by `indices` take value `on_value`, while all -other locations take value `off_value`. - -`one_hot` operation with `indices` of shape ``(i0, i1)`` and `depth` of ``d`` would result -in an output array of shape ``(i0, i1, d)`` with:: - - output[i,j,:] = off_value - output[i,j,indices[i,j]] = on_value - -Examples:: - - one_hot([1,0,2,0], 3) = [[ 0. 1. 0.] - [ 1. 0. 0.] - [ 0. 0. 1.] - [ 1. 0. 0.]] - - one_hot([1,0,2,0], 3, on_value=8, off_value=1, - dtype='int32') = [[1 8 1] - [8 1 1] - [1 1 8] - [8 1 1]] - - one_hot([[1,0],[1,0],[2,0]], 3) = [[[ 0. 1. 0.] - [ 1. 0. 0.]] - - [[ 0. 1. 0.] - [ 1. 0. 0.]] - - [[ 0. 0. 1.] - [ 1. 0. 0.]]] -)code" ADD_FILELINE) -.set_num_outputs(1) -.set_num_inputs(1) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"indices"}; - }) -.set_attr("FInferShape", OneHotOpShape) -.set_attr("FInferType", OneHotOpType) -.set_attr("FCompute", OneHotOpForward) -.set_attr("FGradient", MakeZeroGradNodes) -.add_argument("indices", "NDArray-or-Symbol", "array of locations where to set on_value") -.add_arguments(OneHotParam::__FIELDS__()); - - NNVM_REGISTER_OP(one_hot) .describe(R"code(Returns a one-hot array. @@ -324,6 +275,7 @@ NNVM_REGISTER_OP(sparse_retain) }) .set_attr("FInferShape", SparseRetainOpShape) .set_attr("FInferType", SparseRetainOpType) +.set_attr("FInferStorageType", SparseRetainForwardInferStorageType) .set_attr("FComputeEx", SparseRetainOpForwardEx) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { @@ -335,8 +287,9 @@ NNVM_REGISTER_OP(sparse_retain) NNVM_REGISTER_OP(_backward_sparse_retain) .set_num_inputs(2) -.set_num_outputs(2) +.set_num_outputs(1) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", SparseRetainBackwardInferStorageType) .set_attr("FComputeEx", SparseRetainOpBackwardEx); } // namespace op diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index f84b64253604..1bedbe3864b1 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -910,7 +910,7 @@ void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs, const TBlob input_idx = input_nd.aux_data(rowsparse::kIdx); NDArray output_nd = outputs[sr::kOut]; - output_nd.CheckAndAlloc({mshadow::Shape2(idx_data.Size(), output_nd.shape()[1])}); + output_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())}); TBlob output_data = output_nd.data(); TBlob output_idx = output_nd.aux_data(rowsparse::kIdx); @@ -952,10 +952,8 @@ void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 2U); - CHECK_EQ(req.size(), 2U); + CHECK_EQ(req.size(), 1U); CHECK_NE(req[sr::kArr], kWriteInplace); - CHECK_EQ(req[sr::kIdx], kNullOp) - << "sparse_retain backward does not support the gradient of the index array"; CHECK_EQ(inputs[sr::kOut].storage_type(), kDefaultStorage) << "sparse_retain backward only takes default NDArray as ograd"; @@ -968,7 +966,7 @@ void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs, const TBlob idx_data = inputs[sr::kIdx].data(); NDArray in_grad_nd = outputs[sr::kArr]; - in_grad_nd.CheckAndAlloc({mshadow::Shape2(idx_data.Size(), out_grad_data.shape_[1])}); + in_grad_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())}); TBlob in_grad_data = in_grad_nd.data(); TBlob in_grad_idx = in_grad_nd.aux_data(rowsparse::kIdx); From 857b18337f7887dc5c4f13fc709fabd9e2644297 Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 5 Jun 2017 20:36:38 -0700 Subject: [PATCH 3/4] Add unit test for sparse_retain --- include/mxnet/ndarray.h | 22 +++++++------- python/mxnet/test_utils.py | 16 +++++++--- src/operator/tensor/indexing_op.cc | 2 +- src/operator/tensor/indexing_op.h | 9 ++++-- tests/python/unittest/test_sparse_operator.py | 30 +++++++++++++++++++ 5 files changed, 60 insertions(+), 19 deletions(-) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index dd53b8f1764f..0612784cc218 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -115,44 +115,44 @@ class NDArray { } /*! \brief constructor for NDArray with storage type */ - NDArray(const NDArrayStorageType storage_type, const TShape &shape, Context ctx, + NDArray(const NDArrayStorageType stype, const TShape &shape, Context ctx, bool delay_alloc = true, int dtype = mshadow::default_type_flag, std::vector aux_types = {}, std::vector aux_shapes = {}, TShape storage_shape = TShape(mshadow::Shape1(0))) : shape_(shape), offset_(0), dtype_(dtype), entry_({nullptr, 0, 0}) { // Assign default aux types if not given if (aux_types.size() == 0) { - if (storage_type == kRowSparseStorage) { + if (stype == kRowSparseStorage) { aux_types = {ROW_SPARSE_IDX_TYPE}; - } else if (storage_type == kCSRStorage) { + } else if (stype == kCSRStorage) { aux_types = {CSR_IND_PTR_TYPE, CSR_IDX_DTYPE}; } else { - LOG(FATAL) << "Unknown storage type" << storage_type; + LOG(FATAL) << "Unknown storage type " << stype; } } // Assign default shapes if not given // unknown shapes are intialized as {0} such that Size() would return 0 if (aux_shapes.size() == 0) { - if (storage_type == kRowSparseStorage) { + if (stype == kRowSparseStorage) { aux_shapes = {TShape(mshadow::Shape1(0))}; - } else if (storage_type == kCSRStorage) { + } else if (stype == kCSRStorage) { // aux shapes for indptr and indices aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))}; } else { - LOG(FATAL) << "Unknown storage type" << storage_type; + LOG(FATAL) << "Unknown storage type " << stype; } } if (storage_shape.Size() == 0) { - if (storage_type == kRowSparseStorage) { + if (stype == kRowSparseStorage) { storage_shape = shape; storage_shape[0] = aux_shapes[rowsparse::kIdx][0]; - } else if (storage_type == kCSRStorage) { + } else if (stype == kCSRStorage) { storage_shape = aux_shapes[csr::kIdx]; } else { - LOG(FATAL) << "Unknown storage type" << storage_type; + LOG(FATAL) << "Unknown storage type " << stype; } } - ptr_ = std::make_shared(storage_type, storage_shape, ctx, delay_alloc, + ptr_ = std::make_shared(stype, storage_shape, ctx, delay_alloc, dtype, aux_types, aux_shapes); #if MKL_EXPERIMENTAL == 1 Mkl_mem_ = std::make_shared(); diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 37e346bdf638..5decf50a4b33 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -16,7 +16,7 @@ import numpy.random as rnd import mxnet as mx from .context import Context -from .ndarray import array +from .ndarray import array, _STORAGE_TYPE_STR_TO_ID from .symbol import Symbol try: import requests @@ -457,7 +457,8 @@ def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_trai def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rtol=1e-2, - atol=None, grad_nodes=None, use_forward_train=True, ctx=None): + atol=None, grad_nodes=None, use_forward_train=True, ctx=None, + grad_stype_dict=None): """Verify an operation by checking backward pass via finite difference method. Based on Theano's `theano.gradient.verify_grad` [1] @@ -474,7 +475,7 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto - if type is dict of str -> numpy.ndarray maps the name of arguments to the corresponding numpy.ndarray. *In either case, value of all the arguments must be provided.* - aux_states : ist or tuple or dict, optional + aux_states : list or tuple or dict, optional The auxiliary states required when generating the executor for the symbol. numeric_eps : float, optional Delta for the finite difference method that approximates the gradient. @@ -486,6 +487,8 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto Whether to use is_train=True when computing the finite-difference. ctx : Context, optional Check the gradient computation on the specified device. + grad_stype_dict : dict of str->str, optional + Storage type dictionary for gradient ndarrays. References --------- ..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py @@ -509,7 +512,7 @@ def random_projection(shape): location_npy = {k:v.asnumpy() for k, v in location.items()} aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx) if aux_states is not None: - aux_states_npy = {k:v.asnumpy() for k, v in aux_states.items()} + aux_states_npy = {k: v.asnumpy() for k, v in aux_states.items()} else: aux_states_npy = None if grad_nodes is None: @@ -536,6 +539,11 @@ def random_projection(shape): + [("__random_proj", _rng.normal(0, 0.01, size=out_shape[0]))]) args_grad = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()} + if grad_stype_dict is not None: + assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict" + for k, v in grad_stype_dict.items(): + if k in args_grad and v in _STORAGE_TYPE_STR_TO_ID and v != 'default': + args_grad[k] = mx.nd.cast_storage(args_grad[k], storage_type=v) executor = out.bind(ctx, grad_req=grad_req, args=location, args_grad=args_grad, aux_states=aux_states) diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index a7b40cf035c5..833af40a0367 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -287,7 +287,7 @@ NNVM_REGISTER_OP(sparse_retain) NNVM_REGISTER_OP(_backward_sparse_retain) .set_num_inputs(2) -.set_num_outputs(1) +.set_num_outputs(2) .set_attr("TIsBackward", true) .set_attr("FInferStorageType", SparseRetainBackwardInferStorageType) .set_attr("FComputeEx", SparseRetainOpBackwardEx); diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 1bedbe3864b1..81b219f7c2c9 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -844,8 +844,9 @@ inline bool SparseRetainBackwardInferStorageType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - out_attrs->at(sr::kOut) = kRowSparseStorage; + CHECK_EQ(out_attrs->size(), 2U); + out_attrs->at(sr::kArr) = kRowSparseStorage; + out_attrs->at(sr::kIdx) = kDefaultStorage; return true; } @@ -952,8 +953,10 @@ void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 2U); - CHECK_EQ(req.size(), 1U); + CHECK_EQ(req.size(), 2U); CHECK_NE(req[sr::kArr], kWriteInplace); + CHECK_EQ(req[sr::kIdx], kNullOp) + << "sparse_retain does not support calculating gradients of indices"; CHECK_EQ(inputs[sr::kOut].storage_type(), kDefaultStorage) << "sparse_retain backward only takes default NDArray as ograd"; diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 1f4e2e8cc2c7..fca3a9b91100 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -5,6 +5,7 @@ from numpy.testing import assert_allclose from mxnet.test_utils import * + def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_grad_stype=None): lhs = mx.symbol.Variable('lhs', storage_type=lhs_stype) rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) @@ -69,6 +70,7 @@ def test_elemwise_add_ex_multiple_stages(): exec_test.backward(out_grads=exec_test.outputs) assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) + # TODO(haibin) also add test for backward pass. Check if exception is thrown def test_cast_storage_ex(): def test_rsp_to_dns(shape): @@ -133,6 +135,7 @@ def test_dot_csr_dns(csr_shape, dns_shape, trans_csr): test_dot_csr_dns(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), False) test_dot_csr_dns(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), True) + def test_sparse_embedding(): in_dim = 10 out_dim = 4 @@ -160,6 +163,7 @@ def test_sparse_embedding(): exe_test.backward([grad]) assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, np_grad), atol=1e-5) + def test_sparse_slice(): def check_csr_slice(shape, slice_input): storage_type = 'csr' @@ -175,6 +179,32 @@ def check_csr_slice(shape, slice_input): check_csr_slice(shape, True) check_csr_slice(shape, False) + +def test_sparse_retain(): + for num_rows in range(1, 40): + num_cols = 3 + shape = (num_rows, num_cols) + rsp, _ = rand_sparse_ndarray(shape=shape, storage_type='row_sparse', density=0.5) + length = np.random.randint(1, num_rows + 1) + import random + idx = random.sample(range(0, num_rows), length) + idx.sort() + dns = rsp.asnumpy() + tensor_retained_expected = np.zeros(shape) + for i in idx: + tensor_retained_expected[i][:] = dns[i] + indices = mx.nd.array(idx) + rsp_retained = mx.nd.sparse_retain(rsp, indices=indices) + assert same(tensor_retained_expected, rsp_retained.asnumpy()) + + # check numeric gradient + data = mx.symbol.Variable('data') + idx = mx.symbol.Variable('indices') + sym = mx.sym.sparse_retain(data=data, indices=idx) + check_numeric_gradient(sym, [rsp, indices], grad_nodes=['data'], grad_stype_dict={'data': 'row_sparse'}) + + + if __name__ == '__main__': import nose nose.runmodule() From 18f11faa4e5846bf1cc6c06e4b6cc6658445399e Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 6 Jun 2017 14:08:16 -0700 Subject: [PATCH 4/4] Add example and modify test --- python/mxnet/test_utils.py | 9 +++++++++ src/operator/tensor/indexing_op.cc | 13 +++++++++++++ tests/python/unittest/test_sparse_operator.py | 14 ++++---------- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 5decf50a4b33..c0b3c12fc95e 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -67,6 +67,15 @@ def random_arrays(*shapes): return arrays[0] return arrays + +def random_sample(population, k): + """Return a k length list of the elements chosen from the population sequence.""" + assert 0 <= k <= len(population) + population_copy = population[:] + np.random.shuffle(population_copy) + return population_copy[0:k] + + # TODO(haibin) also include types in arguments def rand_sparse_ndarray(shape, storage_type, density=None): """Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """ diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 833af40a0367..8cf00c0eb7b4 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -266,6 +266,19 @@ Examples:: NNVM_REGISTER_OP(sparse_retain) .describe(R"code(pick rows specified by user input index array from a row sparse matrix +and save them in the output sparse matrix. + +Example:: + + data = [[1, 2], [3, 4], [5, 6]] + indices = [0, 1, 3] + shape = (4, 2) + rsp_in = row_sparse(data, indices) + to_retain = [0, 3] + rsp_out = sparse_retain(rsp_in, to_retain) + rsp_out.values = [[1, 2], [5, 6]] + rsp_out.indices = [0, 3] + )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index fca3a9b91100..57767c612565 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1,8 +1,3 @@ -# pylint: skip-file -import numpy as np -import mxnet as mx -import scipy.sparse as sp -from numpy.testing import assert_allclose from mxnet.test_utils import * @@ -181,13 +176,12 @@ def check_csr_slice(shape, slice_input): def test_sparse_retain(): - for num_rows in range(1, 40): - num_cols = 3 - shape = (num_rows, num_cols) + for _ in range(10): + shape = rand_shape_2d() + num_rows = shape[0] rsp, _ = rand_sparse_ndarray(shape=shape, storage_type='row_sparse', density=0.5) length = np.random.randint(1, num_rows + 1) - import random - idx = random.sample(range(0, num_rows), length) + idx = random_sample(range(0, num_rows), length) idx.sort() dns = rsp.asnumpy() tensor_retained_expected = np.zeros(shape)