diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 913ceaaff097..6d1bda902068 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -23,7 +23,7 @@ from ..ndarray import NDArray -__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "shuffle", 'gamma'] +__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "shuffle", 'gamma', 'exponential'] def randint(low, high=None, size=None, dtype=None, ctx=None, out=None): @@ -319,6 +319,34 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out) +def exponential(scale, size): + r"""Draw samples from an exponential distribution. + Parameters + ---------- + scale : float or array_like of floats + The scale parameter, :math:`\beta = 1/\lambda`. Must be + non-negative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``scale`` is a scalar. Otherwise, + ``np.array(scale).size`` samples are drawn. + Returns + ------- + out : ndarray or scalar + Drawn samples from the parameterized exponential distribution. + """ + from ...numpy import ndarray as np_ndarray + tensor_type_name = np_ndarray + if size == (): + size = None + is_tensor = isinstance(scale, tensor_type_name) + if is_tensor: + return _npi.exponential(scale, scale=None, size=size) + else: + return _npi.exponential(scale=scale, size=size) + + def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): """Draw samples from a Gamma distribution. diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 198f2fcb4389..fe98a1296ee5 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -20,8 +20,9 @@ from __future__ import absolute_import from ..ndarray import numpy as _mx_nd_np + __all__ = ["randint", "uniform", "normal", "choice", "rand", "multinomial", "shuffle", "randn", - "gamma"] + "gamma", "exponential"] def randint(low, high=None, size=None, dtype=None, ctx=None, out=None): @@ -324,6 +325,28 @@ def rand(*size, **kwargs): return _mx_nd_np.random.uniform(0, 1, size=output_shape, **kwargs) +def exponential(scale=1.0, size=None): + r"""Draw samples from an exponential distribution. + + Parameters + ---------- + scale : float or array_like of floats + The scale parameter, :math:`\beta = 1/\lambda`. Must be + non-negative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``scale`` is a scalar. Otherwise, + ``np.array(scale).size`` samples are drawn. + + Returns + ------- + out : ndarray or scalar + Drawn samples from the parameterized exponential distribution. + """ + return _mx_nd_np.random.exponential(scale, size) + + def shuffle(x): """ Modify a sequence in-place by shuffling its contents. diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index c6b23b507d87..33e57f346403 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -21,7 +21,8 @@ from ...context import current_context from . import _internal as _npi -__all__ = ['randint', 'uniform', 'normal', 'rand', 'shuffle', 'gamma'] + +__all__ = ['randint', 'uniform', 'normal', 'rand', 'shuffle', 'gamma', 'exponential'] def randint(low, high=None, size=None, dtype=None, ctx=None, out=None): @@ -347,6 +348,36 @@ def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): raise ValueError("Distribution parameters must be either _Symbol or numbers") +def exponential(scale=1.0, size=None): + r"""Draw samples from an exponential distribution. + + Parameters + ---------- + scale : float or array_like of floats + The scale parameter, :math:`\beta = 1/\lambda`. Must be + non-negative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``scale`` is a scalar. Otherwise, + ``np.array(scale).size`` samples are drawn. + + Returns + ------- + out : ndarray or scalar + Drawn samples from the parameterized exponential distribution. + """ + from ..numpy import _Symbol as np_symbol + tensor_type_name = np_symbol + if size == (): + size = None + is_tensor = isinstance(scale, tensor_type_name) + if is_tensor: + return _npi.exponential(scale, scale=None, size=size) + else: + return _npi.exponential(scale=scale, size=size) + + def shuffle(x): """ Modify a sequence in-place by shuffling its contents. diff --git a/src/operator/numpy/random/np_exponential_op.cc b/src/operator/numpy/random/np_exponential_op.cc new file mode 100644 index 000000000000..cc79fd8cb080 --- /dev/null +++ b/src/operator/numpy/random/np_exponential_op.cc @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_exponential_op.cc + * \brief Operator for numpy sampling from exponential distributions + */ + +#include "./np_exponential_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyExponentialParam); + +NNVM_REGISTER_OP(_npi_exponential) +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyExponentialParam& param = nnvm::get(attrs.parsed); + int num_inputs = 1; + if (param.scale.has_value()) { + num_inputs -= 1; + } + return num_inputs; + }) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyExponentialParam& param = nnvm::get(attrs.parsed); + int num_inputs = 1; + if (param.scale.has_value()) { + num_inputs -= 1; + } + return (num_inputs == 0) ? std::vector() : std::vector{"input1"}; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", UnaryDistOpShape) +.set_attr("FInferType", + [](const nnvm::NodeAttrs &attrs, std::vector *in_attrs, std::vector *out_attrs) { + (*out_attrs)[0] = mshadow::kFloat32; + return true; + }) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyExponentialForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyExponentialParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_exponential_op.cu b/src/operator/numpy/random/np_exponential_op.cu new file mode 100644 index 000000000000..1c0ff1266429 --- /dev/null +++ b/src/operator/numpy/random/np_exponential_op.cu @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_exponential_op.cu + * \brief Operator for numpy sampling from exponential distributions + */ + +#include "./np_exponential_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_exponential) +.set_attr("FCompute", NumpyExponentialForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_exponential_op.h b/src/operator/numpy/random/np_exponential_op.h new file mode 100644 index 000000000000..6f644292d6ac --- /dev/null +++ b/src/operator/numpy/random/np_exponential_op.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_exponential_op.h + * \brief Operator for numpy sampling from exponential distribution. + */ + +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_EXPONENTIAL_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_EXPONENTIAL_OP_H_ + +#include +#include +#include +#include +#include +#include "../../elemwise_op_common.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +struct NumpyExponentialParam : public dmlc::Parameter { + dmlc::optional scale; + dmlc::optional> size; + DMLC_DECLARE_PARAMETER(NumpyExponentialParam) { + DMLC_DECLARE_FIELD(scale) + .set_default(dmlc::optional(1.0)); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional>()) + .describe("Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + } +}; + +template +struct scalar_exponential_kernel { + MSHADOW_XINLINE static void Map(index_t i, float scale, float *threshold, + DType *out) { + out[i] = -scale * log(threshold[i]); + } +}; + +namespace mxnet_op { + +template +struct check_legal_scale_kernel { + MSHADOW_XINLINE static void Map(index_t i, IType *scalar, float* flag) { + if (scalar[i] < 0.0) { + flag[0] = -1.0; + } + } +}; + + +template +struct exponential_kernel { + MSHADOW_XINLINE static void Map(index_t i, + const Shape &stride, + const Shape &oshape, + IType *scales, float* threshold, OType *out) { + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + out[i] = -scales[idx] * log(threshold[i]); + } +}; + +} // namespace mxnet_op + +template +void NumpyExponentialForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyExponentialParam ¶m = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + index_t output_len = outputs[0].Size(); + Random *prnd = ctx.requested[0].get_random(s); + Tensor workspace = + ctx.requested[1].get_space_typed(Shape1(output_len + 1), s); + Tensor uniform_tensor = workspace.Slice(0, output_len); + Tensor indicator_device = workspace.Slice(output_len, output_len + 1); + float indicator_host = 1.0; + float *indicator_device_ptr = indicator_device.dptr_; + Kernel::Launch(s, 1, indicator_device_ptr); + prnd->SampleUniform(&workspace, 0.0, 1.0); + if (param.scale.has_value()) { + CHECK_GE(param.scale.value(), 0.0) << "ValueError: expect scale >= 0"; + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.scale.value(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + } else { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + Kernel, xpu>::Launch( + s, inputs[0].Size(), inputs[0].dptr(), indicator_device_ptr); + }); + _copy(s, &indicator_host, indicator_device_ptr); + CHECK_GE(indicator_host, 0.0) << "ValueError: expect scale >= 0"; + mxnet::TShape new_lshape, new_oshape; + int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, + &new_lshape, &new_lshape, &new_oshape); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + Shape oshape = new_oshape.get(); + Shape stride = calc_stride(new_lshape.get()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), stride, oshape, inputs[0].dptr(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + }); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_EXPONENTIAL_OP_H_ diff --git a/tests/nightly/test_np_random.py b/tests/nightly/test_np_random.py index d086ac4b08d6..8b604262563e 100644 --- a/tests/nightly/test_np_random.py +++ b/tests/nightly/test_np_random.py @@ -39,6 +39,22 @@ import scipy.stats as ss +@retry(5) +@with_seed() +@use_np +def test_np_exponential(): + samples = 1000000 + # Generation test + trials = 8 + num_buckets = 5 + for scale in [1.0, 5.0]: + buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.expon.ppf(x, scale=scale), num_buckets) + buckets = np.array(buckets, dtype="float32").tolist() + probs = [(buckets[i][1] - buckets[i][0])/scale for i in range(num_buckets)] + generator_mx_np = lambda x: mx.np.random.exponential(size=x).asnumpy() + verify_generator(generator=generator_mx_np, buckets=buckets, probs=probs, nsamples=samples, nrepeat=trials) + + @retry(5) @with_seed() @use_np diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 8d54a53b8bc4..273e5206621f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3431,6 +3431,36 @@ def hybrid_forward(self, F, x): assert out.shape == expected_shape +@with_seed() +@use_np +def test_np_exponential(): + class TestRandomExp(HybridBlock): + def __init__(self, shape): + super(TestRandomExp, self).__init__() + self._shape = shape + + def hybrid_forward(self, F, scale): + return F.np.random.exponential(scale, self._shape) + + shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] + for hybridize in [False, True]: + for shape in shapes: + test_exponential = TestRandomExp(shape) + if hybridize: + test_exponential.hybridize() + np_out = _np.random.exponential(size = shape) + mx_out = test_exponential(np.array([1])) + + for shape in shapes: + mx_out = np.random.exponential(np.array([1]), shape) + np_out = _np.random.exponential(np.array([1]).asnumpy(), shape) + assert_almost_equal(mx_out.asnumpy().shape, np_out.shape) + + def _test_exponential_exception(scale): + output = np.random.exponential(scale=scale).asnumpy() + assertRaises(ValueError, _test_exponential_exception, -1) + + @with_seed() @use_np def test_np_randn():