diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index b663ef0179df..98e2536f4c9b 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -87,6 +87,142 @@ void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, } } +template +struct CsrTakeDataKernel { + /*! + * \brief Map function for general case of take grad + * \param tid global thread id + * \param out_idx ptr to out idx + * \param out_data ptr to out data + * \param out_indptr ptr to out indptr + * \param src_data ptr to original csr data + * \param src_idx ptr to original csr idx + * \param idx_ptr ptr to indices + * \param num_rows maximum number of rows in src array + */ + template + MSHADOW_XINLINE static void Map(int tid, RType* out_idx, DType* out_data, + const RType* out_indptr, const RType* src_idx, + const DType* src_data, const RType* src_indptr, + const IType* idx_ptr, const nnvm::dim_t num_rows) { + nnvm::dim_t idx = static_cast(idx_ptr[tid]); + // clip mode + if (clip) { + if (idx < 0) idx = 0; + if (idx >= num_rows) idx = num_rows - 1; + } else { + // wrap mode + idx = idx % num_rows; + idx += (idx < 0) ? num_rows : 0; + } + int row_nnz = src_indptr[idx + 1] - src_indptr[idx]; + for (int i = 0; i < row_nnz; i++) { + out_data[out_indptr[tid] + i] = src_data[src_indptr[idx] + i]; + out_idx[out_indptr[tid] + i] = src_idx[src_indptr[idx] + i]; + } + } +}; + +template +struct CsrTakeRowCountKernel { + /*! + * \brief Map function for general case of take grad + * \param tid global thread id + * \param out_indptr ptr to out indptr + * \param src_indptr ptr to original csr indptr + * \param idx_ptr ptr to indices + * \param num_rows maximum number of rows in src array + */ + template + MSHADOW_XINLINE static void Map(int tid, RType* out_indptr, + const RType* src_indptr, const IType* idx_ptr, + const nnvm::dim_t num_rows) { + if (tid == 0) out_indptr[0] = 0; + nnvm::dim_t idx = static_cast(idx_ptr[tid - 1]); + // clip mode + if (clip) { + if (idx < 0) idx = 0; + if (idx >= num_rows) idx = num_rows - 1; + } else { + // wrap mode + idx = idx % num_rows; + idx += (idx < 0) ? num_rows : 0; + } + out_indptr[tid] = src_indptr[idx + 1] - src_indptr[idx]; + } +}; + +template<> +void TakeOpForwardCsrImpl(const TakeParam& params, + const OpContext& ctx, + const TBlob& idx, + const NDArray& arr, + OpReqType req, + const NDArray& out) { + using namespace csr; + using namespace mxnet_op; + using nnvm::dim_t; + Stream *s = ctx.get_stream(); + if (req == kNullOp) return; + if (!arr.storage_initialized()) { + FillZerosCsrImpl(s, out); + return; + } + CHECK_EQ(idx.shape_.ndim(), 1U) + << "Take with CSR array only supports one-dimensional indices. " + << idx.shape_.ndim() << " dimensional input is given instead"; + CHECK_EQ(req, kWriteTo) << "req = " << req << " is not supported for take(csr)"; + auto axis = params.axis; + CHECK_EQ(axis, 0) << "axis = " << axis << " is not supported for take(csr)"; + CHECK(params.mode == take_::kClip || params.mode == take_::kWrap) + << "mode = " << params.mode << " is not supported"; + const dim_t num_rows = out.shape()[0]; + const dim_t max_num_rows = arr.shape()[0]; + out.CheckAndAllocAuxData(kIndPtr, {Shape1(num_rows + 1)}); + + MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { + MSHADOW_SGL_DBL_TYPE_SWITCH(arr.dtype(), DType, { + MSHADOW_IDX_TYPE_SWITCH(out.aux_type(kIdx), RType, { + RType* out_indptr = out.aux_data(kIndPtr).dptr(); + const RType* src_indptr = arr.aux_data(kIndPtr).dptr(); + const IType* idx_ptr = idx.dptr(); + // gather per row nnz information for output + bool clip = params.mode == take_::kClip; + if (clip) { + Kernel, cpu>::Launch(s, num_rows + 1, + out_indptr, src_indptr, idx_ptr, max_num_rows); + } else { + Kernel, cpu>::Launch(s, num_rows + 1, + out_indptr, src_indptr, idx_ptr, max_num_rows); + } + // calculate prefix sum with single thread + for (dim_t i = 0; i < num_rows; i++) { + out_indptr[i + 1] += out_indptr[i]; + } + // total number of non-zero rows + const dim_t nnz = out_indptr[num_rows]; + if (nnz == 0) { + FillZerosCsrImpl(s, out); + return; + } + out.CheckAndAllocAuxData(kIdx, {Shape1(nnz)}); + out.CheckAndAllocData(Shape1(nnz)); + RType* out_idx = out.aux_data(kIdx).dptr(); + DType* out_data = out.data().dptr(); + const RType* src_idx = arr.aux_data(kIdx).dptr(); + const DType* src_data = arr.data().dptr(); + // copy indices and data for output + if (clip) { + Kernel, cpu>::Launch(s, num_rows, out_idx, + out_data, out_indptr, src_idx, src_data, src_indptr, idx_ptr, max_num_rows); + } else { + Kernel, cpu>::Launch(s, num_rows, out_idx, + out_data, out_indptr, src_idx, src_data, src_indptr, idx_ptr, max_num_rows); + } + }); + }); + }); +} template<> inline void SparseEmbeddingOpBackwardRspImpl(const bool deterministic, @@ -400,6 +536,7 @@ dimension of data (by default outer-most one as axis=0) indexed by indices, and in an output tensor of rank q + (r - 1). Examples:: + x = [4. 5. 6.] // Trivial case, take the second element along the first axis. @@ -431,6 +568,11 @@ Examples:: [[ 3., 4.], [ 5., 6.]]] +The storage type of ``take`` output depends upon the input storage type: + + - take(default, default) = default + - take(csr, default, axis=0) = csr + )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -441,11 +583,13 @@ Examples:: }) .set_attr("FInferShape", TakeOpShape) .set_attr("FInferType", TakeOpType) +.set_attr("FInferStorageType", TakeOpForwardStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", TakeOpForward) +.set_attr("FComputeEx", TakeOpForwardEx) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { return MakeNonlossGradNode("_backward_take", n, ograds, diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 1daf0a2cb18a..5282a7ea9a61 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -755,6 +755,71 @@ inline bool TakeOpType(const nnvm::NodeAttrs& attrs, return (*in_attrs)[0] != -1; } +// storage type inference function for take +inline bool TakeOpForwardStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const int& idx_stype = in_attrs->at(take_::kIdx); + const int& arr_stype = in_attrs->at(take_::kArr); + int& out_stype = out_attrs->at(take_::kOut); + bool dispatched = false; + const TakeParam& param = nnvm::get(attrs.parsed); + if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kDefaultStorage) { + // dns, dns -> dns + dispatched = storage_type_assign(&out_stype, kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); + } + if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kCSRStorage && + param.axis == 0 && (param.mode == take_::kWrap || param.mode == take_::kClip)) { + // take(dns, csr, axis=0) -> csr + dispatched = storage_type_assign(&out_stype, kCSRStorage, + dispatch_mode, DispatchMode::kFComputeEx); + } + if (!dispatched) { + dispatched = dispatch_fallback(out_attrs, dispatch_mode); + } + return dispatched; +} + + +template +void TakeOpForwardCsrImpl(const TakeParam& params, + const OpContext& ctx, + const TBlob& idx, + const NDArray& arr, + OpReqType req, + const NDArray& output); + + +template +void TakeOpForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(req[take_::kOut], kWriteTo); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + const NDArray& idx = inputs[take_::kIdx]; + const NDArray& arr = inputs[take_::kArr]; + const NDArray& out = outputs[take_::kOut]; + const auto idx_stype = idx.storage_type(); + const auto arr_stype = arr.storage_type(); + const auto out_stype = out.storage_type(); + const auto params = nnvm::get(attrs.parsed); + if (idx_stype == kDefaultStorage && arr_stype == kCSRStorage && + out_stype == kCSRStorage) { + // dns, csr -> csr + TakeOpForwardCsrImpl(params, ctx, idx.data(), arr, req[0], out); + } else { + LogUnimplementedOp(attrs, ctx, inputs, req, outputs); + } +} + template void TakeOpForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 875dea7313ae..8dd250cf98ee 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -1013,6 +1013,24 @@ def check_sparse_fc(batch_size, dim_in, dim_out, stype): # test FC with row_sparse weight w/ density=1, csr data (fallback) check_sparse_fc(5, 10, 8, 'csr') +@with_seed() +def test_sparse_take(): + def check_sparse_take(density, mode): + data_shape = rand_shape_2d() + idx_shape = (np.random.randint(low=1, high=10),) + data = rand_ndarray(data_shape, 'csr', density=density) + idx = mx.nd.array(np.random.randint(low=-5, high=15, size=idx_shape)) + result = mx.nd.take(data, idx, mode=mode) + data_np = data.asnumpy() + idx_np = idx.asnumpy().astype('int32') + expected_result = np.take(data_np, idx_np, mode=mode, axis=0) + assert_almost_equal(result.asnumpy(), expected_result) + densities = [0, 0.5, 1] + modes = ['clip', 'wrap'] + for d in densities: + for m in modes: + check_sparse_take(d, m) + if __name__ == '__main__': import nose nose.runmodule()