Skip to content

Commit

Permalink
implementation for Csr slice on cpu (apache#36)
Browse files Browse the repository at this point in the history
* CPU implementation for CSR

remove seg_len from csr slice

add some docs for slice csr

change indptr, values, etc to be private member

bug fix in sparse embedding

update nnvm submoduel

fix lint

update unit test for sparse nd"

* add const for SliceCsrIndPtr kernel
  • Loading branch information
eric-haibin-lin authored May 19, 2017
1 parent 5be8d94 commit 0d33d3f
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 14 deletions.
39 changes: 33 additions & 6 deletions python/mxnet/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,13 @@ def _slice(self, start, stop):
>>> a[1:2].asnumpy()
array([[0, 0, 3]])
>>> a[1:2].indptr.asnumpy()
>>> a[1:2]._indptr.asnumpy()
array([[2, 3]])
>>> a[1:2].indicies.asnumpy()
>>> a[1:2]._indicies.asnumpy()
array([0, 2, 2, 0, 1, 2])
>>> a[1:2].values.asnumpy()
>>> a[1:2]._values.asnumpy()
array([1, 2, 3, 4, 5, 6])
"""
Expand Down Expand Up @@ -293,11 +293,27 @@ def _aux_type(self, i):
return _DTYPE_MX_TO_NP[aux_type.value]

@property
def values(self):
def _values(self):
"""The values array of the SparseNDArray. This is a read-only view of the values array.
They reveal internal implementation details and should be used with care.
Returns
-------
NDArray
This SparseNDArray's values array.
"""
return self._data(0)

@property
def indices(self):
def _indices(self):
"""The indices array of the SparseNDArray. This is a read-only view of the indices array.
They reveal internal implementation details and should be used with care.
Returns
-------
NDArray
This SparseNDArray's indices array.
"""
stype = self.storage_type
if stype == 'row_sparse':
return self._aux_data(0)
Expand All @@ -306,7 +322,16 @@ def indices(self):
raise Exception("unknown storage type " + stype)

@property
def indptr(self):
def _indptr(self):
"""The indptr array of the SparseNDArray with `csr` storage type.
This is a read-only view of the indptr array.
They reveal internal implementation details and should be used with care.
Returns
-------
NDArray
This SparseNDArray's indptr array.
"""
stype = self.storage_type
if stype == 'csr':
return self._aux_data(0)
Expand Down Expand Up @@ -383,6 +408,7 @@ def _aux_data(self, i, writable=False):
SparseNDArray is not yet compacted, the returned result may include invalid values.
"""
self.wait_to_read()
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayGetAuxNDArray(self.handle, i, ctypes.byref(hdl)))
return NDArray(hdl, writable)
Expand All @@ -392,6 +418,7 @@ def _data(self, writable=False):
SparseNDArray is not yet compacted, the returned result may include invalid values.
"""
self.wait_to_read()
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayGetDataNDArray(self.handle, ctypes.byref(hdl)))
return NDArray(hdl, writable)
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ struct EmbeddingBackwardRsp {
size_t segment_end = (i + 1) * segment_len;
for (size_t y = 0; y < num_idx; y++) {
size_t j = idx[y];
if (j > num_rows) j = num_rows - 1;
if (j >= num_rows) j = num_rows - 1;
if (j < segment_start || j >= segment_end) continue;
dst_idx[j] = j;
for (size_t k = 0; k < width; k++) {
Expand Down
79 changes: 79 additions & 0 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_

#include <mxnet/operator_util.h>
#include <dmlc/omp.h>
#include <vector>
#include <algorithm>
#include <utility>
Expand Down Expand Up @@ -1217,6 +1218,84 @@ void Slice(const nnvm::NodeAttrs& attrs,
});
}

// slice the indptr of a csr
struct SliceCsrIndPtr {
template<typename IType>
MSHADOW_XINLINE static void Map(int i, IType* out, const IType* in, const IType* base) {
KERNEL_ASSIGN(out[i], kWriteTo, in[i] - *base);
}
};

/*
* Slice a CSR NDArray
* Only implemented for CPU
*/
template<typename xpu>
void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
const NDArray &in, OpReqType req, const NDArray &out) {
using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
CHECK((std::is_same<xpu, cpu>::value)) << "Slice for CSR input only implemented for CPU";
if (req == kNullOp) return;
CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported";
CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported";
Stream<xpu> *s = ctx.get_stream<xpu>();
int begin = *param.begin[0];
int end = *param.end[0];
int indptr_len = end - begin + 1;
out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len));
if (!in.storage_initialized()) {
out.SetAuxShape(kIndPtr, Shape1(0));
return;
}
CHECK_EQ(in.aux_type(kIndPtr), in.aux_type(kIdx))
<< "The type for indptr and indices are different. This is not implemented yet.";
// assume idx indptr share the same type
NDARRAY_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), IType, {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
auto in_indptr = in.aux_data(kIndPtr).dptr<IType>();
auto out_indptr = out.aux_data(kIndPtr).dptr<IType>();
int num_threads = omp_get_num_threads();
int segment_len = (indptr_len + num_threads - 1) / num_threads;
Kernel<SliceCsrIndPtr, xpu>::Launch(s, indptr_len, out_indptr, in_indptr + begin,
in_indptr + begin);
// retrieve nnz (CPU implementation)
int nnz = out_indptr[indptr_len - 1] - out_indptr[0];
// copy indices and values
out.CheckAndAllocAuxData(kIdx, Shape1(nnz));
out.CheckAndAllocData(Shape1(nnz));
auto in_idx = in.aux_data(kIdx).dptr<IType>();
auto out_idx = out.aux_data(kIdx).dptr<IType>();
auto in_data = in.data().dptr<DType>();
auto out_data = out.data().dptr<DType>();
int offset = in_indptr[begin];
// this is also a CPU-only implementation
memcpy(out_idx, in_idx + offset, nnz * sizeof(IType));
memcpy(out_data, in_data + offset, nnz * sizeof(DType));
});
});
}

template<typename xpu>
void SliceEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1);
CHECK_EQ(outputs.size(), 1);
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
auto in_stype = inputs[0].storage_type();
CHECK_NE(in_stype, kDefaultStorage)
<< "SliceEx is not expected to execute for input with default storage type";
if (in_stype == kCSRStorage) {
SliceCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]);
} else {
LOG(FATAL) << "Slice not implemented for storage type" << in_stype;
}
}

inline bool SliceAssignShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
Expand Down
5 changes: 5 additions & 0 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ and ``end=(e_1, e_2, ... e_n)`` indices will result in an array with the shape
The resulting array's *k*-th dimension contains elements
from the *k*-th dimension of the input array with the open range ``[b_k, e_k)``.
For an input array of non-default storage type(e.g. `csr` or `row_sparse`), it only supports
slicing on the first dimension.
Example::
x = [[ 1., 2., 3., 4.],
Expand All @@ -245,8 +248,10 @@ Example::
.set_attr_parser(ParamParser<SliceParam>)
.set_attr<nnvm::FInferShape>("FInferShape", SliceShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice"})
.set_attr<FCompute>("FCompute<cpu>", Slice<cpu>)
.set_attr<FComputeEx>(FCOMP_EX_CPU, SliceEx<cpu>)
.add_argument("data", "NDArray-or-Symbol", "Source input")
.add_arguments(SliceParam::__FIELDS__());

Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def check_sparse_nd_prop_rsp():
shape = (rnd.randint(1, 2), rnd.randint(1, 2))
nd, (v, idx) = rand_sparse_ndarray(shape, storage_type)
assert(nd._num_aux == 1)
assert(nd.indices.dtype == np.int32)
assert(nd._indices.dtype == np.int32)
assert(nd.storage_type == 'row_sparse')
assert_almost_equal(nd._data().asnumpy(), v)
assert_almost_equal(nd._aux_data(0).asnumpy(), idx)
assert_almost_equal(nd._indices.asnumpy(), idx)

def test_sparse_nd_basic():
def check_rsp_creation(values, indices, shape):
Expand All @@ -88,13 +88,13 @@ def check_rsp_creation(values, indices, shape):
dns[3] = mx.nd.array(values[1])
assert_almost_equal(rsp.asnumpy(), dns.asnumpy())
indices = mx.nd.array(indices).asnumpy()
assert_almost_equal(rsp.indices.asnumpy(), indices)
assert_almost_equal(rsp._indices.asnumpy(), indices)

def check_csr_creation(shape):
csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr')
assert_almost_equal(csr.indptr.asnumpy(), indptr)
assert_almost_equal(csr.indices.asnumpy(), indices)
assert_almost_equal(csr.values.asnumpy(), values)
assert_almost_equal(csr._indptr.asnumpy(), indptr)
assert_almost_equal(csr._indices.asnumpy(), indices)
assert_almost_equal(csr._values.asnumpy(), values)

shape = (4,2)
values = np.random.rand(2,2)
Expand Down
18 changes: 17 additions & 1 deletion tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,27 @@ def test_sparse_embedding():
grad = mx.nd.zeros(np_grad.shape)
grad[:] = np_grad
exe_test.backward([grad])
assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, np_grad))
assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, np_grad), atol=1e-5)

def test_sparse_slice():
def check_csr_slice(shape, sliced_input):
storage_type = 'csr'
A, _ = rand_sparse_ndarray(shape, storage_type)
A = A._slice(1, shape[0] - 1) if sliced_input else A
A2 = A.asnumpy()
begin = rnd.randint(0, A.shape[0] - 1)
end = rnd.randint(begin + 1, A.shape[0])
A_slice = mx.nd.crop(A, begin=begin, end=end)
assert same(A_slice.asnumpy(), A2[begin:end]), (A_slice.asnumpy(), A2[begin:end])

shape = (rnd.randint(7, 15), rnd.randint(1, 10))
check_csr_slice(shape, True)
check_csr_slice(shape, False)

if __name__ == '__main__':
test_elemwise_add_ex()
test_elemwise_add_ex_multiple_stages()
test_cast_storage_ex()
test_sparse_dot()
test_sparse_embedding()
test_sparse_slice()

0 comments on commit 0d33d3f

Please sign in to comment.