Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy_bincount_m
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu authored and Tommliu committed Dec 6, 2019
1 parent 8107ef8 commit 8aff047
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 101 deletions.
1 change: 1 addition & 0 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5160,6 +5160,7 @@ def load_json(json_str):
check_call(_LIB.MXSymbolCreateFromJSON(c_str(json_str), ctypes.byref(handle)))
return _Symbol(handle)


@set_module('mxnet.symbol.numpy')
def bincount(x, weights=None, minlength=0):
"""
Expand Down
5 changes: 0 additions & 5 deletions src/operator/numpy/np_bincount_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "np_broadcast_reduce_op.h"
#ifdef __CUDACC__
#include <thrust/device_ptr.h>
#include <thrust/extrema.h>
#endif

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -109,7 +105,6 @@ void NumpyBincountForwardImpl(const OpContext &ctx,
const int &minlength);

template<typename xpu>

void NumpyBincountForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
Expand Down
45 changes: 21 additions & 24 deletions src/operator/numpy/np_bincount_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

/*!
* Copyright (c) 2019 by Contributors
* \file np_bicount_op-inl.h
* \brief numpy compatible bincount operator GPU registration
* \file np_bicount_op.cc
* \brief numpy compatible bincount operator CPU registration
*/

#include "./np_bincount_op-inl.h"
Expand Down Expand Up @@ -68,17 +68,17 @@ void NumpyBincountForwardImpl<cpu>(const OpContext &ctx,
const NDArray &out,
const size_t &data_n,
const int &minlength) {
using namespace mxnet_op;
BinNumberCount(data, minlength, out, data_n);
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
MSHADOW_TYPE_SWITCH(weights.dtype(), OType, {
size_t out_size = out.shape()[0];
Kernel<set_zero, cpu>::Launch(s, out_size, out.data().dptr<OType>());
BincountCpuWeights(data.data().dptr<DType>(), weights.data().dptr<OType>(),
out.data().dptr<OType>(), data_n);
using namespace mxnet_op;
BinNumberCount(data, minlength, out, data_n);
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
MSHADOW_TYPE_SWITCH(weights.dtype(), OType, {
size_t out_size = out.shape()[0];
Kernel<set_zero, cpu>::Launch(s, out_size, out.data().dptr<OType>());
BincountCpuWeights(data.data().dptr<DType>(), weights.data().dptr<OType>(),
out.data().dptr<OType>(), data_n);
});
});
});
}

template<>
Expand All @@ -87,24 +87,21 @@ void NumpyBincountForwardImpl<cpu>(const OpContext &ctx,
const NDArray &out,
const size_t &data_n,
const int &minlength) {
using namespace mxnet_op;
BinNumberCount(data, minlength, out, data_n);
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
MSHADOW_TYPE_SWITCH(out.dtype(), OType, {
size_t out_size = out.shape()[0];
Kernel<set_zero, cpu>::Launch(s, out_size, out.data().dptr<OType>());
BincountCpu(data.data().dptr<DType>(), out.data().dptr<OType>(), data_n);
using namespace mxnet_op;
BinNumberCount(data, minlength, out, data_n);
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
MSHADOW_TYPE_SWITCH(out.dtype(), OType, {
size_t out_size = out.shape()[0];
Kernel<set_zero, cpu>::Launch(s, out_size, out.data().dptr<OType>());
BincountCpu(data.data().dptr<DType>(), out.data().dptr<OType>(), data_n);
});
});
});
}

DMLC_REGISTER_PARAMETER(NumpyBincountParam);

NNVM_REGISTER_OP(_npi_bincount)
.describe(R"code(
Count number of occurrences of each value in array of non-negative ints.
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<NumpyBincountParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
const NumpyBincountParam& params =
Expand Down
140 changes: 69 additions & 71 deletions src/operator/numpy/np_bincount_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@

/*!
* Copyright (c) 2019 by Contributors
* \file np_bicount_op-inl.h
* \file np_bicount_op.cu
* \brief numpy compatible bincount operator GPU registration
*/

#include "./np_bincount_op-inl.h"
#include <thrust/device_ptr.h>
#include <thrust/extrema.h>
#include "../tensor/util/tensor_util-inl.cuh"
#include "../tensor/util/tensor_util-inl.h"

namespace mxnet {
namespace op {

struct BincountFusedKernel{
struct BincountFusedKernel {
template<typename DType, typename OType>
static MSHADOW_XINLINE void Map(int i, const DType* data, OType* out) {
int idx = data[i];
Expand Down Expand Up @@ -72,42 +74,41 @@ void NumpyBincountForwardImpl<gpu>(const OpContext &ctx,
const NDArray &out,
const size_t &data_n,
const int &minlength) {
using namespace mxnet_op;
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();

MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
DType* d_bin;
DType bin;
DType* d_ptr;
d_ptr = data.data().dptr<DType>();
Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(1), s);
char* is_valid_ptr = reinterpret_cast<char*>(workspace.dptr_);
bool is_valid = CheckInvalidInput(s, d_ptr, data_n, is_valid_ptr);
CHECK(is_valid) << "Input should be nonnegative number"; // check invalid input

cudaMalloc(&d_bin, sizeof(DType));
thrust::device_ptr<DType> dptr_s = thrust::device_pointer_cast(d_ptr);
thrust::device_ptr<DType> dptr_e = thrust::device_pointer_cast(d_ptr + data_n);
d_bin = thrust::raw_pointer_cast(thrust::max_element(dptr_s, dptr_e));
CUDA_CALL(cudaMemcpyAsync(&bin, d_bin, sizeof(DType), cudaMemcpyDeviceToHost,
mshadow::Stream<gpu>::GetStream(s)));
CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s)));
using namespace mxnet_op;
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();

MXNET_NO_FLOAT16_TYPE_SWITCH(data.dtype(), DType, {
DType* h_ptr;
DType* d_ptr;
int bin = minlength;
d_ptr = data.data().dptr<DType>();
Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(1), s);
char* is_valid_ptr = reinterpret_cast<char*>(workspace.dptr_);
bool is_valid = CheckInvalidInput(s, d_ptr, data_n, is_valid_ptr);
CHECK(is_valid) << "Input should be nonnegative number"; // check invalid input

h_ptr = reinterpret_cast<DType*>(malloc(data_n*sizeof(DType)));
CUDA_CALL(cudaMemcpyAsync(h_ptr, d_ptr, data_n*sizeof(DType), cudaMemcpyDeviceToHost,
mshadow::Stream<gpu>::GetStream(s)));
CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s)));
for (size_t i = 0; i < data_n; i++) {
if (h_ptr[i] + 1 > bin) bin = h_ptr[i] + 1;
}
free(h_ptr);
mxnet::TShape s(1, bin);
const_cast<NDArray &>(out).Init(s); // set the output shape forcefully
});

bin = std::max(static_cast<int>(bin+1), minlength);
mxnet::TShape s(1, bin);
const_cast<NDArray &>(out).Init(s); // set the output shape forcefully
});

MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
MSHADOW_TYPE_SWITCH(weights.dtype(), OType, {
size_t out_size = out.shape().Size();
Kernel<set_zero, gpu>::Launch(s, out_size, out.data().dptr<OType>());
Kernel<BincountFusedKernel, gpu>::Launch(
s, data_n, data.data().dptr<DType>(), weights.data().dptr<OType>(),
out.data().dptr<OType>());
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
MSHADOW_TYPE_SWITCH(weights.dtype(), OType, {
size_t out_size = out.shape().Size();
Kernel<set_zero, gpu>::Launch(s, out_size, out.data().dptr<OType>());
Kernel<BincountFusedKernel, gpu>::Launch(
s, data_n, data.data().dptr<DType>(), weights.data().dptr<OType>(),
out.data().dptr<OType>());
});
});
});
}

template<>
Expand All @@ -116,43 +117,40 @@ void NumpyBincountForwardImpl<gpu>(const OpContext &ctx,
const NDArray &out,
const size_t &data_n,
const int &minlength) {
using namespace mxnet_op;
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();

MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
DType* d_bin;
DType bin;
DType* d_ptr;
d_ptr = data.data().dptr<DType>();
Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(1), s);
char* is_valid_ptr = reinterpret_cast<char*>(workspace.dptr_);
bool is_valid = CheckInvalidInput(s, d_ptr, data_n, is_valid_ptr);
CHECK(is_valid) << "Input should be nonnegative number"; // check invalid input

Tensor<gpu, 1, DType> workspace1 = ctx.requested[0]
.get_space_typed<gpu, 1, DType>(Shape1(1), s);
d_bin = reinterpret_cast<DType*>(workspace1.dptr_);
thrust::device_ptr<DType> dptr_s = thrust::device_pointer_cast(d_ptr);
thrust::device_ptr<DType> dptr_e = thrust::device_pointer_cast(d_ptr + data_n);
d_bin = thrust::raw_pointer_cast(thrust::max_element(dptr_s, dptr_e));
CUDA_CALL(cudaMemcpyAsync(&bin, d_bin, sizeof(DType), cudaMemcpyDeviceToHost,
mshadow::Stream<gpu>::GetStream(s)));
CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s)));
using namespace mxnet_op;
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();

MXNET_NO_FLOAT16_TYPE_SWITCH(data.dtype(), DType, {
DType* h_ptr;
DType* d_ptr;
int bin = minlength;
d_ptr = data.data().dptr<DType>();
Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(1), s);
char* is_valid_ptr = reinterpret_cast<char*>(workspace.dptr_);
bool is_valid = CheckInvalidInput(s, d_ptr, data_n, is_valid_ptr);
CHECK(is_valid) << "Input should be nonnegative number"; // check invalid input

h_ptr = reinterpret_cast<DType*>(malloc(data_n*sizeof(DType)));
CUDA_CALL(cudaMemcpyAsync(h_ptr, d_ptr, data_n*sizeof(DType), cudaMemcpyDeviceToHost,
mshadow::Stream<gpu>::GetStream(s)));
CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s)));
for (size_t i = 0; i < data_n; i++) {
if (h_ptr[i] + 1 > bin) bin = h_ptr[i] + 1;
}
free(h_ptr);
mxnet::TShape s(1, bin);
const_cast<NDArray &>(out).Init(s); // set the output shape forcefully
});

bin = std::max(static_cast<int>bin+1, minlength);
mxnet::TShape s(1, bin);
const_cast<NDArray &>(out).Init(s); // set the output shape forcefully
});

MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
MSHADOW_TYPE_SWITCH(out.dtype(), OType, {
size_t out_size = out.shape().Size();
Kernel<set_zero, gpu>::Launch(s, out_size, out.data().dptr<OType>());
Kernel<BincountFusedKernel, gpu>::Launch(
s, data_n, data.data().dptr<DType>(), out.data().dptr<OType>());
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
MSHADOW_TYPE_SWITCH(out.dtype(), OType, {
size_t out_size = out.shape().Size();
Kernel<set_zero, gpu>::Launch(s, out_size, out.data().dptr<OType>());
Kernel<BincountFusedKernel, gpu>::Launch(
s, data_n, data.data().dptr<DType>(), out.data().dptr<OType>());
});
});
});
}

NNVM_REGISTER_OP(_npi_bincount)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5092,7 +5092,7 @@ def hybrid_forward(self, F, a, weights):
np_out = _np.bincount(data.asnumpy(), weights_np, minlength)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
# No bacward operation for operator bincount at this moment
# No backward operation for operator bincount at this moment

# Test imperative once again
mx_out = np.bincount(data, weights, minlength)
Expand Down

0 comments on commit 8aff047

Please sign in to comment.