diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md index fea746bb02f4..264f69093709 100644 --- a/docs/api/python/symbol/symbol.md +++ b/docs/api/python/symbol/symbol.md @@ -612,6 +612,7 @@ Composite multiple symbols into a new one by an operator. random.normal random.poisson random.randint + random.randn random.shuffle random.uniform mxnet.random.seed diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py index f19c1e03202f..b0683b439c2a 100644 --- a/python/mxnet/ndarray/random.py +++ b/python/mxnet/ndarray/random.py @@ -220,8 +220,8 @@ def randn(*shape, **kwargs): dtype = kwargs.pop('dtype', _Null) ctx = kwargs.pop('ctx', None) out = kwargs.pop('out', None) - assert isinstance(loc, (int, float)) - assert isinstance(scale, (int, float)) + assert isinstance(loc, (int, float, NDArray)) + assert isinstance(scale, (int, float, NDArray)) return _random_helper(_internal._random_normal, _internal._sample_normal, [loc, scale], shape, dtype, ctx, out, kwargs) diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py index 4bdfe7045625..b2ff104ff0f3 100644 --- a/python/mxnet/symbol/random.py +++ b/python/mxnet/symbol/random.py @@ -22,7 +22,7 @@ from .symbol import Symbol -__all__ = ['uniform', 'normal', 'poisson', 'exponential', 'gamma', 'multinomial', +__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma', 'multinomial', 'negative_binomial', 'generalized_negative_binomial', 'shuffle', 'randint'] @@ -113,6 +113,36 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null, **kwargs): [loc, scale], shape, dtype, kwargs) +def randn(*shape, **kwargs): + """Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float or Symbol, optional + Mean (centre) of the distribution. + scale : float or Symbol, optional + Standard deviation (spread or width) of the distribution. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `loc` and + `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` + are NDArrays with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + """ + loc = kwargs.pop('loc', 0) + scale = kwargs.pop('scale', 1) + dtype = kwargs.pop('dtype', _Null) + assert isinstance(loc, (int, float, Symbol)) + assert isinstance(scale, (int, float, Symbol)) + return _random_helper(_internal._random_normal, _internal._sample_normal, + [loc, scale], shape, dtype, kwargs) + + def poisson(lam=1, shape=_Null, dtype=_Null, **kwargs): """Draw random samples from a Poisson distribution. diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 720c25d2711e..fe276685bfe3 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -61,8 +61,10 @@ def check_with_device(device, dtype): }, { 'name': 'randn', + 'symbol': mx.sym.random.randn, 'ndop': mx.nd.random.randn, 'params': { 'loc': 10.0, 'scale': 0.5 }, + 'inputs': [ ('loc',[ [ 0.0, 2.5 ], [ -9.75, -7.0 ] ]) , ('scale',[ [ 1.0, 3.7 ], [ 4.2, 1.5 ] ]) ], 'checks': [ ('mean', lambda x, params: np.mean(x.astype(np.float64) - params['loc']), tol), ('std', lambda x, params: np.std(x.astype(np.float64)) - params['scale'], tol) @@ -250,6 +252,9 @@ def check_with_device(device, dtype): params = {'shape': shape, 'dtype': dtype, 'ctx': device} params.update({k : mx.nd.array(v, ctx=device, dtype=dtype) for k, v in symbdic['inputs']}) + if name == 'randn': + params.pop('shape') # randn does not accept shape param + args = shape mx.random.seed(128) ret1 = ndop(*args, **params).asnumpy() mx.random.seed(128) @@ -263,14 +268,12 @@ def check_with_device(device, dtype): err = np.abs(check_func(ret2[i,j], stats)) assert err < tol, "%f vs %f: symbolic test: %s check for `%s` did not pass" % (err, tol, check_name, name) - if 'symbol' not in symbdic: continue # randn does not have symbol - # check symbolic symbol = symbdic['symbol'] X = mx.sym.Variable("X") params = symbdic['params'].copy() params.update(shape=shape, dtype=dtype) - if name.endswith('_like'): + if name.endswith('_like') or name == 'randn': params['data'] = mx.sym.ones(params.pop('shape')) Y = symbol(**params) + X x = mx.nd.zeros(shape, dtype=dtype, ctx=device) @@ -298,7 +301,12 @@ def check_with_device(device, dtype): single_param = len(symbdic['inputs']) == 1 v1 = mx.sym.Variable('v1') v2 = mx.sym.Variable('v2') - Y = symbol(v1,**params) if single_param else symbol(v1,v2,**params) + if name == 'randn': + params.pop('shape') # randn does not accept shape param + args=shape + Y = symbol(v1, **params) if single_param else symbol(*args, loc=v1, scale=v2,**params) + else: + Y = symbol(v1,**params) if single_param else symbol(v1,v2,**params) bindings = { 'v1' : mx.nd.array(symbdic['inputs'][0][1]) } if not single_param : bindings.update({ 'v2' : mx.nd.array(symbdic['inputs'][1][1]) }) @@ -315,9 +323,10 @@ def check_with_device(device, dtype): for check_name, check_func, tol in symbdic['checks']: assert np.abs(check_func(samples, params)) < tol, "symbolic test: %s check for `%s` did not pass" % (check_name, name) + if 'pdfsymbol' not in symbdic: continue # randn not tested for pdf + # check pdfs with only a subset of the generated samples un1 = np.resize(un1, (un1.shape[0], un1.shape[1], pdfshape[0], pdfshape[1])) - print(name) symbol = symbdic['pdfsymbol'] pdffunc = symbdic['pdffunc'] v0 = mx.sym.Variable('v0') @@ -355,7 +364,6 @@ def check_with_device(device, dtype): check_symbolic_forward(test_pdf, [un1, p1, p2], [res], atol=forw_atol, rtol=forw_rtol, dtype=dtype) if dtype == np.float64: grad_nodes = ['v1', 'v2'] if symbdic['discrete'] else ['v0', 'v1', 'v2'] - print(backw_rtol) check_numeric_gradient(test_pdf, [un1, p1, p2], grad_nodes=grad_nodes, atol=backw_atol, rtol=backw_rtol, dtype=dtype) @with_seed(1000)