Skip to content

Commit

Permalink
numpy multinomial op (apache#15878)
Browse files Browse the repository at this point in the history
* numpy multinomial op

* address the comment

* retrigger CI

* retrigger CI
  • Loading branch information
stu1130 authored and gyshi committed Sep 7, 2019
1 parent 779dbf6 commit ad1d278
Show file tree
Hide file tree
Showing 6 changed files with 428 additions and 0 deletions.
54 changes: 54 additions & 0 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

"""Namespace for operators used in Gluon dispatched by F=ndarray."""
from __future__ import absolute_import
import numpy as np
from ...context import current_context
from . import _internal as _npi
from ..ndarray import NDArray
from ...base import numeric_types


Expand Down Expand Up @@ -189,3 +191,55 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs):
raise NotImplementedError('np.random.normal only supports loc and scale of '
'numeric types for now')
return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs)


def multinomial(n, pvals, size=None):
"""multinomial(n, pvals, size=None)
Draw samples from a multinomial distribution.
The multinomial distribution is a multivariate generalisation of the binomial distribution.
Take an experiment with one of ``p`` possible outcomes. An example of such an experiment is throwing a dice,
where the outcome can be 1 through 6. Each sample drawn from the distribution represents n such experiments.
Its values, ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the outcome was ``i``.
Parameters
----------
n : int
Number of experiments.
pvals : sequence of floats, length p
Probabilities of each of the p different outcomes. These should sum to 1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` samples
are drawn. Default is None, in which case a single value is returned.
Returns
-------
out : ndarray
The drawn samples, of shape size, if that was provided. If not, the shape is ``(N,)``.
In other words, each entry ``out[i,j,...,:]`` is an N-dimensional value drawn from the distribution.
Examples
--------
Throw a dice 1000 times, and 1000 times again:
>>> np.random.multinomial(1000, [1/6.]*6, size=2)
array([[164, 161, 179, 158, 150, 188],
[178, 162, 177, 143, 163, 177]])
A loaded die is more likely to land on number 6:
>>> np.random.multinomial(100, [1/7.]*5 + [2/7.])
array([19, 14, 12, 11, 21, 23])
>>> np.random.multinomial(100, [1.0 / 3, 2.0 / 3])
array([32, 68])
"""
if isinstance(pvals, NDArray):
return _npi.multinomial(pvals, pvals=None, n=n, size=size)
else:
if isinstance(pvals, np.ndarray):
raise ValueError('numpy ndarray is not supported!')
if any(isinstance(i, list) for i in pvals):
raise ValueError('object too deep for desired array')
return _npi.multinomial(n=n, pvals=pvals, size=size)
36 changes: 36 additions & 0 deletions python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,39 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs):
This function currently does not support ``loc`` and ``scale`` as ndarrays.
"""
return _mx_nd_np.random.normal(loc, scale, size, **kwargs)


def multinomial(n, pvals, size=None, **kwargs):
"""multinomial(n, pvals, size=None)
Draw samples from a multinomial distribution.
The multinomial distribution is a multivariate generalisation of the binomial distribution.
Take an experiment with one of ``p`` possible outcomes. An example of such an experiment is throwing a dice,
where the outcome can be 1 through 6. Each sample drawn from the distribution represents n such experiments.
Its values, ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the outcome was ``i``.
Parameters
----------
n : int
Number of experiments.
pvals : sequence of floats, length p
Probabilities of each of the p different outcomes. These should sum to 1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` samples
are drawn. Default is None, in which case a single value is returned.
Returns
-------
out : ndarray
The drawn samples, of shape size, if that was provided. If not, the shape is ``(N,)``.
In other words, each entry ``out[i,j,...,:]`` is an N-dimensional value drawn from the distribution.
Examples
--------
Throw a dice 1000 times, and 1000 times again:
>>> np.random.multinomial(1000, [1/6.]*6, size=2)
array([[164, 161, 179, 158, 150, 188],
[178, 162, 177, 143, 163, 177]])
A loaded die is more likely to land on number 6:
>>> np.random.multinomial(100, [1/7.]*5 + [2/7.])
array([19, 14, 12, 11, 21, 23])
>>> np.random.multinomial(100, [1.0 / 3, 2.0 / 3])
array([32, 68])
"""
return _mx_nd_np.random.multinomial(n, pvals, size, **kwargs)
61 changes: 61 additions & 0 deletions src/operator/numpy/random/np_multinomial_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_multinomial_op.h
* \brief Operator for numpy sampling from multinomial distributions
*/
#include "./np_multinomial_op.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(NumpyMultinomialParam);

NNVM_REGISTER_OP(_npi_multinomial)
.describe(R"code(Draw samples from a multinomial distribution. "
"The multinomial distribution is a multivariate generalisation of the binomial distribution. "
"Take an experiment with one of p possible outcomes. "
"An example of such an experiment is throwing a dice, where the outcome can be 1 through 6. "
"Each sample drawn from the distribution represents n such experiments. "
"Its values, X_i = [X_0, X_1, ..., X_p], represent the number of times the outcome was i.
)code")
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed);
return param.pvals.has_value() ? 0U : 1U;
}
)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyMultinomialParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyMultinomialOpShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyMultinomialOpType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<ResourceRequest>{
ResourceRequest::kRandom, ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyMultinomialForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("a", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyMultinomialParam::__FIELDS__());

} // namespace op
} // namespace mxnet
34 changes: 34 additions & 0 deletions src/operator/numpy/random/np_multinomial_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_multinomial_op.cu
* \brief Operator for numpy sampling from multinomial distributions
*/
#include "./np_multinomial_op.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_npi_multinomial)
.set_attr<FCompute>("FCompute<gpu>", NumpyMultinomialForward<gpu>);

} // namespace op
} // namespace mxnet
193 changes: 193 additions & 0 deletions src/operator/numpy/random/np_multinomial_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_multinomial_op.h
* \brief Operator for sampling from multinomial distributions
*/
#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_
#define MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_

#include <mxnet/operator_util.h>
#include <vector>
#include "../../mshadow_op.h"
#include "../../mxnet_op.h"
#include "../../operator_common.h"
#include "../../elemwise_op_common.h"

namespace mxnet {
namespace op {

struct NumpyMultinomialParam : public dmlc::Parameter<NumpyMultinomialParam> {
int n;
dmlc::optional<mxnet::Tuple<double>> pvals;
dmlc::optional<mxnet::Tuple<int>> size;
DMLC_DECLARE_PARAMETER(NumpyMultinomialParam) {
DMLC_DECLARE_FIELD(n)
.describe("Number of experiments.");
DMLC_DECLARE_FIELD(pvals)
.set_default(dmlc::optional<mxnet::Tuple<double>>())
.describe("Probabilities of each of the p different outcomes. "
"These should sum to 1 (however, the last element is always assumed to "
"account for the remaining probability, as long as sum(pvals[:-1]) <= 1)"
"Note that this is for internal usage only. "
"This operator will only have either input mx.ndarray or this list of pvals");
DMLC_DECLARE_FIELD(size)
.set_default(dmlc::optional<mxnet::Tuple<int>>())
.describe("Output shape. If the given shape is, "
"e.g., (m, n, k), then m * n * k samples are drawn. "
"Default is None, in which case a single value is returned.");
}
};

inline bool NumpyMultinomialOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed);
CHECK_EQ(out_attrs->size(), 1U);

std::vector<dim_t> oshape_vec;
dim_t pvals_length;
if (param.pvals.has_value()) {
CHECK_EQ(in_attrs->size(), 0U);
pvals_length = param.pvals.value().ndim();
} else {
// pvals is from input ndarray
CHECK_EQ(in_attrs->size(), 1U);
const TShape& ishape = (*in_attrs)[0];
// check the input shape is only one dimension
CHECK_EQ(ishape.ndim(), 1U)
<< "object too deep for desired array";
pvals_length = ishape[0];
}
if (param.size.has_value()) {
const mxnet::Tuple<int>& size = param.size.value();
for (int i = 0; i < size.ndim(); ++i) {
oshape_vec.emplace_back(size[i]);
}
}
oshape_vec.emplace_back(pvals_length);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec));
return out_attrs->at(0).ndim() != 0U;;
}

inline bool NumpyMultinomialOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), (param.pvals.has_value()) ? 0U : 1U);
CHECK_EQ(out_attrs->size(), 1U);

(*out_attrs)[0] = mshadow::kInt64;
return true;
}

struct multinomial_kernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i,
const int num_exp,
const int prob_length,
DType* pvals,
double* uniform,
int64_t* out) {
for (int j = 0; j < num_exp; ++j) {
DType loc = static_cast<DType>(uniform[i * num_exp + j]);
DType acc = 0.0;
bool found = false;
for (int k = 0; k < prob_length; ++k) {
acc += pvals[k];
if (acc > loc) {
found = true;
out[i * prob_length + k] += 1;
break;
}
}
if (!found) {
out[i * prob_length + (prob_length - 1)] += 1;
}
}
}
};

template<typename xpu>
void NumpyMultinomialForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(inputs.size(), (param.pvals.has_value()) ? 0U : 1U);

int prob_length = (param.pvals.has_value())
? param.pvals.value().ndim() : inputs[0].shape_[0];
// if intput is [] or size contains 0 dimension
if (prob_length == 0U || outputs[0].shape_.Size() == 0) return;
int num_output = outputs[0].Size() / prob_length;
int num_exp = param.n;
Stream<xpu> *s = ctx.get_stream<xpu>();
Random<xpu, double> *prnd = ctx.requested[0].get_random<xpu, double>(s);
size_t temp_space_ = (param.pvals.has_value())
? num_output * param.n + prob_length : num_output * param.n;
Tensor<xpu, 1, double> temp_tensor =
ctx.requested[1].get_space_typed<xpu, 1, double>(Shape1(temp_space_), s);

prnd->SampleUniform(&temp_tensor, 0, 1);
// set zero for the outputs
Kernel<set_zero, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<int64_t>());
if (param.pvals.has_value()) {
// create a tensor to copy the param.pvals tuple to avoid
// error: calling a __host__ function from a __host__ __device__ function is not allowed
// reuse the uniform temp space to create pval tensor
double* pvals_ = temp_tensor.dptr_ + num_output * param.n;
// check if sum of input(pvals) > 1.0
double sum = 0.0;
for (int i = 0; i < prob_length; ++i) {
sum += param.pvals.value()[i];
// copy the tuple to data for later kernel usage
pvals_[i] = param.pvals.value()[i];
CHECK_LE(sum, 1.0)
<< "sum(pvals[:-1]) > 1.0";
}
Kernel<multinomial_kernel, xpu>::Launch(
s, num_output, num_exp, prob_length, pvals_, temp_tensor.dptr_, outputs[0].dptr<int64_t>());
} else {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
// check if sum of input(pvals) > 1.0
DType sum = DType(0);
DType* input = inputs[0].dptr<DType>();
for (int i = 0; i < prob_length; ++i) {
sum += input[i];
CHECK_LE(sum, 1.0)
<< "sum(pvals[:-1]) > 1.0";
}
Kernel<multinomial_kernel, xpu>::Launch(
s, num_output, num_exp, prob_length,
inputs[0].dptr<DType>(), temp_tensor.dptr_, outputs[0].dptr<int64_t>());
});
}
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_
Loading

0 comments on commit ad1d278

Please sign in to comment.