Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add symbol api for randn and fix shape issue for randn ndarray and symbol api #15772

Merged
merged 11 commits into from
Aug 25, 2019
2 changes: 0 additions & 2 deletions python/mxnet/ndarray/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ def randn(*shape, **kwargs):
dtype = kwargs.pop('dtype', _Null)
ctx = kwargs.pop('ctx', None)
out = kwargs.pop('out', None)
assert isinstance(loc, (int, float))
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(scale, (int, float))
return _random_helper(_internal._random_normal, _internal._sample_normal,
[loc, scale], shape, dtype, ctx, out, kwargs)

Expand Down
25 changes: 25 additions & 0 deletions python/mxnet/symbol/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,31 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null, **kwargs):
[loc, scale], shape, dtype, kwargs)


def randn(*shape, **kwargs):
"""Draw random samples from a normal (Gaussian) distribution.

Samples are distributed according to a normal distribution parametrized
by *loc* (mean) and *scale* (standard deviation).


Parameters
----------
loc : float or NDArray
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
Mean (centre) of the distribution.
scale : float or NDArray
Standard deviation (spread or width) of the distribution.
shape : int or tuple of ints
The number of samples to draw. If shape is, e.g., `(m, n)` and `loc` and
`scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale`
are NDArrays with shape, e.g., `(x, y)`, then output will have shape
`(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair.
dtype : {'float16', 'float32', 'float64'}
Data type of output samples. Default is 'float32'
"""
return _random_helper(_internal._random_normal, _internal._sample_normal,
[loc, scale], shape, dtype, ctx, out, kwargs)
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved


def poisson(lam=1, shape=_Null, dtype=_Null, **kwargs):
"""Draw random samples from a Poisson distribution.

Expand Down