From d33d7286517a2fbf44bf2716a6a1a74e7ca4a5f8 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Thu, 19 Sep 2019 12:00:15 +0800 Subject: [PATCH] [Numpy] Numpy behavior normal distribution (#16109) * normal implemented * numpy behavior normal imlemented * retrigger CI * retrigger CI * regrigger ci * add normal parameter check * add raise for normal * remove dead code --- python/mxnet/ndarray/numpy/random.py | 38 ++-- python/mxnet/numpy/random.py | 10 +- python/mxnet/symbol/numpy/random.py | 39 ++-- src/operator/numpy/random/dist_common.cc | 43 ++++ src/operator/numpy/random/dist_common.cu | 43 ++++ src/operator/numpy/random/dist_common.h | 7 + src/operator/numpy/random/np_normal_op.cc | 69 ++++++ src/operator/numpy/random/np_normal_op.cu | 35 ++++ src/operator/numpy/random/np_normal_op.h | 233 +++++++++++++++++++++ src/operator/numpy/random/np_uniform_op.cc | 6 + tests/python/unittest/test_numpy_op.py | 1 - 11 files changed, 482 insertions(+), 42 deletions(-) create mode 100644 src/operator/numpy/random/dist_common.cc create mode 100644 src/operator/numpy/random/dist_common.cu create mode 100644 src/operator/numpy/random/np_normal_op.cc create mode 100644 src/operator/numpy/random/np_normal_op.cu create mode 100644 src/operator/numpy/random/np_normal_op.h diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 9372beaf1e92..883c56e12393 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -21,7 +21,6 @@ from ...context import current_context from . import _internal as _npi from ..ndarray import NDArray -from ...base import numeric_types __all__ = ['randint', 'uniform', 'normal', "choice"] @@ -145,7 +144,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): ctx=ctx, dtype=dtype, out=out) -def normal(loc=0.0, scale=1.0, size=None, **kwargs): +def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs): """Draw random samples from a normal (Gaussian) distribution. Samples are distributed according to a normal distribution parametrized @@ -166,31 +165,36 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): Data type of output samples. Default is 'float32' ctx : Context, optional Device context of output. Default is current context. - out : ``ndarray``, optional - Store output to an existing ``ndarray``. Returns ------- out : ndarray Drawn samples from the parameterized normal distribution. - - Notes - ----- - This function currently does not support ``loc`` and ``scale`` as ndarrays. """ - dtype = kwargs.pop('dtype', None) + from ...numpy import ndarray as np_ndarray + input_type = (isinstance(loc, np_ndarray), isinstance(scale, np_ndarray)) + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) if dtype is None: dtype = 'float32' - ctx = kwargs.pop('ctx', None) if ctx is None: ctx = current_context() - out = kwargs.pop('out', None) - if size is None and out is None: - size = () - if (not isinstance(loc, numeric_types)) or (not isinstance(scale, numeric_types)): - raise NotImplementedError('np.random.normal only supports loc and scale of ' - 'numeric types for now') - return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs) + if out is not None: + size = out.shape + if size == (): + size = None + if input_type == (True, True): + return _npi.normal(loc, scale, loc=None, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (False, True): + return _npi.normal(scale, loc=loc, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (True, False): + return _npi.normal(loc, loc=None, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + else: + return _npi.normal(loc=loc, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) def multinomial(n, pvals, size=None): diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index aace767c8d55..f89b8cac9871 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -110,7 +110,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): return _mx_nd_np.random.uniform(low, high, size=size, ctx=ctx, dtype=dtype, out=out) -def normal(loc=0.0, scale=1.0, size=None, **kwargs): +def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs): """Draw random samples from a normal (Gaussian) distribution. Samples are distributed according to a normal distribution parametrized @@ -130,7 +130,7 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): dtype : {'float16', 'float32', 'float64'}, optional Data type of output samples. Default is 'float32' ctx : Context, optional - Device context of output. Default is current context. + Device context of output, default is current context. out : ``ndarray``, optional Store output to an existing ``ndarray``. @@ -138,12 +138,8 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): ------- out : ndarray Drawn samples from the parameterized normal distribution. - - Notes - ----- - This function currently does not support ``loc`` and ``scale`` as ndarrays. """ - return _mx_nd_np.random.normal(loc, scale, size, **kwargs) + return _mx_nd_np.random.normal(loc, scale, size, dtype, **kwargs) def multinomial(n, pvals, size=None, **kwargs): diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 523983bac20a..a5ccbc1d1995 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -20,7 +20,6 @@ from __future__ import absolute_import from ...context import current_context from . import _internal as _npi -from ...base import numeric_types __all__ = ['randint', 'uniform', 'normal'] @@ -144,7 +143,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): ctx=ctx, dtype=dtype, out=out) -def normal(loc=0.0, scale=1.0, size=None, **kwargs): +def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs): """Draw random samples from a normal (Gaussian) distribution. Samples are distributed according to a normal distribution parametrized @@ -162,34 +161,40 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): samples are drawn. If size is `None` (default), a scalar tensor containing a single value is returned if loc and scale are both scalars. dtype : {'float16', 'float32', 'float64'}, optional - Data type of output samples. Default is 'float32' + Data type of output samples. Default is 'float32'. ctx : Context, optional Device context of output. Default is current context. - out : ``ndarray``, optional - Store output to an existing ``ndarray``. Returns ------- out : _Symbol (symbol representing `mxnet.numpy.ndarray` in computational graphs) Drawn samples from the parameterized normal distribution. - - Notes - ----- - This function currently does not support ``loc`` and ``scale`` as `_Symbol`s. """ - dtype = kwargs.pop('dtype', None) + from ._symbol import _Symbol as np_symbol + input_type = (isinstance(loc, np_symbol), isinstance(scale, np_symbol)) + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) if dtype is None: dtype = 'float32' - ctx = kwargs.pop('ctx', None) if ctx is None: ctx = current_context() out = kwargs.pop('out', None) - if size is None and out is None: - size = () - if (not isinstance(loc, numeric_types)) or (not isinstance(scale, numeric_types)): - raise NotImplementedError('np.random.normal only supports loc and scale of ' - 'numeric types for now') - return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs) + if out is not None: + size = out.shape + if size == (): + size = None + if input_type == (True, True): + return _npi.normal(loc, scale, loc=None, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (False, True): + return _npi.normal(scale, loc=loc, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (True, False): + return _npi.normal(loc, loc=None, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + else: + return _npi.normal(loc=loc, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) def choice(a, size=None, replace=True, p=None, **kwargs): diff --git a/src/operator/numpy/random/dist_common.cc b/src/operator/numpy/random/dist_common.cc new file mode 100644 index 000000000000..925565654a11 --- /dev/null +++ b/src/operator/numpy/random/dist_common.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file dist_common.cc + * \brief Function definition of common functions for distributions + * \with two parameters. + */ + +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +template <> +void _copy(float *dst, float *src) { + *dst = *src; +} + +template <> +void _copy(double *dst, double *src) { + *dst = *src; +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/dist_common.cu b/src/operator/numpy/random/dist_common.cu new file mode 100644 index 000000000000..7dde0123e099 --- /dev/null +++ b/src/operator/numpy/random/dist_common.cu @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file dist_common.cuh + * \brief Function definition of common functions for distributions + * \with two parameters. + */ + +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +template <> +void _copy(float *dst, float *src) { +CUDA_CALL(cudaMemcpy(dst, src, sizeof(float), cudaMemcpyDeviceToHost)); +} + +template <> +void _copy(double *dst, double *src) { +CUDA_CALL(cudaMemcpy(dst, src, sizeof(double), cudaMemcpyDeviceToHost)); +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h index 42d7507f7e74..a3c48f5b4d5c 100644 --- a/src/operator/numpy/random/dist_common.h +++ b/src/operator/numpy/random/dist_common.h @@ -41,6 +41,13 @@ namespace mxnet { namespace op { +template +void _copy(float *dst, float*src); + +template +void _copy(double *dst, double*src); + + inline int FillShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape, const mxnet::TShape &oshape, mxnet::TShape *new_lshape, mxnet::TShape *new_rshape, mxnet::TShape *new_oshape) { diff --git a/src/operator/numpy/random/np_normal_op.cc b/src/operator/numpy/random/np_normal_op.cc new file mode 100644 index 000000000000..1a5df95ff9e8 --- /dev/null +++ b/src/operator/numpy/random/np_normal_op.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_random_op.cc + * \brief Operator for numpy sampling from normal distributions. + */ +#include "./np_normal_op.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyNormalParam); + +NNVM_REGISTER_OP(_npi_normal) +.describe("Numpy behavior normal") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyNormalParam& param = nnvm::get(attrs.parsed); + int num_inputs = 2; + if (param.loc.has_value()) num_inputs -= 1; + if (param.scale.has_value()) num_inputs -= 1; + return num_inputs; + } +) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyNormalParam& param = nnvm::get(attrs.parsed); + int num_inputs = 2; + if (param.loc.has_value()) num_inputs -= 1; + if (param.scale.has_value()) num_inputs -= 1; + if (num_inputs == 0) return std::vector(); + if (num_inputs == 1) return std::vector{"input1"}; + return std::vector{"input1", "input2"}; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", TwoparamsDistOpShape) +.set_attr("FInferType", NumpyNormalOpType) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyNormalForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_argument("input2", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyNormalParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_normal_op.cu b/src/operator/numpy/random/np_normal_op.cu new file mode 100644 index 000000000000..6cdf9f9c4eae --- /dev/null +++ b/src/operator/numpy/random/np_normal_op.cu @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_normal_op.cu + * \brief Operator for numpy sampling from normal distributions + */ + +#include "./np_normal_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_normal) + .set_attr("FCompute", NumpyNormalForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h new file mode 100644 index 000000000000..678f0ec2fb47 --- /dev/null +++ b/src/operator/numpy/random/np_normal_op.h @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_normal_op.h + * \brief Operator for numpy sampling from normal distributions + */ +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_NORMAL_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_NORMAL_OP_H_ + +#include +#include +#include +#include +#include "../../elemwise_op_common.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +struct NumpyNormalParam : public dmlc::Parameter { + dmlc::optional loc; + dmlc::optional scale; + std::string ctx; + int dtype; + dmlc::optional> size; + DMLC_DECLARE_PARAMETER(NumpyNormalParam) { + DMLC_DECLARE_FIELD(loc); + DMLC_DECLARE_FIELD(scale); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional>()) + .describe( + "Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + DMLC_DECLARE_FIELD(ctx).set_default("cpu").describe( + "Context of output, in format [cpu|gpu|cpu_pinned](n)." + " Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .set_default(mshadow::kFloat32) + .describe( + "DType of the output in case this can't be inferred. " + "Defaults to float32 if not defined (dtype=None)."); + } +}; + +inline bool NumpyNormalOpType(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const NumpyNormalParam ¶m = nnvm::get(attrs.parsed); + int otype = param.dtype; + if (otype != -1) { + (*out_attrs)[0] = otype; + } else { + (*out_attrs)[0] = mshadow::kFloat32; + } + return true; +} + +namespace mxnet_op { +template +struct normal_kernel { + MSHADOW_XINLINE static void Map(index_t i, const Shape &lstride, + const Shape &hstride, + const Shape &oshape, IType *loc, + IType *scale, float *normals, OType *out) { + Shape coord = unravel(i, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto hidx = static_cast(dot(coord, hstride)); + IType loc_value = loc[lidx]; + IType scale_value = scale[hidx]; + out[i] = loc_value + normals[i] * scale_value; + } +}; + +template +struct normal_one_scalar_kernel { + MSHADOW_XINLINE static void Map(index_t i, int scalar_pos, + const Shape &stride, + const Shape &oshape, IType *array, + float scalar, float *normals, OType *out) { + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + IType loc_value; + IType scale_value; + if (scalar_pos == 0) { + loc_value = scalar; + scale_value = array[idx]; + } else { + loc_value = array[idx]; + scale_value = scalar; + } + out[i] = loc_value + normals[i] * scale_value; + } +}; + +template +struct normal_two_scalar_kernel { + MSHADOW_XINLINE static void Map(index_t i, float loc, float scale, + float *normals, OType *out) { + out[i] = loc + normals[i] * scale; + } +}; + +template +struct check_legal_scale_kernel { + MSHADOW_XINLINE static void Map(index_t i, IType *scalar, float* flag) { + if (scalar[i] < 0) { + flag[0] = -1.0; + } + } +}; + +} // namespace mxnet_op + +template +void NumpyNormalForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyNormalParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(outputs.size(), 1); + Stream *s = ctx.get_stream(); + + // Generate base random number. + Random *prnd = ctx.requested[0].get_random(s); + index_t output_len = outputs[0].Size(); + Tensor workspace = + ctx.requested[1].get_space_typed(Shape1(output_len + 1), s); + Tensor normal_tensor = workspace.Slice(0, output_len); + Tensor indicator_device = workspace.Slice(output_len, output_len + 1); + float indicator_host = 1.0; + float *indicator_device_ptr = indicator_device.dptr_; + prnd->SampleGaussian(&normal_tensor, 0.0, 1.0); + mxnet::TShape new_lshape, new_hshape, new_oshape; + + // [scalar scalar] case + if (inputs.size() == 0U) { + CHECK_GE(param.scale.value(), 0.0) << "ValueError: scale < 0"; + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.loc.value(), param.scale.value(), + normal_tensor.dptr_, outputs[0].dptr()); + }); + } else if (inputs.size() == 1U) { + // [scalar tensor], [tensor scalar] case + int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, + &new_lshape, &new_lshape, &new_oshape); + int scalar_pos; + float scalar_value; + if (param.loc.has_value()) { + scalar_pos = 0; + scalar_value = param.loc.value(); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + Kernel, xpu>::Launch( + s, inputs[0].Size(), inputs[0].dptr(), indicator_device_ptr); + }); + _copy(&indicator_host, indicator_device_ptr); + CHECK_GE(indicator_host, 0.0) << "ValueError: scale < 0"; + } else { + scalar_pos = 1; + scalar_value = param.scale.value(); + CHECK_GE(scalar_value, 0.0) << "ValueError: scale < 0"; + } + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + Shape oshape = new_oshape.get(); + Shape stride = calc_stride(new_lshape.get()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), scalar_pos, stride, oshape, + inputs[0].dptr(), scalar_value, normal_tensor.dptr_, + outputs[0].dptr()); + }); + }); + }); + } else if (inputs.size() == 2U) { + // [tensor tensor] case + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + Kernel, xpu>::Launch( + s, inputs[1].Size(), inputs[1].dptr(), indicator_device_ptr); + }); + _copy(&indicator_host, indicator_device_ptr); + CHECK_GE(indicator_host, 0.0) << "ValueError: scale < 0"; + int ndim = FillShape(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_hshape, &new_oshape); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + Shape oshape = new_oshape.get(); + Shape lstride = calc_stride(new_lshape.get()); + Shape hstride = calc_stride(new_hshape.get()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), lstride, hstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), + normal_tensor.dptr_, outputs[0].dptr()); + }); + }); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_NORMAL_OP_H_ diff --git a/src/operator/numpy/random/np_uniform_op.cc b/src/operator/numpy/random/np_uniform_op.cc index 394626d07596..7307b7744d5e 100644 --- a/src/operator/numpy/random/np_uniform_op.cc +++ b/src/operator/numpy/random/np_uniform_op.cc @@ -43,6 +43,12 @@ NNVM_REGISTER_OP(_npi_uniform) .set_num_outputs(1) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { + const NumpyUniformParam& param = nnvm::get(attrs.parsed); + int num_inputs = 2; + if (param.low.has_value()) num_inputs -= 1; + if (param.high.has_value()) num_inputs -= 1; + if (num_inputs == 0) return std::vector(); + if (num_inputs == 1) return std::vector{"input1"}; return std::vector{"input1", "input2"}; }) .set_attr_parser(ParamParser) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1656716b4a35..07bd2864cfb8 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1579,7 +1579,6 @@ def test_np_random(): shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] dtypes = ['float16', 'float32', 'float64'] op_names = ['uniform', 'normal'] - op_names = ['normal'] for shape in shapes: for dtype in dtypes: for op_name in op_names: