Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add np.random.normal and npx.seed
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Aug 28, 2019
1 parent 656dc57 commit 941aa64
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 7 deletions.
52 changes: 49 additions & 3 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from __future__ import absolute_import
from ...context import current_context
from . import _internal as _npi
from ...base import numeric_types


__all__ = ['randint', 'uniform']
__all__ = ['randint', 'uniform', 'normal']


def randint(low, high=None, size=None, dtype=None, **kwargs):
Expand Down Expand Up @@ -141,5 +142,50 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
return _npi.uniform(low=low, high=high, size=size,
ctx=ctx, dtype=dtype, out=out)

raise ValueError(
"Distribution parameters must be either mxnet.numpy.ndarray or numbers")

def normal(loc=0.0, scale=1.0, size=None, **kwargs):
"""Draw random samples from a normal (Gaussian) distribution.
Samples are distributed according to a normal distribution parametrized
by *loc* (mean) and *scale* (standard deviation).
Parameters
----------
loc : float, optional
Mean (centre) of the distribution.
scale : float, optional
Standard deviation (spread or "width") of the distribution.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k`
samples are drawn. If size is `None` (default), a scalar tensor containing
a single value is returned if loc and scale are both scalars.
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
ctx : Context, optional
Device context of output. Default is current context.
out : ``ndarray``, optional
Store output to an existing ``ndarray``.
Returns
-------
out : ndarray
Drawn samples from the parameterized normal distribution.
Notes
-----
This function currently does not support ``loc`` and ``scale`` as ndarrays.
"""
dtype = kwargs.pop('dtype', None)
if dtype is None:
dtype = 'float32'
ctx = kwargs.pop('ctx', None)
if ctx is None:
ctx = current_context()
out = kwargs.pop('out', None)
if size is None and out is None:
size = ()
if (not isinstance(loc, numeric_types)) or (not isinstance(scale, numeric_types)):
raise NotImplementedError('np.random.normal only supports loc and scale of '
'numeric types for now')
return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs)
38 changes: 37 additions & 1 deletion python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..ndarray import numpy as _mx_nd_np


__all__ = ["randint", "uniform"]
__all__ = ["randint", "uniform", "normal"]


def randint(low, high=None, size=None, dtype=None, **kwargs):
Expand Down Expand Up @@ -108,3 +108,39 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
Drawn samples from the parameterized uniform distribution.
"""
return _mx_nd_np.random.uniform(low, high, size=size, ctx=ctx, dtype=dtype, out=out)


def normal(loc=0.0, scale=1.0, size=None, **kwargs):
"""Draw random samples from a normal (Gaussian) distribution.
Samples are distributed according to a normal distribution parametrized
by *loc* (mean) and *scale* (standard deviation).
Parameters
----------
loc : float, optional
Mean (centre) of the distribution.
scale : float, optional
Standard deviation (spread or "width") of the distribution.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k`
samples are drawn. If size is `None` (default), a scalar tensor containing
a single value is returned if loc and scale are both scalars.
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
ctx : Context, optional
Device context of output. Default is current context.
out : ``ndarray``, optional
Store output to an existing ``ndarray``.
Returns
-------
out : ndarray
Drawn samples from the parameterized normal distribution.
Notes
-----
This function currently does not support ``loc`` and ``scale`` as ndarrays.
"""
return _mx_nd_np.random.normal(loc, scale, size, **kwargs)
1 change: 1 addition & 0 deletions python/mxnet/numpy_extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@
from ..util import is_np_shape, is_np_array, set_np, reset_np
from ..ndarray import waitall
from .utils import * # pylint: disable=wildcard-import
from .random import * # pylint: disable=wildcard-import

__all__ = []
74 changes: 74 additions & 0 deletions python/mxnet/numpy_extension/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Namespace for ops used in imperative programming."""

from __future__ import absolute_import
from .. import random as _mx_rand


__all__ = ['seed']


def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name
"""Seeds the random number generators in MXNet.
This affects the behavior of modules in MXNet that uses random number generators,
like the dropout operator and `ndarray`'s random sampling operators.
Parameters
----------
seed : int
The random number seed.
ctx : Context
The device context of the generator. The default is "all" which means seeding random
number generators of all devices.
Notes
-----
Random number generators in MXNet are device specific.
`mx.random.seed(seed_state)` sets the state of each generator using `seed_state` and the
device id. Therefore, random numbers generated from different devices can be different
even if they are seeded using the same seed.
To produce identical random number sequences independent of the device id,
set optional `ctx` argument. This produces the same sequence of random numbers independent
of the device id, but the sequence can be different on different kind of devices as MXNet's
random number generators for CPU and GPU use different algorithms.
Example
-------
>>> from mxnet import np, npx
>>> npx.set_np()
>>> npx.random.seed(0)
>>> np.random.uniform()
array(0.5488135)
>>> npx.random.seed(128)
>>> np.random.uniform()
array(0.03812965)
>>> npx.random.seed(128)
>>> np.random.uniform()
array(0.03812965)
>>> npx.random.seed(128)
>>> np.random.uniform(ctx=npx.gpu(0))
array(0.9894903, ctx=gpu(0))
>>> npx.random.seed(128)
>>> np.random.uniform(ctx=npx.gpu(0))
array(0.9894903, ctx=gpu(0))
"""
_mx_rand.seed(seed_state=seed, ctx=ctx)
52 changes: 49 additions & 3 deletions python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from __future__ import absolute_import
from ...context import current_context
from . import _internal as _npi
from ...base import numeric_types


__all__ = ['randint', 'uniform']
__all__ = ['randint', 'uniform', 'normal']


def randint(low, high=None, size=None, dtype=None, **kwargs):
Expand Down Expand Up @@ -142,5 +143,50 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
return _npi.uniform(low=low, high=high, size=size,
ctx=ctx, dtype=dtype, out=out)

raise ValueError(
"Distribution parameters must be either mxnet.numpy.ndarray or numbers")

def normal(loc=0.0, scale=1.0, size=None, **kwargs):
"""Draw random samples from a normal (Gaussian) distribution.
Samples are distributed according to a normal distribution parametrized
by *loc* (mean) and *scale* (standard deviation).
Parameters
----------
loc : float, optional
Mean (centre) of the distribution.
scale : float, optional
Standard deviation (spread or "width") of the distribution.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k`
samples are drawn. If size is `None` (default), a scalar tensor containing
a single value is returned if loc and scale are both scalars.
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
ctx : Context, optional
Device context of output. Default is current context.
out : ``ndarray``, optional
Store output to an existing ``ndarray``.
Returns
-------
out : _Symbol (symbol representing `mxnet.numpy.ndarray` in computational graphs)
Drawn samples from the parameterized normal distribution.
Notes
-----
This function currently does not support ``loc`` and ``scale`` as `_Symbol`s.
"""
dtype = kwargs.pop('dtype', None)
if dtype is None:
dtype = 'float32'
ctx = kwargs.pop('ctx', None)
if ctx is None:
ctx = current_context()
out = kwargs.pop('out', None)
if size is None and out is None:
size = ()
if (not isinstance(loc, numeric_types)) or (not isinstance(scale, numeric_types)):
raise NotImplementedError('np.random.normal only supports loc and scale of '
'numeric types for now')
return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs)
34 changes: 34 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_index.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_broadcast_reduce_op_index.cu
* \brief GPU Implementation of broadcast and reduce functions based on index.
*/
#include "./np_broadcast_reduce_op.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_npi_argmax)
.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::maximum>);

} // namespace op
} // namespace mxnet
47 changes: 47 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,53 @@ def hybrid_forward(self, F, x):
assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False)


@with_seed()
@use_np
def test_np_random():
shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None]
dtypes = ['float16', 'float32', 'float64']
op_names = ['uniform', 'normal']
op_names = ['normal']
for shape in shapes:
for dtype in dtypes:
for op_name in op_names:
print('-------------------------------')
print(op_name)
print(shape)
print(dtype)
op = getattr(np.random, op_name, None)
assert op is not None
out = op(size=shape, dtype=dtype)
expected_shape = shape
if not isinstance(shape, tuple):
expected_shape = () if shape is None else (shape,)
assert out.shape == expected_shape

class TestRandom(HybridBlock):
def __init__(self, shape, op_name):
super(TestRandom, self).__init__()
self._shape = shape
self._op_name = op_name

def hybrid_forward(self, F, x):
op = getattr(F.np.random, self._op_name, None)
assert op is not None
return x + op(size=shape)

x = np.ones(())
for op_name in op_names:
for shape in shapes:
for hybridize in [False, True]:
net = TestRandom(shape, op_name)
if hybridize:
net.hybridize()
out = net(x)
expected_shape = shape
if not isinstance(shape, tuple):
expected_shape = () if shape is None else (shape,)
assert out.shape == expected_shape


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 941aa64

Please sign in to comment.