diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index d892ccdaca73..9ea2ef0f5405 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -17,8 +17,10 @@ """Namespace for operators used in Gluon dispatched by F=ndarray.""" from __future__ import absolute_import +import numpy as np from ...context import current_context from . import _internal as _npi +from ..ndarray import NDArray from ...base import numeric_types @@ -189,3 +191,55 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): raise NotImplementedError('np.random.normal only supports loc and scale of ' 'numeric types for now') return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs) + + +def multinomial(n, pvals, size=None): + """multinomial(n, pvals, size=None) + + Draw samples from a multinomial distribution. + + The multinomial distribution is a multivariate generalisation of the binomial distribution. + Take an experiment with one of ``p`` possible outcomes. An example of such an experiment is throwing a dice, + where the outcome can be 1 through 6. Each sample drawn from the distribution represents n such experiments. + Its values, ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the outcome was ``i``. + + Parameters + ---------- + n : int + Number of experiments. + pvals : sequence of floats, length p + Probabilities of each of the p different outcomes. These should sum to 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` samples + are drawn. Default is None, in which case a single value is returned. + + Returns + ------- + out : ndarray + The drawn samples, of shape size, if that was provided. If not, the shape is ``(N,)``. + In other words, each entry ``out[i,j,...,:]`` is an N-dimensional value drawn from the distribution. + + Examples + -------- + Throw a dice 1000 times, and 1000 times again: + + >>> np.random.multinomial(1000, [1/6.]*6, size=2) + array([[164, 161, 179, 158, 150, 188], + [178, 162, 177, 143, 163, 177]]) + + A loaded die is more likely to land on number 6: + + >>> np.random.multinomial(100, [1/7.]*5 + [2/7.]) + array([19, 14, 12, 11, 21, 23]) + + >>> np.random.multinomial(100, [1.0 / 3, 2.0 / 3]) + array([32, 68]) + """ + if isinstance(pvals, NDArray): + return _npi.multinomial(pvals, pvals=None, n=n, size=size) + else: + if isinstance(pvals, np.ndarray): + raise ValueError('numpy ndarray is not supported!') + if any(isinstance(i, list) for i in pvals): + raise ValueError('object too deep for desired array') + return _npi.multinomial(n=n, pvals=pvals, size=size) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index dc6107476f81..bd534f0b6d97 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -144,3 +144,39 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): This function currently does not support ``loc`` and ``scale`` as ndarrays. """ return _mx_nd_np.random.normal(loc, scale, size, **kwargs) + + +def multinomial(n, pvals, size=None, **kwargs): + """multinomial(n, pvals, size=None) + Draw samples from a multinomial distribution. + The multinomial distribution is a multivariate generalisation of the binomial distribution. + Take an experiment with one of ``p`` possible outcomes. An example of such an experiment is throwing a dice, + where the outcome can be 1 through 6. Each sample drawn from the distribution represents n such experiments. + Its values, ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the outcome was ``i``. + Parameters + ---------- + n : int + Number of experiments. + pvals : sequence of floats, length p + Probabilities of each of the p different outcomes. These should sum to 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` samples + are drawn. Default is None, in which case a single value is returned. + Returns + ------- + out : ndarray + The drawn samples, of shape size, if that was provided. If not, the shape is ``(N,)``. + In other words, each entry ``out[i,j,...,:]`` is an N-dimensional value drawn from the distribution. + Examples + -------- + Throw a dice 1000 times, and 1000 times again: + >>> np.random.multinomial(1000, [1/6.]*6, size=2) + array([[164, 161, 179, 158, 150, 188], + [178, 162, 177, 143, 163, 177]]) + A loaded die is more likely to land on number 6: + >>> np.random.multinomial(100, [1/7.]*5 + [2/7.]) + array([19, 14, 12, 11, 21, 23]) + >>> np.random.multinomial(100, [1.0 / 3, 2.0 / 3]) + array([32, 68]) + """ + return _mx_nd_np.random.multinomial(n, pvals, size, **kwargs) diff --git a/src/operator/numpy/random/np_multinomial_op.cc b/src/operator/numpy/random/np_multinomial_op.cc new file mode 100644 index 000000000000..bf4f88c591cf --- /dev/null +++ b/src/operator/numpy/random/np_multinomial_op.cc @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_multinomial_op.h + * \brief Operator for numpy sampling from multinomial distributions + */ +#include "./np_multinomial_op.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyMultinomialParam); + +NNVM_REGISTER_OP(_npi_multinomial) +.describe(R"code(Draw samples from a multinomial distribution. " +"The multinomial distribution is a multivariate generalisation of the binomial distribution. " +"Take an experiment with one of p possible outcomes. " +"An example of such an experiment is throwing a dice, where the outcome can be 1 through 6. " +"Each sample drawn from the distribution represents n such experiments. " +"Its values, X_i = [X_0, X_1, ..., X_p], represent the number of times the outcome was i. +)code") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyMultinomialParam& param = nnvm::get(attrs.parsed); + return param.pvals.has_value() ? 0U : 1U; + } +) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyMultinomialOpShape) +.set_attr("FInferType", NumpyMultinomialOpType) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyMultinomialForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("a", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyMultinomialParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_multinomial_op.cu b/src/operator/numpy/random/np_multinomial_op.cu new file mode 100644 index 000000000000..a80926024735 --- /dev/null +++ b/src/operator/numpy/random/np_multinomial_op.cu @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_multinomial_op.cu + * \brief Operator for numpy sampling from multinomial distributions + */ +#include "./np_multinomial_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_multinomial) +.set_attr("FCompute", NumpyMultinomialForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_multinomial_op.h b/src/operator/numpy/random/np_multinomial_op.h new file mode 100644 index 000000000000..7115f2761202 --- /dev/null +++ b/src/operator/numpy/random/np_multinomial_op.h @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_multinomial_op.h + * \brief Operator for sampling from multinomial distributions + */ +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_ + +#include +#include +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +struct NumpyMultinomialParam : public dmlc::Parameter { + int n; + dmlc::optional> pvals; + dmlc::optional> size; + DMLC_DECLARE_PARAMETER(NumpyMultinomialParam) { + DMLC_DECLARE_FIELD(n) + .describe("Number of experiments."); + DMLC_DECLARE_FIELD(pvals) + .set_default(dmlc::optional>()) + .describe("Probabilities of each of the p different outcomes. " + "These should sum to 1 (however, the last element is always assumed to " + "account for the remaining probability, as long as sum(pvals[:-1]) <= 1)" + "Note that this is for internal usage only. " + "This operator will only have either input mx.ndarray or this list of pvals"); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional>()) + .describe("Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + } +}; + +inline bool NumpyMultinomialOpShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const NumpyMultinomialParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size(), 1U); + + std::vector oshape_vec; + dim_t pvals_length; + if (param.pvals.has_value()) { + CHECK_EQ(in_attrs->size(), 0U); + pvals_length = param.pvals.value().ndim(); + } else { + // pvals is from input ndarray + CHECK_EQ(in_attrs->size(), 1U); + const TShape& ishape = (*in_attrs)[0]; + // check the input shape is only one dimension + CHECK_EQ(ishape.ndim(), 1U) + << "object too deep for desired array"; + pvals_length = ishape[0]; + } + if (param.size.has_value()) { + const mxnet::Tuple& size = param.size.value(); + for (int i = 0; i < size.ndim(); ++i) { + oshape_vec.emplace_back(size[i]); + } + } + oshape_vec.emplace_back(pvals_length); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec)); + return out_attrs->at(0).ndim() != 0U;; +} + +inline bool NumpyMultinomialOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + const NumpyMultinomialParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), (param.pvals.has_value()) ? 0U : 1U); + CHECK_EQ(out_attrs->size(), 1U); + + (*out_attrs)[0] = mshadow::kInt64; + return true; +} + +struct multinomial_kernel { + template + MSHADOW_XINLINE static void Map(int i, + const int num_exp, + const int prob_length, + DType* pvals, + double* uniform, + int64_t* out) { + for (int j = 0; j < num_exp; ++j) { + DType loc = static_cast(uniform[i * num_exp + j]); + DType acc = 0.0; + bool found = false; + for (int k = 0; k < prob_length; ++k) { + acc += pvals[k]; + if (acc > loc) { + found = true; + out[i * prob_length + k] += 1; + break; + } + } + if (!found) { + out[i * prob_length + (prob_length - 1)] += 1; + } + } + } +}; + +template +void NumpyMultinomialForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyMultinomialParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(inputs.size(), (param.pvals.has_value()) ? 0U : 1U); + + int prob_length = (param.pvals.has_value()) + ? param.pvals.value().ndim() : inputs[0].shape_[0]; + // if intput is [] or size contains 0 dimension + if (prob_length == 0U || outputs[0].shape_.Size() == 0) return; + int num_output = outputs[0].Size() / prob_length; + int num_exp = param.n; + Stream *s = ctx.get_stream(); + Random *prnd = ctx.requested[0].get_random(s); + size_t temp_space_ = (param.pvals.has_value()) + ? num_output * param.n + prob_length : num_output * param.n; + Tensor temp_tensor = + ctx.requested[1].get_space_typed(Shape1(temp_space_), s); + + prnd->SampleUniform(&temp_tensor, 0, 1); + // set zero for the outputs + Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr()); + if (param.pvals.has_value()) { + // create a tensor to copy the param.pvals tuple to avoid + // error: calling a __host__ function from a __host__ __device__ function is not allowed + // reuse the uniform temp space to create pval tensor + double* pvals_ = temp_tensor.dptr_ + num_output * param.n; + // check if sum of input(pvals) > 1.0 + double sum = 0.0; + for (int i = 0; i < prob_length; ++i) { + sum += param.pvals.value()[i]; + // copy the tuple to data for later kernel usage + pvals_[i] = param.pvals.value()[i]; + CHECK_LE(sum, 1.0) + << "sum(pvals[:-1]) > 1.0"; + } + Kernel::Launch( + s, num_output, num_exp, prob_length, pvals_, temp_tensor.dptr_, outputs[0].dptr()); + } else { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + // check if sum of input(pvals) > 1.0 + DType sum = DType(0); + DType* input = inputs[0].dptr(); + for (int i = 0; i < prob_length; ++i) { + sum += input[i]; + CHECK_LE(sum, 1.0) + << "sum(pvals[:-1]) > 1.0"; + } + Kernel::Launch( + s, num_output, num_exp, prob_length, + inputs[0].dptr(), temp_tensor.dptr_, outputs[0].dptr()); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_ diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index fafa5a827c2f..883060466836 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -767,6 +767,56 @@ def test_np_uniform(): mx.test_utils.assert_almost_equal(uniform_samples.asnumpy().mean(0), expect_mean.asnumpy(), rtol=0.20, atol=1e-1) +@retry(5) +@with_seed() +@use_np +def test_np_multinomial(): + pvals_list = [[0.0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0.0]] + sizes = [None, (), (3,), (2, 5, 7), (4, 9)] + experiements = 10000 + for have_size in [False, True]: + for pvals in pvals_list: + if have_size: + for size in sizes: + freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy() / _np.float32(experiements) + # for those cases that didn't need reshape + if size in [None, ()]: + mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1) + else: + # check the shape + assert freq.shape == size + (len(pvals),), 'freq.shape={}, size + (len(pvals))={}'.format(freq.shape, size + (len(pvals))) + freq = freq.reshape((-1, len(pvals))) + # check the value for each row + for i in range(freq.shape[0]): + mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1) + else: + freq = mx.np.random.multinomial(experiements, pvals).asnumpy() / _np.float32(experiements) + mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1) + # check the zero dimension + sizes = [(0), (0, 2), (4, 0, 2), (3, 0, 1, 2, 0)] + for pvals in pvals_list: + for size in sizes: + freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy() + assert freq.size == 0 + # check [] as pvals + for pvals in [[], ()]: + freq = mx.np.random.multinomial(experiements, pvals).asnumpy() + assert freq.size == 0 + for size in sizes: + freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy() + assert freq.size == 0 + # test small experiment for github issue + # https://github.com/apache/incubator-mxnet/issues/15383 + small_exp, total_exp = 20, 10000 + for pvals in pvals_list: + x = np.random.multinomial(small_exp, pvals) + for i in range(total_exp // small_exp): + x = x + np.random.multinomial(20, pvals) + freq = (x.asnumpy() / _np.float32(total_exp)).reshape((-1, len(pvals))) + for i in range(freq.shape[0]): + mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1) + + if __name__ == '__main__': import nose nose.runmodule()