Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add interface for rand
Browse files Browse the repository at this point in the history
add relevant tests

address comments.

* fix document string -> Returns description.

Fix
  • Loading branch information
kshitij12345 authored and reminisce committed Oct 20, 2019
1 parent 5accae0 commit b949716
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 10 deletions.
29 changes: 28 additions & 1 deletion python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..ndarray import NDArray


__all__ = ['randint', 'uniform', 'normal', "choice"]
__all__ = ['randint', 'uniform', 'normal', "choice", "rand"]


def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
Expand Down Expand Up @@ -317,3 +317,30 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None):
return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False, out=out)
else:
return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out)


def rand(*size, **kwargs):
r"""Random values in a given shape.
Create an array of the given shape and populate it with random
samples from a uniform distribution over [0, 1).
Parameters
----------
d0, d1, ..., dn : int, optional
The dimensions of the returned array, should be all positive.
If no argument is given a single Python float is returned.
Returns
-------
out : ndarray
Random values.
Examples
--------
>>> np.random.rand(3,2)
array([[ 0.14022471, 0.96360618], #random
[ 0.37601032, 0.25528411], #random
[ 0.49313049, 0.94909878]]) #random
"""
output_shape = ()
for s in size:
output_shape += (s,)
return uniform(0, 1, size=output_shape, **kwargs)
30 changes: 28 additions & 2 deletions python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from __future__ import absolute_import
from ..ndarray import numpy as _mx_nd_np


__all__ = ["randint", "uniform", "normal", "choice"]
__all__ = ["randint", "uniform", "normal", "choice", "rand"]


def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
Expand Down Expand Up @@ -231,3 +230,30 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None):
array([2, 3, 0])
"""
return _mx_nd_np.random.choice(a, size, replace, p, ctx, out)


def rand(*size, **kwargs):
r"""Random values in a given shape.
Create an array of the given shape and populate it with random
samples from a uniform distribution over [0, 1).
Parameters
----------
d0, d1, ..., dn : int, optional
The dimensions of the returned array, should be all positive.
If no argument is given a single Python float is returned.
Returns
-------
out : ndarray
Random values.
Examples
--------
>>> np.random.rand(3,2)
array([[ 0.14022471, 0.96360618], #random
[ 0.37601032, 0.25528411], #random
[ 0.49313049, 0.94909878]]) #random
"""
output_shape = ()
for s in size:
output_shape += (s,)
return _mx_nd_np.random.uniform(0, 1, size=output_shape, **kwargs)
30 changes: 28 additions & 2 deletions python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from ...context import current_context
from . import _internal as _npi


__all__ = ['randint', 'uniform', 'normal']
__all__ = ['randint', 'uniform', 'normal', 'rand']


def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
Expand Down Expand Up @@ -86,6 +85,33 @@ def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
return _npi.random_randint(low, high, shape=size, dtype=dtype, ctx=ctx, out=out)


def rand(*size, **kwargs):
r"""Random values in a given shape.
Create an array of the given shape and populate it with random
samples from a uniform distribution over [0, 1).
Parameters
----------
d0, d1, ..., dn : int, optional
The dimensions of the returned array, should be all positive.
If no argument is given a single Python float is returned.
Returns
-------
out : ndarray
Random values.
Examples
--------
>>> np.random.rand(3,2)
array([[ 0.14022471, 0.96360618], #random
[ 0.37601032, 0.25528411], #random
[ 0.49313049, 0.94909878]]) #random
"""
output_shape = ()
for s in size:
output_shape += (s,)
return uniform(0, 1, size=output_shape, **kwargs)


def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
"""Draw samples from a uniform distribution.
Expand Down
58 changes: 53 additions & 5 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,19 @@
import sys
import unittest
import numpy as _np
import platform
import mxnet as mx
import scipy.stats as ss
from mxnet import np, npx
from mxnet.gluon import HybridBlock
from mxnet.base import MXNetError
from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray
from mxnet.test_utils import check_numeric_gradient, use_np, collapse_sum_like
from common import assertRaises, with_seed
import random
import scipy.stats as ss
from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry
from mxnet.runtime import Features
from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf
from mxnet.numpy_op_signature import _get_builtin_op
from mxnet.test_utils import current_context, verify_generator, gen_buckets_probs_with_ppf
from mxnet.test_utils import is_op_runnable, has_tvm_ops
import platform


@with_seed()
Expand Down Expand Up @@ -3450,6 +3448,56 @@ def dbg(name, data):
assert_almost_equal(grad[0][iop], grad[1][iop], rtol=rtol, atol=atol)


@with_seed()
@use_np
def test_np_rand():
# Test shapes.
shapes = [
(3, 3),
(3, 4),
(0, 0),
(3, 3, 3),
(0, 0, 0),
(2, 2, 4, 3),
(2, 2, 4, 3),
(2, 0, 3, 0),
(2, 0, 2, 3)
]
dtypes = ['float16', 'float32', 'float64']
for dtype in dtypes:
for shape in shapes:
data_mx = np.random.rand(*shape, dtype=dtype)
assert data_mx.shape == shape

# Test random generator.
ctx = mx.context.current_context()
samples = 1000000
trials = 8
num_buckets = 10
lower = 0.0
upper = 1.0
for dtype in ['float16', 'float32', 'float64']:
buckets, probs = gen_buckets_probs_with_ppf(
lambda x: ss.uniform.ppf(x, lower, upper), num_buckets)
# Quantize bucket boundaries to reflect the actual dtype
# and adjust probs accordingly
buckets = np.array(buckets, dtype=dtype).tolist()
probs = [(ss.uniform.cdf(buckets[i][1], lower, upper) -
ss.uniform.cdf(buckets[i][0], lower, upper))
for i in range(num_buckets)]

def generator_mx(x): return np.random.rand(
samples, ctx=ctx, dtype=dtype).asnumpy()
verify_generator(generator=generator_mx, buckets=buckets,
probs=probs, nsamples=samples, nrepeat=trials)
generator_mx_same_seed =\
lambda x: _np.concatenate(
[np.random.rand(x // 10, ctx=ctx, dtype=dtype).asnumpy()
for _ in range(10)])
verify_generator(generator=generator_mx_same_seed, buckets=buckets,
probs=probs, nsamples=samples, nrepeat=trials)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit b949716

Please sign in to comment.