diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 2b027acd0745..bf7aba8a9a08 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -24,7 +24,7 @@ __all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "multivariate_normal", - "shuffle", 'gamma', 'beta', 'exponential', 'lognormal', 'weibull'] + "shuffle", 'gamma', 'beta', 'exponential', 'lognormal', 'weibull', 'pareto', 'power'] def randint(low, high=None, size=None, dtype=None, ctx=None, out=None): @@ -462,7 +462,7 @@ def exponential(scale, size): return _npi.exponential(scale=scale, size=size) -def weibull(a, size): +def weibull(a, size=None): r"""Draw samples from a 1-parameter Weibull distribution with given parameter a, via inversion. @@ -515,6 +515,92 @@ def weibull(a, size): return _npi.weibull(a=a, size=size) +def pareto(a, size=None): + r"""Draw samples from a Pareto II or Lomax distribution with specified shape a. + + Parameters + ---------- + a : float or array_like of floats + Shape of the distribution. Must be > 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``a`` is a scalar. Otherwise, + ``np.array(a).size`` samples are drawn. + + Returns + ------- + out : ndarray or scalar + Drawn samples from the Pareto distribution. + + Examples + -------- + >>> np.random.pareto(a=5) + array(0.12749612) + >>> mx.numpy.random.pareto(a=5, size=[2,3]) + array([[0.06933999, 0.0344373 , 0.10654891], + [0.0311172 , 0.12911797, 0.03370714]]) + >>> np.random.pareto(a=np.array([2,3]) + array([0.26636696, 0.15685666]) + + The probability density for the Pareto distribution is f(x) = \frac{am^a}{x^{a+1}} + where a is the shape and m the scale. Here m is assumed 1. The Pareto distribution + is a power law distribution. Pareto created it to describe the wealth in the economy. + """ + from ...numpy import ndarray as np_ndarray + tensor_type_name = np_ndarray + if size == (): + size = None + is_tensor = isinstance(a, tensor_type_name) + if is_tensor: + return _npi.pareto(a, a=None, size=size) + else: + return _npi.pareto(a=a, size=size) + + +def power(a, size=None): + r"""Draw samples in [0, 1] from a power distribution with given parameter a. + + Parameters + ---------- + a : float or array_like of floats + Shape of the distribution. Must be > 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``a`` is a scalar. Otherwise, + ``np.array(a).size`` samples are drawn. + + Returns + ------- + out : ndarray or scalar + Drawn samples from the power distribution. + + Examples + -------- + >>> np.random.power(a=5) + array(0.8602478) + >>> np.random.power(a=5, size=[2,3]) + array([[0.988391 , 0.5153122 , 0.9383134 ], + [0.9078098 , 0.87819266, 0.730635]]) + >>> np.random.power(a=np.array([2,3]) + array([0.7499419 , 0.88894516]) + + The probability density function is f(x; a) = ax^{a-1}, 0 \le x \le 1, a>0. + The power distribution is just the inverse of the Pareto distribution and + a special case of the Beta distribution. + """ + from ...numpy import ndarray as np_ndarray + tensor_type_name = np_ndarray + if size == (): + size = None + is_tensor = isinstance(a, tensor_type_name) + if is_tensor: + return _npi.powerd(a, a=None, size=size) + else: + return _npi.powerd(a=a, size=size) + + def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): """Draw samples from a Gamma distribution. diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 136c761162e2..4fc3688d2b39 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -22,7 +22,7 @@ __all__ = ["randint", "uniform", "normal", "choice", "rand", "multinomial", "multivariate_normal", - "shuffle", "randn", "gamma", 'beta', "exponential", "lognormal", "weibull"] + "shuffle", "randn", "gamma", 'beta', "exponential", "lognormal", "weibull", "pareto", "power"] def randint(low, high=None, size=None, dtype=None, ctx=None, out=None): @@ -529,6 +529,76 @@ def weibull(a, size=None): return _mx_nd_np.random.weibull(a, size) +def pareto(a, size=None): + r"""Draw samples from a Pareto II or Lomax distribution with specified shape a. + + Parameters + ---------- + a : float or array_like of floats + Shape of the distribution. Must be > 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``a`` is a scalar. Otherwise, + ``np.array(a).size`` samples are drawn. + + Returns + ------- + out : ndarray or scalar + Drawn samples from the Pareto distribution. + + Examples + -------- + >>> np.random.pareto(a=5) + array(0.12749612) + >>> mx.numpy.random.pareto(a=5, size=[2,3]) + array([[0.06933999, 0.0344373 , 0.10654891], + [0.0311172 , 0.12911797, 0.03370714]]) + >>> np.random.pareto(a=np.array([2,3]) + array([0.26636696, 0.15685666]) + + The probability density for the Pareto distribution is f(x) = \frac{am^a}{x^{a+1}} + where a is the shape and m the scale. Here m is assumed 1. The Pareto distribution + is a power law distribution. Pareto created it to describe the wealth in the economy. + """ + return _mx_nd_np.random.pareto(a, size) + + +def power(a, size=None): + r"""Draw samples in [0, 1] from a power distribution with given parameter a. + + Parameters + ---------- + a : float or array_like of floats + Shape of the distribution. Must be > 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``a`` is a scalar. Otherwise, + ``np.array(a).size`` samples are drawn. + + Returns + ------- + out : ndarray or scalar + Drawn samples from the power distribution. + + Examples + -------- + >>> np.random.power(a=5) + array(0.8602478) + >>> np.random.power(a=5, size=[2,3]) + array([[0.988391 , 0.5153122 , 0.9383134 ], + [0.9078098 , 0.87819266, 0.730635]]) + >>> np.random.power(a=np.array([2,3]) + array([0.7499419 , 0.88894516]) + + The probability density function is f(x; a) = ax^{a-1}, 0 \le x \le 1, a>0. + The power distribution is just the inverse of the Pareto distribution and + a special case of the Beta distribution. + """ + return _mx_nd_np.random.power(a, size) + + def shuffle(x): """ Modify a sequence in-place by shuffling its contents. diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 8e8611e7ef6a..78c2b76db0aa 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -23,7 +23,7 @@ __all__ = ['randint', 'uniform', 'normal', 'multivariate_normal', - 'rand', 'shuffle', 'gamma', 'beta', 'exponential', 'lognormal', 'weibull'] + 'rand', 'shuffle', 'gamma', 'beta', 'exponential', 'lognormal', 'weibull', 'pareto', 'power'] def randint(low, high=None, size=None, dtype=None, ctx=None, out=None): @@ -469,7 +469,7 @@ def exponential(scale=1.0, size=None): return _npi.exponential(scale=scale, size=size) -def weibull(a, size): +def weibull(a, size=None): r"""Draw samples from a 1-parameter Weibull distribution with given parameter a via inversion. @@ -524,6 +524,92 @@ def weibull(a, size): return _npi.weibull(a=a, size=size) +def pareto(a, size=None): + r"""Draw samples from a Pareto II or Lomax distribution with specified shape a. + + Parameters + ---------- + a : float or array_like of floats + Shape of the distribution. Must be > 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``a`` is a scalar. Otherwise, + ``np.array(a).size`` samples are drawn. + + Returns + ------- + out : _Symbol + Drawn samples from the Pareto distribution. + + Examples + -------- + >>> np.random.pareto(a=5) + array(0.12749612) + >>> mx.numpy.random.pareto(a=5, size=[2,3]) + array([[0.06933999, 0.0344373 , 0.10654891], + [0.0311172 , 0.12911797, 0.03370714]]) + >>> np.random.pareto(a=np.array([2,3]) + array([0.26636696, 0.15685666]) + + The probability density for the Pareto distribution is f(x) = \frac{am^a}{x^{a+1}} + where a is the shape and m the scale. Here m is assumed 1. The Pareto distribution + is a power law distribution. Pareto created it to describe the wealth in the economy. + """ + from ..numpy import _Symbol as np_symbol + tensor_type_name = np_symbol + if size == (): + size = None + is_tensor = isinstance(a, tensor_type_name) + if is_tensor: + return _npi.pareto(a, a=None, size=size) + else: + return _npi.pareto(a=a, size=size) + + +def power(a, size=None): + r"""Draw samples in [0, 1] from a power distribution with given parameter a. + + Parameters + ---------- + a : float or array_like of floats + Shape of the distribution. Must be > 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``a`` is a scalar. Otherwise, + ``np.array(a).size`` samples are drawn. + + Returns + ------- + out : _Symbol + Drawn samples from the power distribution. + + Examples + -------- + >>> np.random.power(a=5) + array(0.8602478) + >>> np.random.power(a=5, size=[2,3]) + array([[0.988391 , 0.5153122 , 0.9383134 ], + [0.9078098 , 0.87819266, 0.730635]]) + >>> np.random.power(a=np.array([2,3]) + array([0.7499419 , 0.88894516]) + + The probability density function is f(x; a) = ax^{a-1}, 0 \le x \le 1, a>0. + The power distribution is just the inverse of the Pareto distribution and + a special case of the Beta distribution. + """ + from ..numpy import _Symbol as np_symbol + tensor_type_name = np_symbol + if size == (): + size = None + is_tensor = isinstance(a, tensor_type_name) + if is_tensor: + return _npi.powerd(a, a=None, size=size) + else: + return _npi.powerd(a=a, size=size) + + def multivariate_normal(mean, cov, size=None, check_valid=None, tol=None): """ multivariate_normal(mean, cov, size=None, check_valid=None, tol=None) diff --git a/src/operator/numpy/random/np_pareto_op.cc b/src/operator/numpy/random/np_pareto_op.cc new file mode 100644 index 000000000000..df77448907fe --- /dev/null +++ b/src/operator/numpy/random/np_pareto_op.cc @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_pareto_op.cc + * \brief Operator for numpy sampling from Pareto distributions + */ + +#include "./np_pareto_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyParetoParam); + +NNVM_REGISTER_OP(_npi_pareto) +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyParetoParam& param = nnvm::get(attrs.parsed); + int num_inputs = 1; + if (param.a.has_value()) { + num_inputs -= 1; + } + return num_inputs; + }) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyParetoParam& param = nnvm::get(attrs.parsed); + int num_inputs = 1; + if (param.a.has_value()) { + num_inputs -= 1; + } + return (num_inputs == 0) ? std::vector() : std::vector{"input1"}; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", UnaryDistOpShape) +.set_attr("FInferType", + [](const nnvm::NodeAttrs &attrs, std::vector *in_attrs, std::vector *out_attrs) { + (*out_attrs)[0] = mshadow::kFloat32; + return true; + }) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyParetoForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyParetoParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_pareto_op.cu b/src/operator/numpy/random/np_pareto_op.cu new file mode 100644 index 000000000000..9af362cc69ce --- /dev/null +++ b/src/operator/numpy/random/np_pareto_op.cu @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_pareto_op.cu + * \brief Operator for numpy sampling from pareto distributions + */ + +#include "./np_pareto_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_pareto) +.set_attr("FCompute", NumpyParetoForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_pareto_op.h b/src/operator/numpy/random/np_pareto_op.h new file mode 100644 index 000000000000..01bf29a6a8b0 --- /dev/null +++ b/src/operator/numpy/random/np_pareto_op.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_pareto_op.h + * \brief Operator for numpy sampling from pareto distribution. + */ + +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_PARETO_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_PARETO_OP_H_ + +#include +#include +#include +#include +#include +#include "../../elemwise_op_common.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +struct NumpyParetoParam : public dmlc::Parameter { + dmlc::optional a; + dmlc::optional> size; + DMLC_DECLARE_PARAMETER(NumpyParetoParam) { + DMLC_DECLARE_FIELD(a) + .set_default(dmlc::optional()); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional>()) + .describe("Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + } +}; + +template +struct scalar_pareto_kernel { + MSHADOW_XINLINE static void Map(index_t i, float a, float *threshold, + DType *out) { + out[i] = exp(-log(threshold[i])/a) - DType(1); + } +}; + +namespace mxnet_op { + +template +struct check_legal_a_kernel { + MSHADOW_XINLINE static void Map(index_t i, IType *a, float* flag) { + if (a[i] <= 0.0) { + flag[0] = -1.0; + } + } +}; + + +template +struct pareto_kernel { + MSHADOW_XINLINE static void Map(index_t i, + const Shape &stride, + const Shape &oshape, + IType *aparams, float* threshold, OType *out) { + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + out[i] = exp(-log(threshold[i])/aparams[idx]) - IType(1); + } +}; + +} // namespace mxnet_op + +template +void NumpyParetoForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyParetoParam ¶m = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + index_t output_len = outputs[0].Size(); + Random *prnd = ctx.requested[0].get_random(s); + Tensor workspace = + ctx.requested[1].get_space_typed(Shape1(output_len + 1), s); + Tensor uniform_tensor = workspace.Slice(0, output_len); + Tensor indicator_device = workspace.Slice(output_len, output_len + 1); + float indicator_host = 1.0; + float *indicator_device_ptr = indicator_device.dptr_; + Kernel::Launch(s, 1, indicator_device_ptr); + prnd->SampleUniform(&workspace, 0.0, 1.0); + if (param.a.has_value()) { + CHECK_GT(param.a.value(), 0.0) << "ValueError: expect a > 0"; + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.a.value(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + } else { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + Kernel, xpu>::Launch( + s, inputs[0].Size(), inputs[0].dptr(), indicator_device_ptr); + }); + _copy(s, &indicator_host, indicator_device_ptr); + CHECK_GE(indicator_host, 0.0) << "ValueError: expect a > 0"; + mxnet::TShape new_lshape, new_oshape; + int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, + &new_lshape, &new_lshape, &new_oshape); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + Shape oshape = new_oshape.get(); + Shape stride = calc_stride(new_lshape.get()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), stride, oshape, inputs[0].dptr(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + }); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_PARETO_OP_H_ diff --git a/src/operator/numpy/random/np_power_op.cc b/src/operator/numpy/random/np_power_op.cc new file mode 100644 index 000000000000..c58511d2973e --- /dev/null +++ b/src/operator/numpy/random/np_power_op.cc @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_power_op.cc + * \brief Operator for numpy sampling from power distributions + */ + +#include "./np_power_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyPowerParam); + +NNVM_REGISTER_OP(_npi_powerd) +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyPowerParam& param = nnvm::get(attrs.parsed); + int num_inputs = 1; + if (param.a.has_value()) { + num_inputs -= 1; + } + return num_inputs; + }) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyPowerParam& param = nnvm::get(attrs.parsed); + int num_inputs = 1; + if (param.a.has_value()) { + num_inputs -= 1; + } + return (num_inputs == 0) ? std::vector() : std::vector{"input1"}; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", UnaryDistOpShape) +.set_attr("FInferType", + [](const nnvm::NodeAttrs &attrs, std::vector *in_attrs, std::vector *out_attrs) { + (*out_attrs)[0] = mshadow::kFloat32; + return true; + }) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyPowerForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyPowerParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_power_op.cu b/src/operator/numpy/random/np_power_op.cu new file mode 100644 index 000000000000..d5067f83bb02 --- /dev/null +++ b/src/operator/numpy/random/np_power_op.cu @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_power_op.cu + * \brief Operator for numpy sampling from power distributions + */ + +#include "./np_power_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_powerd) +.set_attr("FCompute", NumpyPowerForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_power_op.h b/src/operator/numpy/random/np_power_op.h new file mode 100644 index 000000000000..a8835fd62957 --- /dev/null +++ b/src/operator/numpy/random/np_power_op.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_power_op.h + * \brief Operator for numpy sampling from power distribution. + */ + +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_POWER_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_POWER_OP_H_ + +#include +#include +#include +#include +#include +#include "../../elemwise_op_common.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +struct NumpyPowerParam : public dmlc::Parameter { + dmlc::optional a; + dmlc::optional> size; + DMLC_DECLARE_PARAMETER(NumpyPowerParam) { + DMLC_DECLARE_FIELD(a) + .set_default(dmlc::optional()); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional>()) + .describe("Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + } +}; + +template +struct scalar_power_kernel { + MSHADOW_XINLINE static void Map(index_t i, float a, float *threshold, + DType *out) { + out[i] = powf(1 - threshold[i], DType(1.0/a)); + } +}; + +namespace mxnet_op { + +template +struct check_legal_a_kernel { + MSHADOW_XINLINE static void Map(index_t i, IType *a, float* flag) { + if (a[i] <= 0.0) { + flag[0] = -1.0; + } + } +}; + + +template +struct power_kernel { + MSHADOW_XINLINE static void Map(index_t i, + const Shape &stride, + const Shape &oshape, + IType *aparams, float* threshold, OType *out) { + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + out[i] = powf(1 - threshold[i], IType(1.0/aparams[idx])); + } +}; + +} // namespace mxnet_op + +template +void NumpyPowerForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyPowerParam ¶m = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + index_t output_len = outputs[0].Size(); + Random *prnd = ctx.requested[0].get_random(s); + Tensor workspace = + ctx.requested[1].get_space_typed(Shape1(output_len + 1), s); + Tensor uniform_tensor = workspace.Slice(0, output_len); + Tensor indicator_device = workspace.Slice(output_len, output_len + 1); + float indicator_host = 1.0; + float *indicator_device_ptr = indicator_device.dptr_; + Kernel::Launch(s, 1, indicator_device_ptr); + prnd->SampleUniform(&workspace, 0.0, 1.0); + if (param.a.has_value()) { + CHECK_GT(param.a.value(), 0.0) << "ValueError: expect a > 0"; + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.a.value(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + } else { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + Kernel, xpu>::Launch( + s, inputs[0].Size(), inputs[0].dptr(), indicator_device_ptr); + }); + _copy(s, &indicator_host, indicator_device_ptr); + CHECK_GE(indicator_host, 0.0) << "ValueError: expect a > 0"; + mxnet::TShape new_lshape, new_oshape; + int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, + &new_lshape, &new_lshape, &new_oshape); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + Shape oshape = new_oshape.get(); + Shape stride = calc_stride(new_lshape.get()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), stride, oshape, inputs[0].dptr(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + }); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_POWER_OP_H_ diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index ee18ddec8045..726a6179c516 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3600,32 +3600,72 @@ def _test_exponential_exception(scale): @with_seed() @use_np -def test_np_random_weibull(): - class TestRandomWeibull(HybridBlock): - def __init__(self, shape): - super(TestRandomWeibull, self).__init__() +def test_np_random_a(): + op_names = ['pareto', 'power', 'weibull'] + # these distributions have one required parameter a + shapes = [(1,), (2, 3), (4, 0, 5), 6, (7, 8), (), None] + + def _test_random_x_range(output): + ge_zero = _np.all(output >= 0) + smaller_equal_one = _np.all(output <= 1) + return ge_zero and smaller_equal_one + + # test imperative size shapes + for [shape, op_name] in itertools.product(shapes, op_names): + op = getattr(np.random, op_name, None) + assert op is not None + out = op(1.0, size=shape) + expected_shape = shape + if not isinstance(shape, tuple): + expected_shape = () if shape is None else (shape,) + assert out.shape == expected_shape + # test range of generated values for power distribution + if op_name == 'power': + assert _test_random_x_range(out.asnumpy()) == True + + # test symbolic/hybridized size shapes + class TestRandomA(HybridBlock): + def __init__(self, shape, op_name): + super(TestRandomA, self).__init__() self._shape = shape + self._op_name = op_name def hybrid_forward(self, F, a): - return F.np.random.weibull(a, self._shape) - - shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] - for hybridize in [False, True]: - for shape in shapes: - test_weibull = TestRandomWeibull(shape) - if hybridize: - test_weibull.hybridize() - np_out = _np.random.weibull(1, size = shape) - mx_out = test_weibull(np.array([1])) - - for shape in shapes: - mx_out = np.random.weibull(np.array([1]), shape) - np_out = _np.random.weibull(np.array([1]).asnumpy(), shape) - assert_almost_equal(mx_out.asnumpy().shape, np_out.shape) + op = getattr(F.np.random, self._op_name, None) + assert op is not None + return op(a, size=self._shape) - def _test_weibull_exception(a): - output = np.random.weibull(a=a).asnumpy() - assertRaises(ValueError, _test_weibull_exception, -1) + hybridize = [False, True] + for [op_name, shape, hybridize] in itertools.product(op_names, shapes, hybridize): + test_op = TestRandomA(shape, op_name) + if hybridize: + test_op.hybridize() + mx_out = test_op(np.array(1.0)) + expected_shape = shape + if not isinstance(shape, tuple): + expected_shape = () if shape is None else (shape,) + assert mx_out.shape == expected_shape + + # test broadcasting of required parameter a shape when a is array-like + ashapes = [(1,), (2, 3), (4, 0, 5), 6, (7, 8)] + for shape in ashapes: + a = np.ones(shape) + for op_name in op_names: + op = getattr(np.random, op_name, None) + assert op is not None + mx_out = op(a, size=None) + expected_shape = a.shape + assert mx_out.shape == expected_shape + + # test illegal parameter values (as numpy produces) + def _test_exception(a): + output = op(a=a).asnumpy() + for op in op_names: + op = getattr(np.random, op_name, None) + if op is not None: + assertRaises(ValueError, _test_exception, -1) + if op in ['pareto', 'power']: + assertRaises(ValueError, _test_exception, 0) @with_seed()