From 9f1136c82c8d33dcbf07337c782f80a529f560b5 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Fri, 6 Sep 2019 03:26:56 +0000 Subject: [PATCH 1/8] normal implemented --- python/mxnet/ndarray/numpy/random.py | 33 ++-- python/mxnet/numpy/random.py | 8 +- python/mxnet/symbol/numpy/random.py | 33 ++-- src/operator/numpy/random/np_normal_op.cc | 69 ++++++++ src/operator/numpy/random/np_normal_op.cu | 36 ++++ src/operator/numpy/random/np_normal_op.h | 207 ++++++++++++++++++++++ tests/python/unittest/test_numpy_op.py | 2 +- 7 files changed, 353 insertions(+), 35 deletions(-) create mode 100644 src/operator/numpy/random/np_normal_op.cc create mode 100644 src/operator/numpy/random/np_normal_op.cu create mode 100644 src/operator/numpy/random/np_normal_op.h diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 9ea2ef0f5405..41e72c7ea210 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -145,7 +145,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): ctx=ctx, dtype=dtype, out=out) -def normal(loc=0.0, scale=1.0, size=None, **kwargs): +def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None): """Draw random samples from a normal (Gaussian) distribution. Samples are distributed according to a normal distribution parametrized @@ -173,24 +173,29 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): ------- out : ndarray Drawn samples from the parameterized normal distribution. - - Notes - ----- - This function currently does not support ``loc`` and ``scale`` as ndarrays. """ - dtype = kwargs.pop('dtype', None) + from ...numpy import ndarray as np_ndarray + input_type = (isinstance(loc, np_ndarray), isinstance(scale, np_ndarray)) if dtype is None: dtype = 'float32' - ctx = kwargs.pop('ctx', None) if ctx is None: ctx = current_context() - out = kwargs.pop('out', None) - if size is None and out is None: - size = () - if (not isinstance(loc, numeric_types)) or (not isinstance(scale, numeric_types)): - raise NotImplementedError('np.random.normal only supports loc and scale of ' - 'numeric types for now') - return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs) + if out is not None: + size = out.shape + if size == (): + size = None + if input_type == (True, True): + return _npi.normal(loc, scale, loc=None, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (False, True): + return _npi.normal(scale, loc=loc, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (True, False): + return _npi.normal(loc, loc=None, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + else: + return _npi.normal(loc=loc, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) def multinomial(n, pvals, size=None): diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index bd534f0b6d97..573877b8a6eb 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -110,7 +110,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): return _mx_nd_np.random.uniform(low, high, size=size, ctx=ctx, dtype=dtype, out=out) -def normal(loc=0.0, scale=1.0, size=None, **kwargs): +def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs): """Draw random samples from a normal (Gaussian) distribution. Samples are distributed according to a normal distribution parametrized @@ -138,12 +138,8 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): ------- out : ndarray Drawn samples from the parameterized normal distribution. - - Notes - ----- - This function currently does not support ``loc`` and ``scale`` as ndarrays. """ - return _mx_nd_np.random.normal(loc, scale, size, **kwargs) + return _mx_nd_np.random.normal(loc, scale, size, dtype, **kwargs) def multinomial(n, pvals, size=None, **kwargs): diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index c5b8e1dc4906..a4064e08d994 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -144,7 +144,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): ctx=ctx, dtype=dtype, out=out) -def normal(loc=0.0, scale=1.0, size=None, **kwargs): +def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None): """Draw random samples from a normal (Gaussian) distribution. Samples are distributed according to a normal distribution parametrized @@ -172,21 +172,26 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): ------- out : _Symbol (symbol representing `mxnet.numpy.ndarray` in computational graphs) Drawn samples from the parameterized normal distribution. - - Notes - ----- - This function currently does not support ``loc`` and ``scale`` as `_Symbol`s. """ - dtype = kwargs.pop('dtype', None) + from ._symbol import _Symbol as np_symbol + input_type = (isinstance(loc, np_symbol), isinstance(scale, np_symbol)) if dtype is None: dtype = 'float32' - ctx = kwargs.pop('ctx', None) if ctx is None: ctx = current_context() - out = kwargs.pop('out', None) - if size is None and out is None: - size = () - if (not isinstance(loc, numeric_types)) or (not isinstance(scale, numeric_types)): - raise NotImplementedError('np.random.normal only supports loc and scale of ' - 'numeric types for now') - return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs) + if out is not None: + size = out.shape + if size == (): + size = None + if input_type == (True, True): + return _npi.normal(loc, scale, loc=None, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (False, True): + return _npi.normal(scale, loc=loc, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (True, False): + return _npi.normal(loc, loc=None, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + else: + return _npi.normal(loc=loc, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) diff --git a/src/operator/numpy/random/np_normal_op.cc b/src/operator/numpy/random/np_normal_op.cc new file mode 100644 index 000000000000..0201bac864c8 --- /dev/null +++ b/src/operator/numpy/random/np_normal_op.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_random_op.cc + * \brief Operator for numpy sampling from normal distributions + */ +#include "./np_normal_op.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyNormalParam); + +NNVM_REGISTER_OP(_npi_normal) +.describe("numpy behavior normal") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyNormalParam& param = nnvm::get(attrs.parsed); + int num_inputs = 2; + if (param.loc.has_value()) num_inputs -= 1; + if (param.scale.has_value()) num_inputs -= 1; + return num_inputs; + } +) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyNormalParam& param = nnvm::get(attrs.parsed); + int num_inputs = 2; + if (param.loc.has_value()) num_inputs -= 1; + if (param.scale.has_value()) num_inputs -= 1; + if (num_inputs == 0) return std::vector(); + if (num_inputs == 1) return std::vector{"input1"}; + return std::vector{"input1", "input2"}; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", TwoparamsDistOpShape) +.set_attr("FInferType", NumpyNormalOpType) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyNormalForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_argument("input2", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyNormalParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_normal_op.cu b/src/operator/numpy/random/np_normal_op.cu new file mode 100644 index 000000000000..eaa8a04cb415 --- /dev/null +++ b/src/operator/numpy/random/np_normal_op.cu @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_normal_op.cu + * \brief Operator for numpy sampling from normal distributions + */ + + #include "./np_normal_op.h" + + namespace mxnet { + namespace op { + + NNVM_REGISTER_OP(_npi_normal) + .set_attr("FCompute", NumpyNormalForward); + + } // namespace op + } // namespace mxnet + \ No newline at end of file diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h new file mode 100644 index 000000000000..863636dfac76 --- /dev/null +++ b/src/operator/numpy/random/np_normal_op.h @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_normal_op.h + * \brief Operator for numpy sampling from normal distributions + */ +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_NORMAL_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_NORMAL_OP_H_ + +#include +#include +#include +#include +#include +#include "../../elemwise_op_common.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +struct NumpyNormalParam : public dmlc::Parameter { + dmlc::optional loc; + dmlc::optional scale; + std::string ctx; + int dtype; + dmlc::optional> size; + DMLC_DECLARE_PARAMETER(NumpyNormalParam) { + DMLC_DECLARE_FIELD(loc); + DMLC_DECLARE_FIELD(scale); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional>()) + .describe( + "Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + DMLC_DECLARE_FIELD(ctx).set_default("cpu").describe( + "Context of output, in format [cpu|gpu|cpu_pinned](n)." + " Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .set_default(mshadow::kFloat32) + .describe( + "DType of the output in case this can't be inferred. " + "Defaults to float32 if not defined (dtype=None)."); + } +}; + +inline bool NumpyNormalOpType(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const NumpyNormalParam ¶m = nnvm::get(attrs.parsed); + int otype = param.dtype; + if (otype != -1) { + (*out_attrs)[0] = otype; + } else { + (*out_attrs)[0] = mshadow::kFloat32; + } + return true; +} + +namespace mxnet_op { +template +struct normal_kernel { + MSHADOW_XINLINE static void Map(index_t i, const Shape &lstride, + const Shape &hstride, + const Shape &oshape, IType *loc, + IType *scale, float *normals, OType *out) { + Shape coord = unravel(i, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto hidx = static_cast(dot(coord, hstride)); + IType loc_value = loc[lidx]; + IType scale_value = scale[hidx]; + out[i] = loc_value + normals[i] * scale_value; + } +}; + +template +struct normal_one_scalar_kernel { + MSHADOW_XINLINE static void Map(index_t i, int scalar_pos, + const Shape &stride, + const Shape &oshape, IType *array, + float scalar, float *normals, OType *out) { + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + IType loc_value; + IType scale_value; + if (scalar_pos == 0) { + loc_value = scalar; + scale_value = array[idx]; + } else { + loc_value = array[idx]; + scale_value = scalar; + } + out[i] = loc_value + normals[i] * scale_value; + } +}; + +template +struct normal_two_scalar_kernel { + MSHADOW_XINLINE static void Map(index_t i, float loc, float scale, + float *normals, OType *out) { + out[i] = loc + normals[i] * scale; + } +}; +} // namespace mxnet_op + +template +void NumpyNormalForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyNormalParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(outputs.size(), 1); + Stream *s = ctx.get_stream(); + + // Generate base random number. + Random *prnd = ctx.requested[0].get_random(s); + Tensor normal_tensor = + ctx.requested[1].get_space_typed(Shape1(outputs[0].Size()), + s); + prnd->SampleGaussian(&normal_tensor, 0, 1); + mxnet::TShape new_lshape, new_hshape, new_oshape; + + // [scalar scalar] case + if (inputs.size() == 0U) { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.loc.value(), param.scale.value(), + normal_tensor.dptr_, outputs[0].dptr()); + }); + } else if (inputs.size() == 1U) { + // [scalar tensor], [tensor scalar] case + int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, + &new_lshape, &new_lshape, &new_oshape); + int scalar_pos; + float scalar_value; + // int type_flag = param.t; + if (param.loc.has_value()) { + scalar_pos = 0; + scalar_value = param.loc.value(); + } else { + scalar_pos = 1; + scalar_value = param.scale.value(); + } + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + Shape oshape = new_oshape.get(); + Shape stride = calc_stride(new_lshape.get()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), scalar_pos, stride, oshape, + inputs[0].dptr(), scalar_value, normal_tensor.dptr_, + outputs[0].dptr()); + }); + }); + }); + } else if (inputs.size() == 2U) { + // [tensor tensor] case + int ndim = FillShape(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_hshape, &new_oshape); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + Shape oshape = new_oshape.get(); + Shape lstride = calc_stride(new_lshape.get()); + Shape hstride = calc_stride(new_hshape.get()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), lstride, hstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), + normal_tensor.dptr_, outputs[0].dptr()); + }); + }); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 51aa9054451c..6ed067fe33b5 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1554,7 +1554,7 @@ def test_np_random(): shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] dtypes = ['float16', 'float32', 'float64'] op_names = ['uniform', 'normal'] - op_names = ['normal'] + # op_names = ['normal'] for shape in shapes: for dtype in dtypes: for op_name in op_names: From dadec55c09dde0fb34afbff8e511543c54024f34 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Fri, 6 Sep 2019 06:06:48 +0000 Subject: [PATCH 2/8] numpy behavior normal imlemented --- python/mxnet/ndarray/numpy/random.py | 15 +++++++-------- python/mxnet/symbol/numpy/random.py | 15 +++++++-------- src/operator/numpy/random/np_normal_op.cu | 19 +++++++++---------- src/operator/numpy/random/np_normal_op.h | 2 +- src/operator/numpy/random/np_uniform_op.cc | 6 ++++++ 5 files changed, 30 insertions(+), 27 deletions(-) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 41e72c7ea210..86310e4e065e 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -21,7 +21,6 @@ from ...context import current_context from . import _internal as _npi from ..ndarray import NDArray -from ...base import numeric_types __all__ = ['randint', 'uniform', 'normal'] @@ -145,7 +144,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): ctx=ctx, dtype=dtype, out=out) -def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None): +def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs): """Draw random samples from a normal (Gaussian) distribution. Samples are distributed according to a normal distribution parametrized @@ -166,8 +165,6 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None): Data type of output samples. Default is 'float32' ctx : Context, optional Device context of output. Default is current context. - out : ``ndarray``, optional - Store output to an existing ``ndarray``. Returns ------- @@ -176,6 +173,8 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None): """ from ...numpy import ndarray as np_ndarray input_type = (isinstance(loc, np_ndarray), isinstance(scale, np_ndarray)) + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) if dtype is None: dtype = 'float32' if ctx is None: @@ -186,16 +185,16 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None): size = None if input_type == (True, True): return _npi.normal(loc, scale, loc=None, scale=None, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) elif input_type == (False, True): return _npi.normal(scale, loc=loc, scale=None, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) elif input_type == (True, False): return _npi.normal(loc, loc=None, scale=scale, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) else: return _npi.normal(loc=loc, scale=scale, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) def multinomial(n, pvals, size=None): diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index a4064e08d994..754b9dac6294 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -20,7 +20,6 @@ from __future__ import absolute_import from ...context import current_context from . import _internal as _npi -from ...base import numeric_types __all__ = ['randint', 'uniform', 'normal'] @@ -144,7 +143,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): ctx=ctx, dtype=dtype, out=out) -def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None): +def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs): """Draw random samples from a normal (Gaussian) distribution. Samples are distributed according to a normal distribution parametrized @@ -165,8 +164,6 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None): Data type of output samples. Default is 'float32' ctx : Context, optional Device context of output. Default is current context. - out : ``ndarray``, optional - Store output to an existing ``ndarray``. Returns ------- @@ -175,6 +172,8 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None): """ from ._symbol import _Symbol as np_symbol input_type = (isinstance(loc, np_symbol), isinstance(scale, np_symbol)) + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) if dtype is None: dtype = 'float32' if ctx is None: @@ -185,13 +184,13 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None): size = None if input_type == (True, True): return _npi.normal(loc, scale, loc=None, scale=None, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) elif input_type == (False, True): return _npi.normal(scale, loc=loc, scale=None, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) elif input_type == (True, False): return _npi.normal(loc, loc=None, scale=scale, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) else: return _npi.normal(loc=loc, scale=scale, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) diff --git a/src/operator/numpy/random/np_normal_op.cu b/src/operator/numpy/random/np_normal_op.cu index eaa8a04cb415..6cdf9f9c4eae 100644 --- a/src/operator/numpy/random/np_normal_op.cu +++ b/src/operator/numpy/random/np_normal_op.cu @@ -23,14 +23,13 @@ * \brief Operator for numpy sampling from normal distributions */ - #include "./np_normal_op.h" +#include "./np_normal_op.h" - namespace mxnet { - namespace op { - - NNVM_REGISTER_OP(_npi_normal) - .set_attr("FCompute", NumpyNormalForward); - - } // namespace op - } // namespace mxnet - \ No newline at end of file +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_normal) + .set_attr("FCompute", NumpyNormalForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h index 863636dfac76..2ab693d08b56 100644 --- a/src/operator/numpy/random/np_normal_op.h +++ b/src/operator/numpy/random/np_normal_op.h @@ -204,4 +204,4 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_NORMAL_OP_H_ diff --git a/src/operator/numpy/random/np_uniform_op.cc b/src/operator/numpy/random/np_uniform_op.cc index 394626d07596..7307b7744d5e 100644 --- a/src/operator/numpy/random/np_uniform_op.cc +++ b/src/operator/numpy/random/np_uniform_op.cc @@ -43,6 +43,12 @@ NNVM_REGISTER_OP(_npi_uniform) .set_num_outputs(1) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { + const NumpyUniformParam& param = nnvm::get(attrs.parsed); + int num_inputs = 2; + if (param.low.has_value()) num_inputs -= 1; + if (param.high.has_value()) num_inputs -= 1; + if (num_inputs == 0) return std::vector(); + if (num_inputs == 1) return std::vector{"input1"}; return std::vector{"input1", "input2"}; }) .set_attr_parser(ParamParser) From ef6266c2e00a75e3a1992741fdf23c6f7fb9f29a Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Mon, 9 Sep 2019 01:07:33 +0000 Subject: [PATCH 3/8] retrigger CI --- src/operator/numpy/random/np_normal_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/numpy/random/np_normal_op.cc b/src/operator/numpy/random/np_normal_op.cc index 0201bac864c8..33c8884b0309 100644 --- a/src/operator/numpy/random/np_normal_op.cc +++ b/src/operator/numpy/random/np_normal_op.cc @@ -20,7 +20,7 @@ /*! * Copyright (c) 2019 by Contributors * \file np_random_op.cc - * \brief Operator for numpy sampling from normal distributions + * \brief Operator for numpy sampling from normal distributions. */ #include "./np_normal_op.h" From 6a3d9ccd17de4385bfa8181b5ccd5620fb5532ea Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 10 Sep 2019 06:23:46 +0000 Subject: [PATCH 4/8] retrigger CI --- python/mxnet/symbol/numpy/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index c48fd8134931..a5ccbc1d1995 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -161,7 +161,7 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs): samples are drawn. If size is `None` (default), a scalar tensor containing a single value is returned if loc and scale are both scalars. dtype : {'float16', 'float32', 'float64'}, optional - Data type of output samples. Default is 'float32' + Data type of output samples. Default is 'float32'. ctx : Context, optional Device context of output. Default is current context. From edc23a59a0f77b102f5f3bf8600bcf5ad1e9b626 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 11 Sep 2019 08:51:50 +0000 Subject: [PATCH 5/8] regrigger ci --- python/mxnet/numpy/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 2330b31e486a..f89b8cac9871 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -130,7 +130,7 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs): dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' ctx : Context, optional - Device context of output. Default is current context. + Device context of output, default is current context. out : ``ndarray``, optional Store output to an existing ``ndarray``. From b4900ed0ac97cf9ca447d06b207b66668c35ed74 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 17 Sep 2019 11:34:55 +0000 Subject: [PATCH 6/8] add normal parameter check --- src/operator/numpy/random/dist_common.cc | 43 +++++++++++++++++++++++ src/operator/numpy/random/dist_common.cu | 43 +++++++++++++++++++++++ src/operator/numpy/random/dist_common.h | 7 ++++ src/operator/numpy/random/np_normal_op.cc | 2 +- src/operator/numpy/random/np_normal_op.h | 40 ++++++++++++++++++--- 5 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 src/operator/numpy/random/dist_common.cc create mode 100644 src/operator/numpy/random/dist_common.cu diff --git a/src/operator/numpy/random/dist_common.cc b/src/operator/numpy/random/dist_common.cc new file mode 100644 index 000000000000..f9c5b3d660b3 --- /dev/null +++ b/src/operator/numpy/random/dist_common.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file dist_common.cc + * \brief Function definition of common functions for distributions + * \with two parameters. + */ + +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +template <> +void _copy(float *dst, float *src) { + *dst = *src; +} + +template <> +void _copy(double *dst, double *src) { + *dst = *src; +} + +} // namespace op +} // namespace mxnet \ No newline at end of file diff --git a/src/operator/numpy/random/dist_common.cu b/src/operator/numpy/random/dist_common.cu new file mode 100644 index 000000000000..c74ee7578270 --- /dev/null +++ b/src/operator/numpy/random/dist_common.cu @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file dist_common.cuh + * \brief Function definition of common functions for distributions + * \with two parameters. + */ + +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +template <> +void _copy(float *dst, float *src) { +CUDA_CALL(cudaMemcpy(dst, src, sizeof(float), cudaMemcpyDeviceToHost)); +} + +template <> +void _copy(double *dst, double *src) { +CUDA_CALL(cudaMemcpy(dst, src, sizeof(double), cudaMemcpyDeviceToHost)); +} + +} // namespace op +} // namespace mxnet \ No newline at end of file diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h index 42d7507f7e74..a3c48f5b4d5c 100644 --- a/src/operator/numpy/random/dist_common.h +++ b/src/operator/numpy/random/dist_common.h @@ -41,6 +41,13 @@ namespace mxnet { namespace op { +template +void _copy(float *dst, float*src); + +template +void _copy(double *dst, double*src); + + inline int FillShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape, const mxnet::TShape &oshape, mxnet::TShape *new_lshape, mxnet::TShape *new_rshape, mxnet::TShape *new_oshape) { diff --git a/src/operator/numpy/random/np_normal_op.cc b/src/operator/numpy/random/np_normal_op.cc index 33c8884b0309..1a5df95ff9e8 100644 --- a/src/operator/numpy/random/np_normal_op.cc +++ b/src/operator/numpy/random/np_normal_op.cc @@ -30,7 +30,7 @@ namespace op { DMLC_REGISTER_PARAMETER(NumpyNormalParam); NNVM_REGISTER_OP(_npi_normal) -.describe("numpy behavior normal") +.describe("Numpy behavior normal") .set_num_inputs( [](const nnvm::NodeAttrs& attrs) { const NumpyNormalParam& param = nnvm::get(attrs.parsed); diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h index 2ab693d08b56..4d79650d6df8 100644 --- a/src/operator/numpy/random/np_normal_op.h +++ b/src/operator/numpy/random/np_normal_op.h @@ -36,6 +36,7 @@ #include "../../operator_common.h" #include "../../tensor/elemwise_binary_broadcast_op.h" #include "./dist_common.h" +#include namespace mxnet { namespace op { @@ -126,6 +127,16 @@ struct normal_two_scalar_kernel { out[i] = loc + normals[i] * scale; } }; + +template +struct check_legal_scale_kernel { + MSHADOW_XINLINE static void Map(index_t i, IType *scalar, float* flag) { + if (scalar[i] < 0) { + flag[0] = -1.0; + } + } +}; + } // namespace mxnet_op template @@ -142,14 +153,20 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, // Generate base random number. Random *prnd = ctx.requested[0].get_random(s); - Tensor normal_tensor = - ctx.requested[1].get_space_typed(Shape1(outputs[0].Size()), - s); - prnd->SampleGaussian(&normal_tensor, 0, 1); + index_t output_len = outputs[0].Size(); + Tensor workspace = + ctx.requested[1].get_space_typed(Shape1(output_len + 1), s); + Tensor normal_tensor = workspace.Slice(0, output_len); + Tensor indicator_device = workspace.Slice(output_len, output_len + 1); + float indicator_host = 1.0; + float *indicator_device_ptr = indicator_device.dptr_; + prnd->SampleGaussian(&normal_tensor, 0.0, 1.0); mxnet::TShape new_lshape, new_hshape, new_oshape; // [scalar scalar] case if (inputs.size() == 0U) { + // printf("scale value:%f", param.scale.value()); + CHECK_GE(param.scale.value(), 0.0) << "ValueError: scale < 0"; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { Kernel, xpu>::Launch( s, outputs[0].Size(), param.loc.value(), param.scale.value(), @@ -161,13 +178,20 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, &new_lshape, &new_lshape, &new_oshape); int scalar_pos; float scalar_value; - // int type_flag = param.t; if (param.loc.has_value()) { scalar_pos = 0; scalar_value = param.loc.value(); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + Kernel, xpu>::Launch( + s, inputs[0].Size(), inputs[0].dptr(), indicator_device_ptr + ); + }); + _copy(&indicator_host, indicator_device_ptr); + CHECK_GE(indicator_host, 0.0) << "ValueError: scale < 0"; } else { scalar_pos = 1; scalar_value = param.scale.value(); + CHECK_GE(scalar_value, 0.0) << "ValueError: scale < 0"; } MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { @@ -183,6 +207,12 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, }); } else if (inputs.size() == 2U) { // [tensor tensor] case + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + Kernel, xpu>::Launch( + s, inputs[1].Size(), inputs[1].dptr(), indicator_device_ptr); + }); + _copy(&indicator_host, indicator_device_ptr); + CHECK_GE(indicator_host, 0.0) << "ValueError: scale < 0"; int ndim = FillShape(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_hshape, &new_oshape); MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { From 9ba0a68d937611eb22a0fffba7363c02bed8d642 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 18 Sep 2019 05:26:53 +0000 Subject: [PATCH 7/8] add raise for normal --- src/operator/numpy/random/dist_common.cc | 2 +- src/operator/numpy/random/dist_common.cu | 2 +- src/operator/numpy/random/np_normal_op.h | 8 ++------ 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/operator/numpy/random/dist_common.cc b/src/operator/numpy/random/dist_common.cc index f9c5b3d660b3..925565654a11 100644 --- a/src/operator/numpy/random/dist_common.cc +++ b/src/operator/numpy/random/dist_common.cc @@ -40,4 +40,4 @@ void _copy(double *dst, double *src) { } } // namespace op -} // namespace mxnet \ No newline at end of file +} // namespace mxnet diff --git a/src/operator/numpy/random/dist_common.cu b/src/operator/numpy/random/dist_common.cu index c74ee7578270..7dde0123e099 100644 --- a/src/operator/numpy/random/dist_common.cu +++ b/src/operator/numpy/random/dist_common.cu @@ -40,4 +40,4 @@ CUDA_CALL(cudaMemcpy(dst, src, sizeof(double), cudaMemcpyDeviceToHost)); } } // namespace op -} // namespace mxnet \ No newline at end of file +} // namespace mxnet diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h index 4d79650d6df8..678f0ec2fb47 100644 --- a/src/operator/numpy/random/np_normal_op.h +++ b/src/operator/numpy/random/np_normal_op.h @@ -25,7 +25,6 @@ #ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_NORMAL_OP_H_ #define MXNET_OPERATOR_NUMPY_RANDOM_NP_NORMAL_OP_H_ -#include #include #include #include @@ -36,7 +35,6 @@ #include "../../operator_common.h" #include "../../tensor/elemwise_binary_broadcast_op.h" #include "./dist_common.h" -#include namespace mxnet { namespace op { @@ -154,7 +152,7 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, // Generate base random number. Random *prnd = ctx.requested[0].get_random(s); index_t output_len = outputs[0].Size(); - Tensor workspace = + Tensor workspace = ctx.requested[1].get_space_typed(Shape1(output_len + 1), s); Tensor normal_tensor = workspace.Slice(0, output_len); Tensor indicator_device = workspace.Slice(output_len, output_len + 1); @@ -165,7 +163,6 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, // [scalar scalar] case if (inputs.size() == 0U) { - // printf("scale value:%f", param.scale.value()); CHECK_GE(param.scale.value(), 0.0) << "ValueError: scale < 0"; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { Kernel, xpu>::Launch( @@ -183,8 +180,7 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, scalar_value = param.loc.value(); MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { Kernel, xpu>::Launch( - s, inputs[0].Size(), inputs[0].dptr(), indicator_device_ptr - ); + s, inputs[0].Size(), inputs[0].dptr(), indicator_device_ptr); }); _copy(&indicator_host, indicator_device_ptr); CHECK_GE(indicator_host, 0.0) << "ValueError: scale < 0"; From 2bf54d46e13e61dc47890a2f672862f571b61d69 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 18 Sep 2019 08:01:46 +0000 Subject: [PATCH 8/8] remove dead code --- tests/python/unittest/test_numpy_op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b53e4c03080d..62a0b6fb50dd 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1554,7 +1554,6 @@ def test_np_random(): shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] dtypes = ['float16', 'float32', 'float64'] op_names = ['uniform', 'normal'] - # op_names = ['normal'] for shape in shapes: for dtype in dtypes: for op_name in op_names: