Skip to content

Commit

Permalink
[Numpy] Numpy behavior normal distribution (apache#16109)
Browse files Browse the repository at this point in the history
* normal implemented

* numpy behavior normal imlemented

* retrigger CI

* retrigger CI

* regrigger ci

* add normal parameter check

* add raise for normal

* remove dead code
  • Loading branch information
xidulu authored and larroy committed Sep 28, 2019
1 parent 4de39b3 commit d33d728
Show file tree
Hide file tree
Showing 11 changed files with 482 additions and 42 deletions.
38 changes: 21 additions & 17 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ...context import current_context
from . import _internal as _npi
from ..ndarray import NDArray
from ...base import numeric_types


__all__ = ['randint', 'uniform', 'normal', "choice"]
Expand Down Expand Up @@ -145,7 +144,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
ctx=ctx, dtype=dtype, out=out)


def normal(loc=0.0, scale=1.0, size=None, **kwargs):
def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
"""Draw random samples from a normal (Gaussian) distribution.
Samples are distributed according to a normal distribution parametrized
Expand All @@ -166,31 +165,36 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs):
Data type of output samples. Default is 'float32'
ctx : Context, optional
Device context of output. Default is current context.
out : ``ndarray``, optional
Store output to an existing ``ndarray``.
Returns
-------
out : ndarray
Drawn samples from the parameterized normal distribution.
Notes
-----
This function currently does not support ``loc`` and ``scale`` as ndarrays.
"""
dtype = kwargs.pop('dtype', None)
from ...numpy import ndarray as np_ndarray
input_type = (isinstance(loc, np_ndarray), isinstance(scale, np_ndarray))
ctx = kwargs.pop('ctx', None)
out = kwargs.pop('out', None)
if dtype is None:
dtype = 'float32'
ctx = kwargs.pop('ctx', None)
if ctx is None:
ctx = current_context()
out = kwargs.pop('out', None)
if size is None and out is None:
size = ()
if (not isinstance(loc, numeric_types)) or (not isinstance(scale, numeric_types)):
raise NotImplementedError('np.random.normal only supports loc and scale of '
'numeric types for now')
return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs)
if out is not None:
size = out.shape
if size == ():
size = None
if input_type == (True, True):
return _npi.normal(loc, scale, loc=None, scale=None, size=size,
ctx=ctx, dtype=dtype, out=out)
elif input_type == (False, True):
return _npi.normal(scale, loc=loc, scale=None, size=size,
ctx=ctx, dtype=dtype, out=out)
elif input_type == (True, False):
return _npi.normal(loc, loc=None, scale=scale, size=size,
ctx=ctx, dtype=dtype, out=out)
else:
return _npi.normal(loc=loc, scale=scale, size=size,
ctx=ctx, dtype=dtype, out=out)


def multinomial(n, pvals, size=None):
Expand Down
10 changes: 3 additions & 7 deletions python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
return _mx_nd_np.random.uniform(low, high, size=size, ctx=ctx, dtype=dtype, out=out)


def normal(loc=0.0, scale=1.0, size=None, **kwargs):
def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
"""Draw random samples from a normal (Gaussian) distribution.
Samples are distributed according to a normal distribution parametrized
Expand All @@ -130,20 +130,16 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs):
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
ctx : Context, optional
Device context of output. Default is current context.
Device context of output, default is current context.
out : ``ndarray``, optional
Store output to an existing ``ndarray``.
Returns
-------
out : ndarray
Drawn samples from the parameterized normal distribution.
Notes
-----
This function currently does not support ``loc`` and ``scale`` as ndarrays.
"""
return _mx_nd_np.random.normal(loc, scale, size, **kwargs)
return _mx_nd_np.random.normal(loc, scale, size, dtype, **kwargs)


def multinomial(n, pvals, size=None, **kwargs):
Expand Down
39 changes: 22 additions & 17 deletions python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from __future__ import absolute_import
from ...context import current_context
from . import _internal as _npi
from ...base import numeric_types


__all__ = ['randint', 'uniform', 'normal']
Expand Down Expand Up @@ -144,7 +143,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
ctx=ctx, dtype=dtype, out=out)


def normal(loc=0.0, scale=1.0, size=None, **kwargs):
def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
"""Draw random samples from a normal (Gaussian) distribution.
Samples are distributed according to a normal distribution parametrized
Expand All @@ -162,34 +161,40 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs):
samples are drawn. If size is `None` (default), a scalar tensor containing
a single value is returned if loc and scale are both scalars.
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
Data type of output samples. Default is 'float32'.
ctx : Context, optional
Device context of output. Default is current context.
out : ``ndarray``, optional
Store output to an existing ``ndarray``.
Returns
-------
out : _Symbol (symbol representing `mxnet.numpy.ndarray` in computational graphs)
Drawn samples from the parameterized normal distribution.
Notes
-----
This function currently does not support ``loc`` and ``scale`` as `_Symbol`s.
"""
dtype = kwargs.pop('dtype', None)
from ._symbol import _Symbol as np_symbol
input_type = (isinstance(loc, np_symbol), isinstance(scale, np_symbol))
ctx = kwargs.pop('ctx', None)
out = kwargs.pop('out', None)
if dtype is None:
dtype = 'float32'
ctx = kwargs.pop('ctx', None)
if ctx is None:
ctx = current_context()
out = kwargs.pop('out', None)
if size is None and out is None:
size = ()
if (not isinstance(loc, numeric_types)) or (not isinstance(scale, numeric_types)):
raise NotImplementedError('np.random.normal only supports loc and scale of '
'numeric types for now')
return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs)
if out is not None:
size = out.shape
if size == ():
size = None
if input_type == (True, True):
return _npi.normal(loc, scale, loc=None, scale=None, size=size,
ctx=ctx, dtype=dtype, out=out)
elif input_type == (False, True):
return _npi.normal(scale, loc=loc, scale=None, size=size,
ctx=ctx, dtype=dtype, out=out)
elif input_type == (True, False):
return _npi.normal(loc, loc=None, scale=scale, size=size,
ctx=ctx, dtype=dtype, out=out)
else:
return _npi.normal(loc=loc, scale=scale, size=size,
ctx=ctx, dtype=dtype, out=out)


def choice(a, size=None, replace=True, p=None, **kwargs):
Expand Down
43 changes: 43 additions & 0 deletions src/operator/numpy/random/dist_common.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2015 by Contributors
* \file dist_common.cc
* \brief Function definition of common functions for distributions
* \with two parameters.
*/

#include "./dist_common.h"

namespace mxnet {
namespace op {

template <>
void _copy<cpu>(float *dst, float *src) {
*dst = *src;
}

template <>
void _copy<cpu>(double *dst, double *src) {
*dst = *src;
}

} // namespace op
} // namespace mxnet
43 changes: 43 additions & 0 deletions src/operator/numpy/random/dist_common.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2015 by Contributors
* \file dist_common.cuh
* \brief Function definition of common functions for distributions
* \with two parameters.
*/

#include "./dist_common.h"

namespace mxnet {
namespace op {

template <>
void _copy<gpu>(float *dst, float *src) {
CUDA_CALL(cudaMemcpy(dst, src, sizeof(float), cudaMemcpyDeviceToHost));
}

template <>
void _copy<gpu>(double *dst, double *src) {
CUDA_CALL(cudaMemcpy(dst, src, sizeof(double), cudaMemcpyDeviceToHost));
}

} // namespace op
} // namespace mxnet
7 changes: 7 additions & 0 deletions src/operator/numpy/random/dist_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
namespace mxnet {
namespace op {

template <typename xpu>
void _copy(float *dst, float*src);

template <typename xpu>
void _copy(double *dst, double*src);


inline int FillShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape,
const mxnet::TShape &oshape, mxnet::TShape *new_lshape,
mxnet::TShape *new_rshape, mxnet::TShape *new_oshape) {
Expand Down
69 changes: 69 additions & 0 deletions src/operator/numpy/random/np_normal_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_random_op.cc
* \brief Operator for numpy sampling from normal distributions.
*/
#include "./np_normal_op.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(NumpyNormalParam);

NNVM_REGISTER_OP(_npi_normal)
.describe("Numpy behavior normal")
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyNormalParam& param = nnvm::get<NumpyNormalParam>(attrs.parsed);
int num_inputs = 2;
if (param.loc.has_value()) num_inputs -= 1;
if (param.scale.has_value()) num_inputs -= 1;
return num_inputs;
}
)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const NumpyNormalParam& param = nnvm::get<NumpyNormalParam>(attrs.parsed);
int num_inputs = 2;
if (param.loc.has_value()) num_inputs -= 1;
if (param.scale.has_value()) num_inputs -= 1;
if (num_inputs == 0) return std::vector<std::string>();
if (num_inputs == 1) return std::vector<std::string>{"input1"};
return std::vector<std::string>{"input1", "input2"};
})
.set_attr_parser(ParamParser<NumpyNormalParam>)
.set_attr<mxnet::FInferShape>("FInferShape", TwoparamsDistOpShape<NumpyNormalParam>)
.set_attr<nnvm::FInferType>("FInferType", NumpyNormalOpType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<ResourceRequest>{
ResourceRequest::kRandom, ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyNormalForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("input1", "NDArray-or-Symbol", "Source input")
.add_argument("input2", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyNormalParam::__FIELDS__());

} // namespace op
} // namespace mxnet
35 changes: 35 additions & 0 deletions src/operator/numpy/random/np_normal_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_normal_op.cu
* \brief Operator for numpy sampling from normal distributions
*/

#include "./np_normal_op.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_npi_normal)
.set_attr<FCompute>("FCompute<gpu>", NumpyNormalForward<gpu>);

} // namespace op
} // namespace mxnet
Loading

0 comments on commit d33d728

Please sign in to comment.