Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix multinomial bug on gpu (#16204)
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 authored and haojin2 committed Sep 19, 2019
1 parent 66c4207 commit b3da7d2
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 41 deletions.
13 changes: 13 additions & 0 deletions src/operator/numpy/random/np_multinomial_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@
namespace mxnet {
namespace op {

template<typename DType>
void CheckPvalGPU(DType* input, int prob_length) {
std::vector<DType> pvals_(prob_length);
CUDA_CALL(cudaMemcpy(&pvals_[0], input, sizeof(DType) * prob_length,
cudaMemcpyDeviceToHost));
DType sum = DType(0.0);
for (int i = 0; i < prob_length; ++i) {
sum += pvals_[i];
CHECK(sum <= DType(1.0))
<< "sum(pvals[:-1]) > 1.0";
}
}

NNVM_REGISTER_OP(_npi_multinomial)
.set_attr<FCompute>("FCompute<gpu>", NumpyMultinomialForward<gpu>);

Expand Down
26 changes: 18 additions & 8 deletions src/operator/numpy/random/np_multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,19 @@ inline bool NumpyMultinomialOpType(const nnvm::NodeAttrs& attrs,
return true;
}

template<typename DType>
void CheckPvalGPU(DType* input, int prob_length);

template<typename DType>
void CheckPval(DType* input, int prob_length) {
DType sum = DType(0.0);
for (int i = 0; i < prob_length; ++i) {
sum += input[i];
CHECK_LE(sum, 1.0)
<< "sum(pvals[:-1]) > 1.0";
}
}

struct multinomial_kernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i,
Expand Down Expand Up @@ -172,14 +185,11 @@ void NumpyMultinomialForward(const nnvm::NodeAttrs& attrs,
s, num_output, num_exp, prob_length, pvals_, temp_tensor.dptr_, outputs[0].dptr<int64_t>());
} else {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
// check if sum of input(pvals) > 1.0
DType sum = DType(0);
DType* input = inputs[0].dptr<DType>();
for (int i = 0; i < prob_length; ++i) {
sum += input[i];
CHECK_LE(sum, 1.0)
<< "sum(pvals[:-1]) > 1.0";
}
if (std::is_same<xpu, cpu>::value) {
CheckPval<DType>(inputs[0].dptr<DType>(), prob_length);
} else {
CheckPvalGPU<DType>(inputs[0].dptr<DType>(), prob_length);
}
Kernel<multinomial_kernel, xpu>::Launch(
s, num_output, num_exp, prob_length,
inputs[0].dptr<DType>(), temp_tensor.dptr_, outputs[0].dptr<int64_t>());
Expand Down
90 changes: 57 additions & 33 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,47 +821,71 @@ def test_np_multinomial():
pvals_list = [[0.0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0.0]]
sizes = [None, (), (3,), (2, 5, 7), (4, 9)]
experiements = 10000
for have_size in [False, True]:
for pvals in pvals_list:
if have_size:
for size in sizes:
freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy() / _np.float32(experiements)
# for those cases that didn't need reshape
if size in [None, ()]:
mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1)
for pvals_mx_np_array in [False, True]:
for have_size in [False, True]:
for pvals in pvals_list:
if pvals_mx_np_array:
pvals = mx.np.array(pvals)
if have_size:
for size in sizes:
freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy() / _np.float32(experiements)
# for those cases that didn't need reshape
if size in [None, ()]:
if type(pvals) == np.ndarray:
mx.test_utils.assert_almost_equal(freq, pvals.asnumpy(), rtol=0.20, atol=1e-1)
else:
mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1)
else:
# check the shape
assert freq.shape == size + (len(pvals),), 'freq.shape={}, size + (len(pvals))={}'.format(freq.shape, size + (len(pvals)))
freq = freq.reshape((-1, len(pvals)))
# check the value for each row
for i in range(freq.shape[0]):
if type(pvals) == np.ndarray:
mx.test_utils.assert_almost_equal(freq[i, :], pvals.asnumpy(), rtol=0.20, atol=1e-1)
else:
mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1)
else:
freq = mx.np.random.multinomial(experiements, pvals).asnumpy() / _np.float32(experiements)
if type(pvals) == np.ndarray:
mx.test_utils.assert_almost_equal(freq, pvals.asnumpy(), rtol=0.20, atol=1e-1)
else:
# check the shape
assert freq.shape == size + (len(pvals),), 'freq.shape={}, size + (len(pvals))={}'.format(freq.shape, size + (len(pvals)))
freq = freq.reshape((-1, len(pvals)))
# check the value for each row
for i in range(freq.shape[0]):
mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1)
else:
freq = mx.np.random.multinomial(experiements, pvals).asnumpy() / _np.float32(experiements)
mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1)
mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1)
# check the zero dimension
sizes = [(0), (0, 2), (4, 0, 2), (3, 0, 1, 2, 0)]
for pvals in pvals_list:
for size in sizes:
freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy()
assert freq.size == 0
for pvals_mx_np_array in [False, True]:
for pvals in pvals_list:
for size in sizes:
if pvals_mx_np_array:
pvals = mx.np.array(pvals)
freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy()
assert freq.size == 0
# check [] as pvals
for pvals in [[], ()]:
freq = mx.np.random.multinomial(experiements, pvals).asnumpy()
assert freq.size == 0
for size in sizes:
freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy()
for pvals_mx_np_array in [False, True]:
for pvals in [[], ()]:
if pvals_mx_np_array:
pvals = mx.np.array(pvals)
freq = mx.np.random.multinomial(experiements, pvals).asnumpy()
assert freq.size == 0
for size in sizes:
freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy()
assert freq.size == 0
# test small experiment for github issue
# https://github.com/apache/incubator-mxnet/issues/15383
small_exp, total_exp = 20, 10000
for pvals in pvals_list:
x = np.random.multinomial(small_exp, pvals)
for i in range(total_exp // small_exp):
x = x + np.random.multinomial(20, pvals)
freq = (x.asnumpy() / _np.float32(total_exp)).reshape((-1, len(pvals)))
for i in range(freq.shape[0]):
mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1)
for pvals_mx_np_array in [False, True]:
for pvals in pvals_list:
if pvals_mx_np_array:
pvals = mx.np.array(pvals)
x = np.random.multinomial(small_exp, pvals)
for i in range(total_exp // small_exp):
x = x + np.random.multinomial(20, pvals)
freq = (x.asnumpy() / _np.float32(total_exp)).reshape((-1, len(pvals)))
for i in range(freq.shape[0]):
if type(pvals) == np.ndarray:
mx.test_utils.assert_almost_equal(freq[i, :], pvals.asnumpy(), rtol=0.20, atol=1e-1)
else:
mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1)


if __name__ == '__main__':
Expand Down

0 comments on commit b3da7d2

Please sign in to comment.