diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index e67c766c6bdf..913ceaaff097 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -23,7 +23,7 @@ from ..ndarray import NDArray -__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "shuffle"] +__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "shuffle", 'gamma'] def randint(low, high=None, size=None, dtype=None, ctx=None, out=None): @@ -319,6 +319,63 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out) +def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): + """Draw samples from a Gamma distribution. + + Samples are drawn from a Gamma distribution with specified parameters, + `shape` (sometimes designated "k") and `scale` (sometimes designated + "theta"), where both parameters are > 0. + + Parameters + ---------- + shape : float or array_like of floats + The shape of the gamma distribution. Should be greater than zero. + scale : float or array_like of floats, optional + The scale of the gamma distribution. Should be greater than zero. + Default is equal to 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``shape`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(shape, scale).size`` samples are drawn. + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : ndarray or scalar + Drawn samples from the parameterized gamma distribution. + + The Gamma distribution is often used to model the times to failure of + electronic components, and arises naturally in processes for which the + waiting times between Poisson distributed events are relevant. + """ + from ...numpy import ndarray as np_ndarray + input_type = (isinstance(shape, np_ndarray), isinstance(scale, np_ndarray)) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if out is not None: + size = out.shape + if size == (): + size = None + if input_type == (True, True): + return _npi.gamma(shape, scale, shape=None, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (False, True): + return _npi.gamma(scale, shape=shape, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (True, False): + return _npi.gamma(shape, shape=None, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + else: + return _npi.gamma(shape=shape, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + + raise ValueError("Distribution parameters must be either mxnet.numpy.ndarray or numbers") + + def rand(*size, **kwargs): r"""Random values in a given shape. diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 95719a005cec..198f2fcb4389 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -20,7 +20,8 @@ from __future__ import absolute_import from ..ndarray import numpy as _mx_nd_np -__all__ = ["randint", "uniform", "normal", "choice", "rand", "multinomial", "shuffle", "randn"] +__all__ = ["randint", "uniform", "normal", "choice", "rand", "multinomial", "shuffle", "randn", + "gamma"] def randint(low, high=None, size=None, dtype=None, ctx=None, out=None): @@ -359,6 +360,40 @@ def shuffle(x): _mx_nd_np.random.shuffle(x) +def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): + """Draw samples from a Gamma distribution. + + Samples are drawn from a Gamma distribution with specified parameters, + `shape` (sometimes designated "k") and `scale` (sometimes designated + "theta"), where both parameters are > 0. + + Parameters + ---------- + shape : float or array_like of floats + The shape of the gamma distribution. Should be greater than zero. + scale : float or array_like of floats, optional + The scale of the gamma distribution. Should be greater than zero. + Default is equal to 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``shape`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(shape, scale).size`` samples are drawn. + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : ndarray or scalar + Drawn samples from the parameterized gamma distribution. + + The Gamma distribution is often used to model the times to failure of + electronic components, and arises naturally in processes for which the + waiting times between Poisson distributed events are relevant. + """ + return _mx_nd_np.random.gamma(shape, scale, size, dtype, ctx, out) + + def randn(*size, **kwargs): r"""Return a sample (or samples) from the "standard normal" distribution. If positive, int_like or int-convertible arguments are provided, diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 94c29f407acc..c6b23b507d87 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -21,7 +21,7 @@ from ...context import current_context from . import _internal as _npi -__all__ = ['randint', 'uniform', 'normal', 'rand', 'shuffle'] +__all__ = ['randint', 'uniform', 'normal', 'rand', 'shuffle', 'gamma'] def randint(low, high=None, size=None, dtype=None, ctx=None, out=None): @@ -290,6 +290,63 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out) +def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): + """Draw samples from a Gamma distribution. + + Samples are drawn from a Gamma distribution with specified parameters, + `shape` (sometimes designated "k") and `scale` (sometimes designated + "theta"), where both parameters are > 0. + + Parameters + ---------- + shape : float or array_like of floats + The shape of the gamma distribution. Should be greater than zero. + scale : float or array_like of floats, optional + The scale of the gamma distribution. Should be greater than zero. + Default is equal to 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``shape`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(shape, scale).size`` samples are drawn. + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : _Symbol + Drawn samples from the parameterized gamma distribution. + + The Gamma distribution is often used to model the times to failure of + electronic components, and arises naturally in processes for which the + waiting times between Poisson distributed events are relevant. + """ + from ._symbol import _Symbol as np_symbol + input_type = (isinstance(shape, np_symbol), isinstance(scale, np_symbol)) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if out is not None: + size = out.shape + if size == (): + size = None + if input_type == (True, True): + return _npi.gamma(shape, scale, shape=None, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (False, True): + return _npi.gamma(scale, shape=shape, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (True, False): + return _npi.gamma(shape, shape=None, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + else: + return _npi.gamma(shape=shape, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + + raise ValueError("Distribution parameters must be either _Symbol or numbers") + + def shuffle(x): """ Modify a sequence in-place by shuffling its contents. diff --git a/src/operator/numpy/random/np_gamma_op.cc b/src/operator/numpy/random/np_gamma_op.cc new file mode 100644 index 000000000000..72e337b1642b --- /dev/null +++ b/src/operator/numpy/random/np_gamma_op.cc @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_gamma_op.cc + * \brief Operator for random sampling from gamma distribution + */ + +#include "./np_gamma_op.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyGammaParam); + +inline bool NumpyGammaOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + const NumpyGammaParam& param = nnvm::get(attrs.parsed); + int otype = param.dtype; + if (otype != -1) { + (*out_attrs)[0] = otype; + } else { + (*out_attrs)[0] = mshadow::kFloat32; + } + return true; +} + +NNVM_REGISTER_OP(_npi_gamma) +.describe("Numpy behavior gamma") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyGammaParam& param = nnvm::get(attrs.parsed); + int num_inputs = 2; + if (param.shape.has_value()) num_inputs -= 1; + if (param.scale.has_value()) num_inputs -= 1; + return num_inputs; + } +) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyGammaParam& param = nnvm::get(attrs.parsed); + int num_inputs = 2; + if (param.scale.has_value()) num_inputs -= 1; + if (param.shape.has_value()) num_inputs -= 1; + if (num_inputs == 0) return std::vector(); + if (num_inputs == 1) return std::vector{"input1"}; + return std::vector{"input1", "input2"}; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", TwoparamsDistOpShape) +.set_attr("FInferType", NumpyGammaOpType) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, + ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyGammaForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_argument("input2", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyGammaParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_gamma_op.cu b/src/operator/numpy/random/np_gamma_op.cu new file mode 100644 index 000000000000..5be15c7b9d13 --- /dev/null +++ b/src/operator/numpy/random/np_gamma_op.cu @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_gamma_op.cu + * \brief Operator for random sampling from gamma distribution + */ + +#include "./np_gamma_op.h" +#include + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_gamma) +.set_attr("FCompute", NumpyGammaForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h new file mode 100644 index 000000000000..83e8f1f5242b --- /dev/null +++ b/src/operator/numpy/random/np_gamma_op.h @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_gamma_op.h + * \brief Operator for random sampling from gamma distribution + */ + +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_GAMMA_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_GAMMA_OP_H_ + +#include +#include +#include +#include +#include +#include "./dist_common.h" +#include "../../elemwise_op_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" + +#define M 2 + +namespace mxnet { +namespace op { + +struct NumpyGammaParam : public dmlc::Parameter { + dmlc::optional shape; + dmlc::optional scale; + std::string ctx; + int dtype; + dmlc::optional> size; + DMLC_DECLARE_PARAMETER(NumpyGammaParam) { + DMLC_DECLARE_FIELD(shape); + DMLC_DECLARE_FIELD(scale); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional>()) + .describe("Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + DMLC_DECLARE_FIELD(ctx) + .set_default("cpu") + .describe("Context of output, in format [xpu|xpu|xpu_pinned](n)." + " Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .set_default(mshadow::kFloat32) + .describe("DType of the output in case this can't be inferred. " + "Defaults to float32 if not defined (dtype=None)."); + } +}; + + +namespace mxnet_op { + +template +MSHADOW_XINLINE void GammaTransform(IType a, IType b, + FType* uniforms, FType* normals) { + FType d = a < 1 ? a + 2.0 / 3.0 : a - 1.0 / 3.0; + FType k = sqrt(9.0 * d); + FType c = 1.0 / k; + for (size_t i = 0; i < M - 1; i++) { + FType u = uniforms[i]; + FType n = normals[i]; + uniforms[i] = FType(-1); + FType ocn = 1+c*n; + FType v = ocn*ocn*ocn; + if (v > 0) { + if (u <= (1 - 0.0331 * (n * n) * (n * n))) { + // rejection sample. The second operation should be + // performed with low probability. This is the "squeeze" + uniforms[i] = FType(d * v * b); + } + if (logf(u) < 0.5 * (n * n) + d * (1 - v + logf(v))) { + // rejection sample. The second operation should be + // performed with low probability. This is the "squeeze" + uniforms[i] = FType(d * v * b); + } + } + } +} + + +template +MSHADOW_XINLINE FType GammaReduce(IType a, FType* uniforms) { + FType u2 = uniforms[M - 1]; + for (size_t i = 0; i < M - 1; i++) { + FType sample = uniforms[i]; + if (sample > 0) { + return a < 1 ? sample * powf(u2, FType(1.0 / a)) : sample; + } + } + return -1; +} + +template +struct CheckSuccessKernel { + MSHADOW_XINLINE static void Map(int i, OType* out, FType* flag) { + if (out[i] < 0) { + flag[0] = -1.0; + } + } +}; + +template +struct gamma_kernel { + MSHADOW_XINLINE static void Map(index_t i, const Shape &lstride, + const Shape &hstride, + const Shape &oshape, IType *shape, + IType *scale, FType *uniforms, FType *normals, + OType *out, FType *flag = nullptr) { + // We know the sampling procedure is in its first stage, if `flag` is + // nullptr, i.e. there is no need for reseting the indicator + // variable(flag[0] = 1) nor checking whether a specific element is sampled + // successfully (out[i] < 0). + bool in_first_stage = (flag == nullptr); + if (!in_first_stage) { + flag[0] = 1; + } + Shape coord = unravel(i, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto hidx = static_cast(dot(coord, hstride)); + IType shape_value = shape[lidx]; + IType scale_value = scale[hidx]; + if (in_first_stage || out[i] < 0) { + // map phase + GammaTransform(shape_value, scale_value, uniforms + i * M, + normals + i * M); + // reduce phase + OType sample = + (OType)GammaReduce(shape_value, uniforms + i * M); + out[i] = sample; + } + } +}; + +template +struct gamma_one_scalar_kernel { + MSHADOW_XINLINE static void Map(index_t i, int scalar_pos, + const Shape &stride, + const Shape &oshape, IType *array, + float scalar, FType *uniforms, FType *normals, + OType *out, FType *flag = nullptr) { + // We know the sampling procedure is in its first stage, if `flag` is + // nullptr, i.e. there is no need for reseting the indicator + // variable(flag[0] = 1) nor checking whether a specific element is sampled + // successfully (out[i] < 0). + bool in_first_stage = (flag == nullptr); + if (!in_first_stage) { + flag[0] = 1; + } + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + IType shape_value; + IType scale_value; + if (scalar_pos == 0) { + shape_value = scalar; + scale_value = array[idx]; + } else { + shape_value = array[idx]; + scale_value = scalar; + } + if (in_first_stage || out[i] < 0) { + // map phase + GammaTransform(shape_value, scale_value, uniforms + i * M, + normals + i * M); + // reduce phase + OType sample = + (OType)GammaReduce(shape_value, uniforms + i * M); + out[i] = sample; + } + } +}; + +template +struct gamma_two_scalar_kernel { + MSHADOW_XINLINE static void Map(index_t i, float shape_value, + float scale_value, FType *uniforms_origin, + FType *normals_origin, OType *out, + FType *flag = nullptr) { + // We know the sampling procedure is in its first stage, if `flag` is + // nullptr, i.e. there is no need for reseting the indicator + // variable(flag[0] = 1) nor checking whether a specific element is sampled + // successfully (out[i] < 0). + bool in_first_stage = (flag == nullptr); + if (!in_first_stage) { + flag[0] = 1; + } + FType *uniforms = uniforms_origin + i * M; + FType *normals = normals_origin + i * M; + if (in_first_stage || out[i] < 0) { + // map phase + GammaTransform(shape_value, scale_value, uniforms, normals); + // reduce phase + OType sample = (OType)GammaReduce(shape_value, uniforms); + out[i] = sample; + } + } +}; +} // namespace mxnet_op + +template +void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyGammaParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(outputs.size(), 1); + Stream *s = ctx.get_stream(); + // Generate base random number. + Random *prnd = ctx.requested[0].get_random(s); + index_t output_len = outputs[0].Size(); + Tensor random_tensor = + ctx.requested[1].get_space_typed(Shape1(output_len * 2 * M + 1), s); + Tensor uniform_tensor = random_tensor.Slice(0, output_len * M); + Tensor normal_tensor = random_tensor.Slice(output_len * M, output_len * 2 * M); + prnd->SampleUniform(&uniform_tensor, 0, 1); + prnd->SampleGaussian(&normal_tensor, 0, 1); + mxnet::TShape new_lshape, new_hshape, new_oshape; + FType failure_indicator = 1.0; + Tensor failure_indic_workspace = + random_tensor.Slice(output_len * 2 * M, output_len * 2 * M + 1); + FType *failure_indicator_device = failure_indic_workspace.dptr_; + // [scalar scalar] case + if (inputs.size() == 0U) { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + bool in_resample_stage = false; + do { + if (in_resample_stage) { + prnd->SampleUniform(&uniform_tensor, 0, 1); + prnd->SampleGaussian(&normal_tensor, 0, 1); + } + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.shape.value(), param.scale.value(), + uniform_tensor.dptr_, normal_tensor.dptr_, outputs[0].dptr(), + in_resample_stage ? failure_indicator_device : nullptr); + failure_indicator = 1.0; + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); + _copy(s, &failure_indicator, failure_indicator_device); + in_resample_stage = true; + } while (failure_indicator < 0); + }); + } else if (inputs.size() == 1U) { + // [scalar tensor], [tensor scalar] case + int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, + &new_lshape, &new_lshape, &new_oshape); + int scalar_pos; + float scalar_value; + if (param.shape.has_value()) { + scalar_pos = 0; + scalar_value = param.shape.value(); + } else { + scalar_pos = 1; + scalar_value = param.scale.value(); + } + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape stride = calc_stride(new_lshape.get()); + bool in_resample_stage = false; + do { + if (in_resample_stage) { + prnd->SampleUniform(&uniform_tensor, 0, 1); + prnd->SampleGaussian(&normal_tensor, 0, 1); + } + Kernel, xpu>::Launch( + s, outputs[0].Size(), scalar_pos, stride, oshape, + inputs[0].dptr(), scalar_value, + uniform_tensor.dptr_, normal_tensor.dptr_, + outputs[0].dptr(), + in_resample_stage ? failure_indicator_device : nullptr); + failure_indicator = 1.0; + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); + _copy(s, &failure_indicator, failure_indicator_device); + in_resample_stage = true; + } while (failure_indicator < 0); + }); + }); + }); + } else if (inputs.size() == 2U) { + // [tensor tensor] case + int ndim = FillShape(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_hshape, &new_oshape); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = calc_stride(new_lshape.get()); + mshadow::Shape hstride = calc_stride(new_hshape.get()); + bool in_resample_stage = false; + do { + if (in_resample_stage) { + prnd->SampleUniform(&uniform_tensor, 0, 1); + prnd->SampleGaussian(&normal_tensor, 0, 1); + } + prnd->SampleUniform(&uniform_tensor, 0, 1); + prnd->SampleGaussian(&normal_tensor, 0, 1); + Kernel, xpu>::Launch( + s, outputs[0].Size(), lstride, hstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), + uniform_tensor.dptr_, normal_tensor.dptr_, + outputs[0].dptr(), in_resample_stage ? failure_indicator_device : nullptr); + failure_indicator = 1.0; + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); + _copy(s, &failure_indicator, failure_indicator_device); + in_resample_stage = true; + } while (failure_indicator < 0); + }); + }); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_GAMMA_OP_H_ diff --git a/tests/nightly/test_np_random.py b/tests/nightly/test_np_random.py index 345fb86b1222..d086ac4b08d6 100644 --- a/tests/nightly/test_np_random.py +++ b/tests/nightly/test_np_random.py @@ -78,6 +78,33 @@ def test_np_normal(): verify_generator(generator=generator_mx_np, buckets=buckets, probs=probs, nsamples=samples, nrepeat=trials) +@retry(5) +@with_seed() +@use_np +def test_np_gamma(): + types = [None, "float32", "float64"] + ctx = mx.context.current_context() + samples = 1000000 + # Generation test + trials = 8 + num_buckets = 5 + for dtype in types: + for alpha, beta in [(2.0, 3.0), (0.5, 1.0)]: + buckets, probs = gen_buckets_probs_with_ppf( + lambda x: ss.gamma.ppf(x, a=alpha, loc=0, scale=beta), num_buckets) + buckets = np.array(buckets).tolist() + def generator_mx(x): return np.random.gamma( + alpha, beta, size=samples, ctx=ctx).asnumpy() + verify_generator(generator=generator_mx, buckets=buckets, probs=probs, + nsamples=samples, nrepeat=trials) + generator_mx_same_seed =\ + lambda x: _np.concatenate( + [np.random.gamma(alpha, beta, size=(x // 10), ctx=ctx).asnumpy() + for _ in range(10)]) + verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs, + nsamples=samples, nrepeat=trials) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index d60b86a57d35..b25c69385e1e 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3279,34 +3279,44 @@ def hybrid_forward(self, F, param1, param2): def test_np_random(): shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] dtypes = ['float16', 'float32', 'float64'] - op_names = ['uniform', 'normal'] + op_names = ['uniform', 'normal', 'gamma'] for shape in shapes: for dtype in dtypes: for op_name in op_names: op = getattr(np.random, op_name, None) assert op is not None - out = op(size=shape, dtype=dtype) + if op_name == 'gamma': + out = op(1, size=shape, dtype=dtype) + else: + out = op(size=shape, dtype=dtype) expected_shape = shape if not isinstance(shape, tuple): expected_shape = () if shape is None else (shape,) assert out.shape == expected_shape class TestRandom(HybridBlock): - def __init__(self, shape, op_name): + def __init__(self, shape, op_name, param=None): super(TestRandom, self).__init__() self._shape = shape self._op_name = op_name + # In case parameters are not optional + self._param = param def hybrid_forward(self, F, x): op = getattr(F.np.random, self._op_name, None) assert op is not None - return x + op(size=shape) + if self._param is not None: + return x + op(self._param, size=self._shape) + return x + op(size=self._shape) x = np.ones(()) for op_name in op_names: for shape in shapes: for hybridize in [False, True]: - net = TestRandom(shape, op_name) + if op_name == "gamma": + net = TestRandom(shape, op_name, 1) + else: + net = TestRandom(shape, op_name) if hybridize: net.hybridize() out = net(x)