Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[numpy]add op random.lognormal #17415

Merged
merged 3 commits into from
Feb 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "multivariate_normal",
"shuffle", 'gamma', 'beta', 'exponential']
"shuffle", 'gamma', 'beta', 'exponential', 'lognormal']


def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
Expand Down Expand Up @@ -194,6 +194,39 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None):
ctx=ctx, dtype=dtype, out=out)


def lognormal(mean=0.0, sigma=1.0, size=None, dtype=None, ctx=None, out=None):
r"""Draw samples from a log-normal distribution.
Draw samples from a log-normal distribution with specified mean,
standard deviation, and array shape. Note that the mean and standard
deviation are not the values for the distribution itself, but of the
underlying normal distribution it is derived from.
Parameters
----------
mean : float or array_like of floats, optional
Mean value of the underlying normal distribution. Default is 0.
sigma : float or array_like of floats, optional
Standard deviation of the underlying normal distribution. Must be
non-negative. Default is 1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``mean`` and ``sigma`` are both scalars.
Otherwise, ``np.broadcast(mean, sigma).size`` samples are drawn.
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
ctx : Context, optional
Device context of output. Default is current context.
out : ``ndarray``, optional
Store output to an existing ``ndarray``.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized log-normal distribution.
"""
from . import _op as _mx_np_op
return _mx_np_op.exp(normal(loc=mean, scale=sigma, size=size, dtype=dtype, ctx=ctx, out=out))


def multinomial(n, pvals, size=None):
r"""multinomial(n, pvals, size=None)

Expand Down
61 changes: 60 additions & 1 deletion python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


__all__ = ["randint", "uniform", "normal", "choice", "rand", "multinomial", "multivariate_normal",
"shuffle", "randn", "gamma", 'beta', "exponential"]
"shuffle", "randn", "gamma", 'beta', "exponential", "lognormal"]


def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
Expand Down Expand Up @@ -201,6 +201,65 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None):
return _mx_nd_np.random.normal(loc, scale, size, dtype, ctx, out)


def lognormal(mean=0.0, sigma=1.0, size=None, dtype=None, ctx=None, out=None):
r"""Draw samples from a log-normal distribution.
Draw samples from a log-normal distribution with specified mean,
standard deviation, and array shape. Note that the mean and standard
deviation are not the values for the distribution itself, but of the
underlying normal distribution it is derived from.
Parameters
----------
mean : float or array_like of floats, optional
Mean value of the underlying normal distribution. Default is 0.
sigma : float or array_like of floats, optional
Standard deviation of the underlying normal distribution. Must be
non-negative. Default is 1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``mean`` and ``sigma`` are both scalars.
Otherwise, ``np.broadcast(mean, sigma).size`` samples are drawn.
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
ctx : Context, optional
Device context of output. Default is current context.
out : ``ndarray``, optional
Store output to an existing ``ndarray``.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized log-normal distribution.
Notes
-----
A variable `x` has a log-normal distribution if `log(x)` is normally
distributed. The probability density function for the log-normal
distribution is:
.. math:: p(x) = \frac{1}{\sigma x \sqrt{2\pi}}
e^{(-\frac{(ln(x)-\mu)^2}{2\sigma^2})}
where :math:`\mu` is the mean and :math:`\sigma` is the standard
deviation of the normally distributed logarithm of the variable.
A log-normal distribution results if a random variable is the *product*
of a large number of independent, identically-distributed variables in
the same way that a normal distribution results if the variable is the
*sum* of a large number of independent, identically-distributed
variables.
References
----------
.. [1] Limpert, E., Stahel, W. A., and Abbt, M., "Log-normal
Distributions across the Sciences: Keys and Clues,"
BioScience, Vol. 51, No. 5, May, 2001.
https://stat.ethz.ch/~stahel/lognormal/bioscience.pdf
.. [2] Reiss, R.D. and Thomas, M., "Statistical Analysis of Extreme
Values," Basel: Birkhauser Verlag, 2001, pp. 31-32.
Examples
--------
Draw samples from the distribution:
>>> mu, sigma = 3., 1. # mean and standard deviation
>>> s = np.random.lognormal(mu, sigma, 1000)
"""
return _mx_nd_np.random.lognormal(mean, sigma, size, dtype, ctx, out)


def multinomial(n, pvals, size=None, **kwargs):
r"""
Draw samples from a multinomial distribution.
Expand Down
35 changes: 34 additions & 1 deletion python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


__all__ = ['randint', 'uniform', 'normal', 'multivariate_normal',
'rand', 'shuffle', 'gamma', 'beta', 'exponential']
'rand', 'shuffle', 'gamma', 'beta', 'exponential', 'lognormal']


def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
Expand Down Expand Up @@ -218,6 +218,39 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None):
ctx=ctx, dtype=dtype, out=out)


def lognormal(mean=0.0, sigma=1.0, size=None, dtype=None, ctx=None, out=None):
r"""Draw samples from a log-normal distribution.
Draw samples from a log-normal distribution with specified mean,
standard deviation, and array shape. Note that the mean and standard
deviation are not the values for the distribution itself, but of the
underlying normal distribution it is derived from.
Parameters
----------
mean : float or array_like of floats, optional
Mean value of the underlying normal distribution. Default is 0.
sigma : float or array_like of floats, optional
Standard deviation of the underlying normal distribution. Must be
non-negative. Default is 1.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``mean`` and ``sigma`` are both scalars.
Otherwise, ``np.broadcast(mean, sigma).size`` samples are drawn.
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
ctx : Context, optional
Device context of output. Default is current context.
out : ``ndarray``, optional
Store output to an existing ``ndarray``.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized log-normal distribution.
"""
from . import _symbol as _mx_np_symbol
return _mx_np_symbol.exp(normal(loc=mean, scale=sigma, size=size, dtype=dtype, ctx=ctx, out=out))


def choice(a, size=None, replace=True, p=None, ctx=None, out=None):
r"""Generates a random sample from a given 1-D array

Expand Down
50 changes: 50 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,56 @@ def hybrid_forward(self, F, loc, scale):
assert_almost_equal(loc.grad.asnumpy().sum(), _np.ones(out_shape).sum(), rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_lognormal_grad():
class TestLognormalGrad(HybridBlock):
def __init__(self, shape):
super(TestLognormalGrad, self).__init__()
self._shape = shape

def hybrid_forward(self, F, mean, sigma):
return F.np.random.lognormal(mean, sigma, self._shape)

param_shape = [
[(3, 2), (3, 2)],
[(3, 2, 2), (3, 2, 2)],
[(3, 4, 5), (4, 1)],
]
output_shapes = [
(3, 2),
(4, 3, 2, 2),
(3, 4, 5)
]
for hybridize in [False, True]:
for ((shape1, shape2), out_shape) in zip(param_shape, output_shapes):
test_lognormal_grad = TestLognormalGrad(out_shape)
if hybridize:
test_lognormal_grad.hybridize()
mean = np.zeros(shape1)
mean.attach_grad()
sigma = np.ones(shape2)
sigma.attach_grad()
with mx.autograd.record():
mx_out = test_lognormal_grad(mean, sigma)
np_out = _np.random.lognormal(mean = mean.asnumpy(),
sigma = sigma.asnumpy(), size = out_shape)
assert np_out.shape == mx_out.shape
mx_out.backward()
assert mean.grad.shape == shape1
assert sigma.grad.shape == shape2
assert_almost_equal(mean.grad.asnumpy().sum(), mx_out.asnumpy().sum(), rtol=1e-3, atol=1e-5)

for ((shape1, shape2), out_shape) in zip(param_shape, output_shapes):
mx_out = np.random.lognormal(np.zeros(shape1), np.ones(shape2), out_shape)
np_out = _np.random.lognormal(np.zeros(shape1).asnumpy(), np.ones(shape2).asnumpy(), out_shape)
assert_almost_equal(mx_out.asnumpy().shape, np_out.shape)

def _test_lognormal_exception(sigma):
output = np.random.lognormal(sigma=sigma).asnumpy()
assertRaises(ValueError, _test_lognormal_exception, -1)


@with_seed()
@use_np
def test_npx_sample_n():
Expand Down