Skip to content

Commit

Permalink
sparse support for take(csr, axis=0) (apache#12889)
Browse files Browse the repository at this point in the history
* initial commit

* add test cases for mode

* fix bug

* add comment

* more comments
  • Loading branch information
eric-haibin-lin authored and lanking520 committed Oct 24, 2018
1 parent 58d35c5 commit d96bafa
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 0 deletions.
144 changes: 144 additions & 0 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,142 @@ void SparseEmbeddingOpForwardRspImpl<cpu>(const OpContext& ctx,
}
}

template<bool clip>
struct CsrTakeDataKernel {
/*!
* \brief Map function for general case of take grad
* \param tid global thread id
* \param out_idx ptr to out idx
* \param out_data ptr to out data
* \param out_indptr ptr to out indptr
* \param src_data ptr to original csr data
* \param src_idx ptr to original csr idx
* \param idx_ptr ptr to indices
* \param num_rows maximum number of rows in src array
*/
template<typename IType, typename DType, typename RType>
MSHADOW_XINLINE static void Map(int tid, RType* out_idx, DType* out_data,
const RType* out_indptr, const RType* src_idx,
const DType* src_data, const RType* src_indptr,
const IType* idx_ptr, const nnvm::dim_t num_rows) {
nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid]);
// clip mode
if (clip) {
if (idx < 0) idx = 0;
if (idx >= num_rows) idx = num_rows - 1;
} else {
// wrap mode
idx = idx % num_rows;
idx += (idx < 0) ? num_rows : 0;
}
int row_nnz = src_indptr[idx + 1] - src_indptr[idx];
for (int i = 0; i < row_nnz; i++) {
out_data[out_indptr[tid] + i] = src_data[src_indptr[idx] + i];
out_idx[out_indptr[tid] + i] = src_idx[src_indptr[idx] + i];
}
}
};

template<bool clip>
struct CsrTakeRowCountKernel {
/*!
* \brief Map function for general case of take grad
* \param tid global thread id
* \param out_indptr ptr to out indptr
* \param src_indptr ptr to original csr indptr
* \param idx_ptr ptr to indices
* \param num_rows maximum number of rows in src array
*/
template<typename IType, typename RType>
MSHADOW_XINLINE static void Map(int tid, RType* out_indptr,
const RType* src_indptr, const IType* idx_ptr,
const nnvm::dim_t num_rows) {
if (tid == 0) out_indptr[0] = 0;
nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid - 1]);
// clip mode
if (clip) {
if (idx < 0) idx = 0;
if (idx >= num_rows) idx = num_rows - 1;
} else {
// wrap mode
idx = idx % num_rows;
idx += (idx < 0) ? num_rows : 0;
}
out_indptr[tid] = src_indptr[idx + 1] - src_indptr[idx];
}
};

template<>
void TakeOpForwardCsrImpl<cpu>(const TakeParam& params,
const OpContext& ctx,
const TBlob& idx,
const NDArray& arr,
OpReqType req,
const NDArray& out) {
using namespace csr;
using namespace mxnet_op;
using nnvm::dim_t;
Stream<cpu> *s = ctx.get_stream<cpu>();
if (req == kNullOp) return;
if (!arr.storage_initialized()) {
FillZerosCsrImpl(s, out);
return;
}
CHECK_EQ(idx.shape_.ndim(), 1U)
<< "Take with CSR array only supports one-dimensional indices. "
<< idx.shape_.ndim() << " dimensional input is given instead";
CHECK_EQ(req, kWriteTo) << "req = " << req << " is not supported for take(csr)";
auto axis = params.axis;
CHECK_EQ(axis, 0) << "axis = " << axis << " is not supported for take(csr)";
CHECK(params.mode == take_::kClip || params.mode == take_::kWrap)
<< "mode = " << params.mode << " is not supported";
const dim_t num_rows = out.shape()[0];
const dim_t max_num_rows = arr.shape()[0];
out.CheckAndAllocAuxData(kIndPtr, {Shape1(num_rows + 1)});

MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, {
MSHADOW_SGL_DBL_TYPE_SWITCH(arr.dtype(), DType, {
MSHADOW_IDX_TYPE_SWITCH(out.aux_type(kIdx), RType, {
RType* out_indptr = out.aux_data(kIndPtr).dptr<RType>();
const RType* src_indptr = arr.aux_data(kIndPtr).dptr<RType>();
const IType* idx_ptr = idx.dptr<IType>();
// gather per row nnz information for output
bool clip = params.mode == take_::kClip;
if (clip) {
Kernel<CsrTakeRowCountKernel<true>, cpu>::Launch(s, num_rows + 1,
out_indptr, src_indptr, idx_ptr, max_num_rows);
} else {
Kernel<CsrTakeRowCountKernel<false>, cpu>::Launch(s, num_rows + 1,
out_indptr, src_indptr, idx_ptr, max_num_rows);
}
// calculate prefix sum with single thread
for (dim_t i = 0; i < num_rows; i++) {
out_indptr[i + 1] += out_indptr[i];
}
// total number of non-zero rows
const dim_t nnz = out_indptr[num_rows];
if (nnz == 0) {
FillZerosCsrImpl(s, out);
return;
}
out.CheckAndAllocAuxData(kIdx, {Shape1(nnz)});
out.CheckAndAllocData(Shape1(nnz));
RType* out_idx = out.aux_data(kIdx).dptr<RType>();
DType* out_data = out.data().dptr<DType>();
const RType* src_idx = arr.aux_data(kIdx).dptr<RType>();
const DType* src_data = arr.data().dptr<DType>();
// copy indices and data for output
if (clip) {
Kernel<CsrTakeDataKernel<true>, cpu>::Launch(s, num_rows, out_idx,
out_data, out_indptr, src_idx, src_data, src_indptr, idx_ptr, max_num_rows);
} else {
Kernel<CsrTakeDataKernel<false>, cpu>::Launch(s, num_rows, out_idx,
out_data, out_indptr, src_idx, src_data, src_indptr, idx_ptr, max_num_rows);
}
});
});
});
}

template<>
inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const bool deterministic,
Expand Down Expand Up @@ -400,6 +536,7 @@ dimension of data (by default outer-most one as axis=0) indexed by indices, and
in an output tensor of rank q + (r - 1).
Examples::
x = [4. 5. 6.]
// Trivial case, take the second element along the first axis.
Expand Down Expand Up @@ -431,6 +568,11 @@ Examples::
[[ 3., 4.],
[ 5., 6.]]]
The storage type of ``take`` output depends upon the input storage type:
- take(default, default) = default
- take(csr, default, axis=0) = csr
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
Expand All @@ -441,11 +583,13 @@ Examples::
})
.set_attr<nnvm::FInferShape>("FInferShape", TakeOpShape)
.set_attr<nnvm::FInferType>("FInferType", TakeOpType)
.set_attr<FInferStorageType>("FInferStorageType", TakeOpForwardStorageType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", TakeOpForward<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", TakeOpForwardEx<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
return MakeNonlossGradNode("_backward_take", n, ograds,
Expand Down
65 changes: 65 additions & 0 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,71 @@ inline bool TakeOpType(const nnvm::NodeAttrs& attrs,
return (*in_attrs)[0] != -1;
}

// storage type inference function for take
inline bool TakeOpForwardStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const int& idx_stype = in_attrs->at(take_::kIdx);
const int& arr_stype = in_attrs->at(take_::kArr);
int& out_stype = out_attrs->at(take_::kOut);
bool dispatched = false;
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kDefaultStorage) {
// dns, dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kCSRStorage &&
param.axis == 0 && (param.mode == take_::kWrap || param.mode == take_::kClip)) {
// take(dns, csr, axis=0) -> csr
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}


template<typename xpu>
void TakeOpForwardCsrImpl(const TakeParam& params,
const OpContext& ctx,
const TBlob& idx,
const NDArray& arr,
OpReqType req,
const NDArray& output);


template<typename xpu>
void TakeOpForwardEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(req[take_::kOut], kWriteTo);
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const NDArray& idx = inputs[take_::kIdx];
const NDArray& arr = inputs[take_::kArr];
const NDArray& out = outputs[take_::kOut];
const auto idx_stype = idx.storage_type();
const auto arr_stype = arr.storage_type();
const auto out_stype = out.storage_type();
const auto params = nnvm::get<TakeParam>(attrs.parsed);
if (idx_stype == kDefaultStorage && arr_stype == kCSRStorage &&
out_stype == kCSRStorage) {
// dns, csr -> csr
TakeOpForwardCsrImpl<xpu>(params, ctx, idx.data(), arr, req[0], out);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}

template<typename xpu>
void TakeOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,24 @@ def check_sparse_fc(batch_size, dim_in, dim_out, stype):
# test FC with row_sparse weight w/ density=1, csr data (fallback)
check_sparse_fc(5, 10, 8, 'csr')

@with_seed()
def test_sparse_take():
def check_sparse_take(density, mode):
data_shape = rand_shape_2d()
idx_shape = (np.random.randint(low=1, high=10),)
data = rand_ndarray(data_shape, 'csr', density=density)
idx = mx.nd.array(np.random.randint(low=-5, high=15, size=idx_shape))
result = mx.nd.take(data, idx, mode=mode)
data_np = data.asnumpy()
idx_np = idx.asnumpy().astype('int32')
expected_result = np.take(data_np, idx_np, mode=mode, axis=0)
assert_almost_equal(result.asnumpy(), expected_result)
densities = [0, 0.5, 1]
modes = ['clip', 'wrap']
for d in densities:
for m in modes:
check_sparse_take(d, m)

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit d96bafa

Please sign in to comment.