Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

sparse support for take(csr, axis=0) #12889

Merged
merged 5 commits into from
Oct 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,142 @@ void SparseEmbeddingOpForwardRspImpl<cpu>(const OpContext& ctx,
}
}

template<bool clip>
struct CsrTakeDataKernel {
/*!
* \brief Map function for general case of take grad
* \param tid global thread id
* \param out_idx ptr to out idx
* \param out_data ptr to out data
* \param out_indptr ptr to out indptr
* \param src_data ptr to original csr data
* \param src_idx ptr to original csr idx
* \param idx_ptr ptr to indices
* \param num_rows maximum number of rows in src array
*/
template<typename IType, typename DType, typename RType>
MSHADOW_XINLINE static void Map(int tid, RType* out_idx, DType* out_data,
const RType* out_indptr, const RType* src_idx,
const DType* src_data, const RType* src_indptr,
const IType* idx_ptr, const nnvm::dim_t num_rows) {
nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid]);
// clip mode
if (clip) {
if (idx < 0) idx = 0;
if (idx >= num_rows) idx = num_rows - 1;
} else {
// wrap mode
idx = idx % num_rows;
idx += (idx < 0) ? num_rows : 0;
}
int row_nnz = src_indptr[idx + 1] - src_indptr[idx];
for (int i = 0; i < row_nnz; i++) {
out_data[out_indptr[tid] + i] = src_data[src_indptr[idx] + i];
out_idx[out_indptr[tid] + i] = src_idx[src_indptr[idx] + i];
}
}
};

template<bool clip>
struct CsrTakeRowCountKernel {
/*!
* \brief Map function for general case of take grad
* \param tid global thread id
* \param out_indptr ptr to out indptr
* \param src_indptr ptr to original csr indptr
* \param idx_ptr ptr to indices
* \param num_rows maximum number of rows in src array
*/
template<typename IType, typename RType>
MSHADOW_XINLINE static void Map(int tid, RType* out_indptr,
const RType* src_indptr, const IType* idx_ptr,
const nnvm::dim_t num_rows) {
if (tid == 0) out_indptr[0] = 0;
nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid - 1]);
// clip mode
if (clip) {
if (idx < 0) idx = 0;
if (idx >= num_rows) idx = num_rows - 1;
} else {
// wrap mode
idx = idx % num_rows;
idx += (idx < 0) ? num_rows : 0;
}
out_indptr[tid] = src_indptr[idx + 1] - src_indptr[idx];
}
};

template<>
void TakeOpForwardCsrImpl<cpu>(const TakeParam& params,
const OpContext& ctx,
const TBlob& idx,
const NDArray& arr,
OpReqType req,
const NDArray& out) {
using namespace csr;
using namespace mxnet_op;
using nnvm::dim_t;
Stream<cpu> *s = ctx.get_stream<cpu>();
if (req == kNullOp) return;
if (!arr.storage_initialized()) {
FillZerosCsrImpl(s, out);
return;
}
CHECK_EQ(idx.shape_.ndim(), 1U)
<< "Take with CSR array only supports one-dimensional indices. "
<< idx.shape_.ndim() << " dimensional input is given instead";
CHECK_EQ(req, kWriteTo) << "req = " << req << " is not supported for take(csr)";
auto axis = params.axis;
CHECK_EQ(axis, 0) << "axis = " << axis << " is not supported for take(csr)";
CHECK(params.mode == take_::kClip || params.mode == take_::kWrap)
<< "mode = " << params.mode << " is not supported";
const dim_t num_rows = out.shape()[0];
const dim_t max_num_rows = arr.shape()[0];
out.CheckAndAllocAuxData(kIndPtr, {Shape1(num_rows + 1)});

MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, {
MSHADOW_SGL_DBL_TYPE_SWITCH(arr.dtype(), DType, {
MSHADOW_IDX_TYPE_SWITCH(out.aux_type(kIdx), RType, {
RType* out_indptr = out.aux_data(kIndPtr).dptr<RType>();
const RType* src_indptr = arr.aux_data(kIndPtr).dptr<RType>();
const IType* idx_ptr = idx.dptr<IType>();
// gather per row nnz information for output
bool clip = params.mode == take_::kClip;
if (clip) {
Kernel<CsrTakeRowCountKernel<true>, cpu>::Launch(s, num_rows + 1,
out_indptr, src_indptr, idx_ptr, max_num_rows);
} else {
Kernel<CsrTakeRowCountKernel<false>, cpu>::Launch(s, num_rows + 1,
out_indptr, src_indptr, idx_ptr, max_num_rows);
}
// calculate prefix sum with single thread
for (dim_t i = 0; i < num_rows; i++) {
out_indptr[i + 1] += out_indptr[i];
}
// total number of non-zero rows
const dim_t nnz = out_indptr[num_rows];
if (nnz == 0) {
FillZerosCsrImpl(s, out);
return;
}
out.CheckAndAllocAuxData(kIdx, {Shape1(nnz)});
out.CheckAndAllocData(Shape1(nnz));
RType* out_idx = out.aux_data(kIdx).dptr<RType>();
DType* out_data = out.data().dptr<DType>();
const RType* src_idx = arr.aux_data(kIdx).dptr<RType>();
const DType* src_data = arr.data().dptr<DType>();
// copy indices and data for output
if (clip) {
Kernel<CsrTakeDataKernel<true>, cpu>::Launch(s, num_rows, out_idx,
out_data, out_indptr, src_idx, src_data, src_indptr, idx_ptr, max_num_rows);
} else {
Kernel<CsrTakeDataKernel<false>, cpu>::Launch(s, num_rows, out_idx,
out_data, out_indptr, src_idx, src_data, src_indptr, idx_ptr, max_num_rows);
}
});
});
});
}

template<>
inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const bool deterministic,
Expand Down Expand Up @@ -400,6 +536,7 @@ dimension of data (by default outer-most one as axis=0) indexed by indices, and
in an output tensor of rank q + (r - 1).
Examples::
x = [4. 5. 6.]
// Trivial case, take the second element along the first axis.
Expand Down Expand Up @@ -431,6 +568,11 @@ Examples::
[[ 3., 4.],
[ 5., 6.]]]
The storage type of ``take`` output depends upon the input storage type:
- take(default, default) = default
- take(csr, default, axis=0) = csr
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
Expand All @@ -441,11 +583,13 @@ Examples::
})
.set_attr<nnvm::FInferShape>("FInferShape", TakeOpShape)
.set_attr<nnvm::FInferType>("FInferType", TakeOpType)
.set_attr<FInferStorageType>("FInferStorageType", TakeOpForwardStorageType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", TakeOpForward<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", TakeOpForwardEx<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
return MakeNonlossGradNode("_backward_take", n, ograds,
Expand Down
65 changes: 65 additions & 0 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,71 @@ inline bool TakeOpType(const nnvm::NodeAttrs& attrs,
return (*in_attrs)[0] != -1;
}

// storage type inference function for take
inline bool TakeOpForwardStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const int& idx_stype = in_attrs->at(take_::kIdx);
const int& arr_stype = in_attrs->at(take_::kArr);
int& out_stype = out_attrs->at(take_::kOut);
bool dispatched = false;
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kDefaultStorage) {
// dns, dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kCSRStorage &&
param.axis == 0 && (param.mode == take_::kWrap || param.mode == take_::kClip)) {
// take(dns, csr, axis=0) -> csr
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}


template<typename xpu>
void TakeOpForwardCsrImpl(const TakeParam& params,
const OpContext& ctx,
const TBlob& idx,
const NDArray& arr,
OpReqType req,
const NDArray& output);


template<typename xpu>
void TakeOpForwardEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(req[take_::kOut], kWriteTo);
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const NDArray& idx = inputs[take_::kIdx];
const NDArray& arr = inputs[take_::kArr];
const NDArray& out = outputs[take_::kOut];
const auto idx_stype = idx.storage_type();
const auto arr_stype = arr.storage_type();
const auto out_stype = out.storage_type();
const auto params = nnvm::get<TakeParam>(attrs.parsed);
if (idx_stype == kDefaultStorage && arr_stype == kCSRStorage &&
out_stype == kCSRStorage) {
// dns, csr -> csr
TakeOpForwardCsrImpl<xpu>(params, ctx, idx.data(), arr, req[0], out);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}

template<typename xpu>
void TakeOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,24 @@ def check_sparse_fc(batch_size, dim_in, dim_out, stype):
# test FC with row_sparse weight w/ density=1, csr data (fallback)
check_sparse_fc(5, 10, 8, 'csr')

@with_seed()
def test_sparse_take():
def check_sparse_take(density, mode):
data_shape = rand_shape_2d()
idx_shape = (np.random.randint(low=1, high=10),)
data = rand_ndarray(data_shape, 'csr', density=density)
idx = mx.nd.array(np.random.randint(low=-5, high=15, size=idx_shape))
result = mx.nd.take(data, idx, mode=mode)
data_np = data.asnumpy()
idx_np = idx.asnumpy().astype('int32')
expected_result = np.take(data_np, idx_np, mode=mode, axis=0)
assert_almost_equal(result.asnumpy(), expected_result)
densities = [0, 0.5, 1]
modes = ['clip', 'wrap']
for d in densities:
for m in modes:
check_sparse_take(d, m)

if __name__ == '__main__':
import nose
nose.runmodule()