From 8aff04726f3d68d172152d79f203d7ec894693e9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 3 Dec 2019 09:03:12 +0000 Subject: [PATCH] numpy_bincount_m --- python/mxnet/symbol/numpy/_symbol.py | 1 + src/operator/numpy/np_bincount_op-inl.h | 5 - src/operator/numpy/np_bincount_op.cc | 45 ++++---- src/operator/numpy/np_bincount_op.cu | 140 ++++++++++++------------ tests/python/unittest/test_numpy_op.py | 2 +- 5 files changed, 92 insertions(+), 101 deletions(-) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 030c51fd4992..51ec52eecf3c 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -5160,6 +5160,7 @@ def load_json(json_str): check_call(_LIB.MXSymbolCreateFromJSON(c_str(json_str), ctypes.byref(handle))) return _Symbol(handle) + @set_module('mxnet.symbol.numpy') def bincount(x, weights=None, minlength=0): """ diff --git a/src/operator/numpy/np_bincount_op-inl.h b/src/operator/numpy/np_bincount_op-inl.h index 6fd9532bead9..43fc1814f2f8 100644 --- a/src/operator/numpy/np_bincount_op-inl.h +++ b/src/operator/numpy/np_bincount_op-inl.h @@ -33,10 +33,6 @@ #include "../operator_common.h" #include "../elemwise_op_common.h" #include "np_broadcast_reduce_op.h" -#ifdef __CUDACC__ -#include -#include -#endif namespace mxnet { namespace op { @@ -109,7 +105,6 @@ void NumpyBincountForwardImpl(const OpContext &ctx, const int &minlength); template - void NumpyBincountForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, diff --git a/src/operator/numpy/np_bincount_op.cc b/src/operator/numpy/np_bincount_op.cc index 7cd3accb0661..6256db176977 100644 --- a/src/operator/numpy/np_bincount_op.cc +++ b/src/operator/numpy/np_bincount_op.cc @@ -19,8 +19,8 @@ /*! * Copyright (c) 2019 by Contributors - * \file np_bicount_op-inl.h - * \brief numpy compatible bincount operator GPU registration + * \file np_bicount_op.cc + * \brief numpy compatible bincount operator CPU registration */ #include "./np_bincount_op-inl.h" @@ -68,17 +68,17 @@ void NumpyBincountForwardImpl(const OpContext &ctx, const NDArray &out, const size_t &data_n, const int &minlength) { -using namespace mxnet_op; -BinNumberCount(data, minlength, out, data_n); -mshadow::Stream *s = ctx.get_stream(); -MSHADOW_TYPE_SWITCH(data.dtype(), DType, { - MSHADOW_TYPE_SWITCH(weights.dtype(), OType, { - size_t out_size = out.shape()[0]; - Kernel::Launch(s, out_size, out.data().dptr()); - BincountCpuWeights(data.data().dptr(), weights.data().dptr(), - out.data().dptr(), data_n); + using namespace mxnet_op; + BinNumberCount(data, minlength, out, data_n); + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(data.dtype(), DType, { + MSHADOW_TYPE_SWITCH(weights.dtype(), OType, { + size_t out_size = out.shape()[0]; + Kernel::Launch(s, out_size, out.data().dptr()); + BincountCpuWeights(data.data().dptr(), weights.data().dptr(), + out.data().dptr(), data_n); + }); }); - }); } template<> @@ -87,24 +87,21 @@ void NumpyBincountForwardImpl(const OpContext &ctx, const NDArray &out, const size_t &data_n, const int &minlength) { -using namespace mxnet_op; -BinNumberCount(data, minlength, out, data_n); -mshadow::Stream *s = ctx.get_stream(); -MSHADOW_TYPE_SWITCH(data.dtype(), DType, { - MSHADOW_TYPE_SWITCH(out.dtype(), OType, { - size_t out_size = out.shape()[0]; - Kernel::Launch(s, out_size, out.data().dptr()); - BincountCpu(data.data().dptr(), out.data().dptr(), data_n); + using namespace mxnet_op; + BinNumberCount(data, minlength, out, data_n); + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(data.dtype(), DType, { + MSHADOW_TYPE_SWITCH(out.dtype(), OType, { + size_t out_size = out.shape()[0]; + Kernel::Launch(s, out_size, out.data().dptr()); + BincountCpu(data.data().dptr(), out.data().dptr(), data_n); + }); }); - }); } DMLC_REGISTER_PARAMETER(NumpyBincountParam); NNVM_REGISTER_OP(_npi_bincount) -.describe(R"code( -Count number of occurrences of each value in array of non-negative ints. -)code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs([](const NodeAttrs& attrs) { const NumpyBincountParam& params = diff --git a/src/operator/numpy/np_bincount_op.cu b/src/operator/numpy/np_bincount_op.cu index c3a1a13166fa..e37e9af435b5 100644 --- a/src/operator/numpy/np_bincount_op.cu +++ b/src/operator/numpy/np_bincount_op.cu @@ -19,18 +19,20 @@ /*! * Copyright (c) 2019 by Contributors - * \file np_bicount_op-inl.h + * \file np_bicount_op.cu * \brief numpy compatible bincount operator GPU registration */ #include "./np_bincount_op-inl.h" +#include +#include #include "../tensor/util/tensor_util-inl.cuh" #include "../tensor/util/tensor_util-inl.h" namespace mxnet { namespace op { -struct BincountFusedKernel{ +struct BincountFusedKernel { template static MSHADOW_XINLINE void Map(int i, const DType* data, OType* out) { int idx = data[i]; @@ -72,42 +74,41 @@ void NumpyBincountForwardImpl(const OpContext &ctx, const NDArray &out, const size_t &data_n, const int &minlength) { -using namespace mxnet_op; -mshadow::Stream *s = ctx.get_stream(); - -MSHADOW_TYPE_SWITCH(data.dtype(), DType, { - DType* d_bin; - DType bin; - DType* d_ptr; - d_ptr = data.data().dptr(); - Tensor workspace = ctx.requested[0] - .get_space_typed(Shape1(1), s); - char* is_valid_ptr = reinterpret_cast(workspace.dptr_); - bool is_valid = CheckInvalidInput(s, d_ptr, data_n, is_valid_ptr); - CHECK(is_valid) << "Input should be nonnegative number"; // check invalid input - - cudaMalloc(&d_bin, sizeof(DType)); - thrust::device_ptr dptr_s = thrust::device_pointer_cast(d_ptr); - thrust::device_ptr dptr_e = thrust::device_pointer_cast(d_ptr + data_n); - d_bin = thrust::raw_pointer_cast(thrust::max_element(dptr_s, dptr_e)); - CUDA_CALL(cudaMemcpyAsync(&bin, d_bin, sizeof(DType), cudaMemcpyDeviceToHost, - mshadow::Stream::GetStream(s))); - CUDA_CALL(cudaStreamSynchronize(mshadow::Stream::GetStream(s))); + using namespace mxnet_op; + mshadow::Stream *s = ctx.get_stream(); + + MXNET_NO_FLOAT16_TYPE_SWITCH(data.dtype(), DType, { + DType* h_ptr; + DType* d_ptr; + int bin = minlength; + d_ptr = data.data().dptr(); + Tensor workspace = ctx.requested[0] + .get_space_typed(Shape1(1), s); + char* is_valid_ptr = reinterpret_cast(workspace.dptr_); + bool is_valid = CheckInvalidInput(s, d_ptr, data_n, is_valid_ptr); + CHECK(is_valid) << "Input should be nonnegative number"; // check invalid input + + h_ptr = reinterpret_cast(malloc(data_n*sizeof(DType))); + CUDA_CALL(cudaMemcpyAsync(h_ptr, d_ptr, data_n*sizeof(DType), cudaMemcpyDeviceToHost, + mshadow::Stream::GetStream(s))); + CUDA_CALL(cudaStreamSynchronize(mshadow::Stream::GetStream(s))); + for (size_t i = 0; i < data_n; i++) { + if (h_ptr[i] + 1 > bin) bin = h_ptr[i] + 1; + } + free(h_ptr); + mxnet::TShape s(1, bin); + const_cast(out).Init(s); // set the output shape forcefully + }); - bin = std::max(static_cast(bin+1), minlength); - mxnet::TShape s(1, bin); - const_cast(out).Init(s); // set the output shape forcefully -}); - -MSHADOW_TYPE_SWITCH(data.dtype(), DType, { - MSHADOW_TYPE_SWITCH(weights.dtype(), OType, { - size_t out_size = out.shape().Size(); - Kernel::Launch(s, out_size, out.data().dptr()); - Kernel::Launch( - s, data_n, data.data().dptr(), weights.data().dptr(), - out.data().dptr()); + MSHADOW_TYPE_SWITCH(data.dtype(), DType, { + MSHADOW_TYPE_SWITCH(weights.dtype(), OType, { + size_t out_size = out.shape().Size(); + Kernel::Launch(s, out_size, out.data().dptr()); + Kernel::Launch( + s, data_n, data.data().dptr(), weights.data().dptr(), + out.data().dptr()); + }); }); - }); } template<> @@ -116,43 +117,40 @@ void NumpyBincountForwardImpl(const OpContext &ctx, const NDArray &out, const size_t &data_n, const int &minlength) { -using namespace mxnet_op; -mshadow::Stream *s = ctx.get_stream(); - -MSHADOW_TYPE_SWITCH(data.dtype(), DType, { - DType* d_bin; - DType bin; - DType* d_ptr; - d_ptr = data.data().dptr(); - Tensor workspace = ctx.requested[0] - .get_space_typed(Shape1(1), s); - char* is_valid_ptr = reinterpret_cast(workspace.dptr_); - bool is_valid = CheckInvalidInput(s, d_ptr, data_n, is_valid_ptr); - CHECK(is_valid) << "Input should be nonnegative number"; // check invalid input - - Tensor workspace1 = ctx.requested[0] - .get_space_typed(Shape1(1), s); - d_bin = reinterpret_cast(workspace1.dptr_); - thrust::device_ptr dptr_s = thrust::device_pointer_cast(d_ptr); - thrust::device_ptr dptr_e = thrust::device_pointer_cast(d_ptr + data_n); - d_bin = thrust::raw_pointer_cast(thrust::max_element(dptr_s, dptr_e)); - CUDA_CALL(cudaMemcpyAsync(&bin, d_bin, sizeof(DType), cudaMemcpyDeviceToHost, - mshadow::Stream::GetStream(s))); - CUDA_CALL(cudaStreamSynchronize(mshadow::Stream::GetStream(s))); + using namespace mxnet_op; + mshadow::Stream *s = ctx.get_stream(); + + MXNET_NO_FLOAT16_TYPE_SWITCH(data.dtype(), DType, { + DType* h_ptr; + DType* d_ptr; + int bin = minlength; + d_ptr = data.data().dptr(); + Tensor workspace = ctx.requested[0] + .get_space_typed(Shape1(1), s); + char* is_valid_ptr = reinterpret_cast(workspace.dptr_); + bool is_valid = CheckInvalidInput(s, d_ptr, data_n, is_valid_ptr); + CHECK(is_valid) << "Input should be nonnegative number"; // check invalid input + + h_ptr = reinterpret_cast(malloc(data_n*sizeof(DType))); + CUDA_CALL(cudaMemcpyAsync(h_ptr, d_ptr, data_n*sizeof(DType), cudaMemcpyDeviceToHost, + mshadow::Stream::GetStream(s))); + CUDA_CALL(cudaStreamSynchronize(mshadow::Stream::GetStream(s))); + for (size_t i = 0; i < data_n; i++) { + if (h_ptr[i] + 1 > bin) bin = h_ptr[i] + 1; + } + free(h_ptr); + mxnet::TShape s(1, bin); + const_cast(out).Init(s); // set the output shape forcefully + }); - bin = std::max(static_castbin+1, minlength); - mxnet::TShape s(1, bin); - const_cast(out).Init(s); // set the output shape forcefully -}); - -MSHADOW_TYPE_SWITCH(data.dtype(), DType, { - MSHADOW_TYPE_SWITCH(out.dtype(), OType, { - size_t out_size = out.shape().Size(); - Kernel::Launch(s, out_size, out.data().dptr()); - Kernel::Launch( - s, data_n, data.data().dptr(), out.data().dptr()); + MSHADOW_TYPE_SWITCH(data.dtype(), DType, { + MSHADOW_TYPE_SWITCH(out.dtype(), OType, { + size_t out_size = out.shape().Size(); + Kernel::Launch(s, out_size, out.data().dptr()); + Kernel::Launch( + s, data_n, data.data().dptr(), out.data().dptr()); + }); }); - }); } NNVM_REGISTER_OP(_npi_bincount) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 551b81e73385..6e7311a59e49 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -5092,7 +5092,7 @@ def hybrid_forward(self, F, a, weights): np_out = _np.bincount(data.asnumpy(), weights_np, minlength) assert mx_out.shape == np_out.shape assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) - # No bacward operation for operator bincount at this moment + # No backward operation for operator bincount at this moment # Test imperative once again mx_out = np.bincount(data, weights, minlength)