From 0d33d3fc5666fc7888f8d799018f8839693cda83 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Thu, 18 May 2017 23:21:52 -0700 Subject: [PATCH] implementation for Csr slice on cpu (#36) * CPU implementation for CSR remove seg_len from csr slice add some docs for slice csr change indptr, values, etc to be private member bug fix in sparse embedding update nnvm submoduel fix lint update unit test for sparse nd" * add const for SliceCsrIndPtr kernel --- python/mxnet/sparse_ndarray.py | 39 +++++++-- src/operator/tensor/indexing_op.h | 2 +- src/operator/tensor/matrix_op-inl.h | 79 +++++++++++++++++++ src/operator/tensor/matrix_op.cc | 5 ++ tests/python/unittest/test_sparse_ndarray.py | 12 +-- tests/python/unittest/test_sparse_operator.py | 18 ++++- 6 files changed, 141 insertions(+), 14 deletions(-) diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/sparse_ndarray.py index a4da8f673324..e21ece416fba 100644 --- a/python/mxnet/sparse_ndarray.py +++ b/python/mxnet/sparse_ndarray.py @@ -251,13 +251,13 @@ def _slice(self, start, stop): >>> a[1:2].asnumpy() array([[0, 0, 3]]) - >>> a[1:2].indptr.asnumpy() + >>> a[1:2]._indptr.asnumpy() array([[2, 3]]) - >>> a[1:2].indicies.asnumpy() + >>> a[1:2]._indicies.asnumpy() array([0, 2, 2, 0, 1, 2]) - >>> a[1:2].values.asnumpy() + >>> a[1:2]._values.asnumpy() array([1, 2, 3, 4, 5, 6]) """ @@ -293,11 +293,27 @@ def _aux_type(self, i): return _DTYPE_MX_TO_NP[aux_type.value] @property - def values(self): + def _values(self): + """The values array of the SparseNDArray. This is a read-only view of the values array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's values array. + """ return self._data(0) @property - def indices(self): + def _indices(self): + """The indices array of the SparseNDArray. This is a read-only view of the indices array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indices array. + """ stype = self.storage_type if stype == 'row_sparse': return self._aux_data(0) @@ -306,7 +322,16 @@ def indices(self): raise Exception("unknown storage type " + stype) @property - def indptr(self): + def _indptr(self): + """The indptr array of the SparseNDArray with `csr` storage type. + This is a read-only view of the indptr array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indptr array. + """ stype = self.storage_type if stype == 'csr': return self._aux_data(0) @@ -383,6 +408,7 @@ def _aux_data(self, i, writable=False): SparseNDArray is not yet compacted, the returned result may include invalid values. """ + self.wait_to_read() hdl = NDArrayHandle() check_call(_LIB.MXNDArrayGetAuxNDArray(self.handle, i, ctypes.byref(hdl))) return NDArray(hdl, writable) @@ -392,6 +418,7 @@ def _data(self, writable=False): SparseNDArray is not yet compacted, the returned result may include invalid values. """ + self.wait_to_read() hdl = NDArrayHandle() check_call(_LIB.MXNDArrayGetDataNDArray(self.handle, ctypes.byref(hdl))) return NDArray(hdl, writable) diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index b677eedcc6fc..c8b2c95c8a15 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -328,7 +328,7 @@ struct EmbeddingBackwardRsp { size_t segment_end = (i + 1) * segment_len; for (size_t y = 0; y < num_idx; y++) { size_t j = idx[y]; - if (j > num_rows) j = num_rows - 1; + if (j >= num_rows) j = num_rows - 1; if (j < segment_start || j >= segment_end) continue; dst_idx[j] = j; for (size_t k = 0; k < width; k++) { diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 5bb27acd3daf..51e4869e94f8 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -7,6 +7,7 @@ #define MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_ #include +#include #include #include #include @@ -1217,6 +1218,84 @@ void Slice(const nnvm::NodeAttrs& attrs, }); } +// slice the indptr of a csr +struct SliceCsrIndPtr { + template + MSHADOW_XINLINE static void Map(int i, IType* out, const IType* in, const IType* base) { + KERNEL_ASSIGN(out[i], kWriteTo, in[i] - *base); + } +}; + +/* + * Slice a CSR NDArray + * Only implemented for CPU + */ +template +void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, + const NDArray &in, OpReqType req, const NDArray &out) { + using namespace mshadow; + using namespace mxnet_op; + using namespace csr; + CHECK((std::is_same::value)) << "Slice for CSR input only implemented for CPU"; + if (req == kNullOp) return; + CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported"; + CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported"; + Stream *s = ctx.get_stream(); + int begin = *param.begin[0]; + int end = *param.end[0]; + int indptr_len = end - begin + 1; + out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len)); + if (!in.storage_initialized()) { + out.SetAuxShape(kIndPtr, Shape1(0)); + return; + } + CHECK_EQ(in.aux_type(kIndPtr), in.aux_type(kIdx)) + << "The type for indptr and indices are different. This is not implemented yet."; + // assume idx indptr share the same type + NDARRAY_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), IType, { + MSHADOW_TYPE_SWITCH(in.dtype(), DType, { + auto in_indptr = in.aux_data(kIndPtr).dptr(); + auto out_indptr = out.aux_data(kIndPtr).dptr(); + int num_threads = omp_get_num_threads(); + int segment_len = (indptr_len + num_threads - 1) / num_threads; + Kernel::Launch(s, indptr_len, out_indptr, in_indptr + begin, + in_indptr + begin); + // retrieve nnz (CPU implementation) + int nnz = out_indptr[indptr_len - 1] - out_indptr[0]; + // copy indices and values + out.CheckAndAllocAuxData(kIdx, Shape1(nnz)); + out.CheckAndAllocData(Shape1(nnz)); + auto in_idx = in.aux_data(kIdx).dptr(); + auto out_idx = out.aux_data(kIdx).dptr(); + auto in_data = in.data().dptr(); + auto out_data = out.data().dptr(); + int offset = in_indptr[begin]; + // this is also a CPU-only implementation + memcpy(out_idx, in_idx + offset, nnz * sizeof(IType)); + memcpy(out_data, in_data + offset, nnz * sizeof(DType)); + }); + }); +} + +template +void SliceEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 1); + const SliceParam& param = nnvm::get(attrs.parsed); + auto in_stype = inputs[0].storage_type(); + CHECK_NE(in_stype, kDefaultStorage) + << "SliceEx is not expected to execute for input with default storage type"; + if (in_stype == kCSRStorage) { + SliceCsrImpl(param, ctx, inputs[0], req[0], outputs[0]); + } else { + LOG(FATAL) << "Slice not implemented for storage type" << in_stype; + } +} + inline bool SliceAssignShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index f4104d5f710d..7b73947aee34 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -232,6 +232,9 @@ and ``end=(e_1, e_2, ... e_n)`` indices will result in an array with the shape The resulting array's *k*-th dimension contains elements from the *k*-th dimension of the input array with the open range ``[b_k, e_k)``. +For an input array of non-default storage type(e.g. `csr` or `row_sparse`), it only supports +slicing on the first dimension. + Example:: x = [[ 1., 2., 3., 4.], @@ -245,8 +248,10 @@ Example:: .set_attr_parser(ParamParser) .set_attr("FInferShape", SliceShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 1>) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_slice"}) .set_attr("FCompute", Slice) +.set_attr(FCOMP_EX_CPU, SliceEx) .add_argument("data", "NDArray-or-Symbol", "Source input") .add_arguments(SliceParam::__FIELDS__()); diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 1c9fc9eb8127..89323fe2dbfd 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -75,10 +75,10 @@ def check_sparse_nd_prop_rsp(): shape = (rnd.randint(1, 2), rnd.randint(1, 2)) nd, (v, idx) = rand_sparse_ndarray(shape, storage_type) assert(nd._num_aux == 1) - assert(nd.indices.dtype == np.int32) + assert(nd._indices.dtype == np.int32) assert(nd.storage_type == 'row_sparse') assert_almost_equal(nd._data().asnumpy(), v) - assert_almost_equal(nd._aux_data(0).asnumpy(), idx) + assert_almost_equal(nd._indices.asnumpy(), idx) def test_sparse_nd_basic(): def check_rsp_creation(values, indices, shape): @@ -88,13 +88,13 @@ def check_rsp_creation(values, indices, shape): dns[3] = mx.nd.array(values[1]) assert_almost_equal(rsp.asnumpy(), dns.asnumpy()) indices = mx.nd.array(indices).asnumpy() - assert_almost_equal(rsp.indices.asnumpy(), indices) + assert_almost_equal(rsp._indices.asnumpy(), indices) def check_csr_creation(shape): csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') - assert_almost_equal(csr.indptr.asnumpy(), indptr) - assert_almost_equal(csr.indices.asnumpy(), indices) - assert_almost_equal(csr.values.asnumpy(), values) + assert_almost_equal(csr._indptr.asnumpy(), indptr) + assert_almost_equal(csr._indices.asnumpy(), indices) + assert_almost_equal(csr._values.asnumpy(), values) shape = (4,2) values = np.random.rand(2,2) diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 925fe3903b0f..8a34d4d591a0 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -170,7 +170,22 @@ def test_sparse_embedding(): grad = mx.nd.zeros(np_grad.shape) grad[:] = np_grad exe_test.backward([grad]) - assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, np_grad)) + assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, np_grad), atol=1e-5) + +def test_sparse_slice(): + def check_csr_slice(shape, sliced_input): + storage_type = 'csr' + A, _ = rand_sparse_ndarray(shape, storage_type) + A = A._slice(1, shape[0] - 1) if sliced_input else A + A2 = A.asnumpy() + begin = rnd.randint(0, A.shape[0] - 1) + end = rnd.randint(begin + 1, A.shape[0]) + A_slice = mx.nd.crop(A, begin=begin, end=end) + assert same(A_slice.asnumpy(), A2[begin:end]), (A_slice.asnumpy(), A2[begin:end]) + + shape = (rnd.randint(7, 15), rnd.randint(1, 10)) + check_csr_slice(shape, True) + check_csr_slice(shape, False) if __name__ == '__main__': test_elemwise_add_ex() @@ -178,3 +193,4 @@ def test_sparse_embedding(): test_cast_storage_ex() test_sparse_dot() test_sparse_embedding() + test_sparse_slice()