From 93c2291c62e4588ef18ef0b28184473b35d49c1c Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 11 Sep 2019 08:49:37 +0000 Subject: [PATCH 1/6] gamma done --- python/mxnet/ndarray/numpy/random.py | 27 +++++++++++++++++++++++++++ python/mxnet/numpy/random.py | 4 ++++ 2 files changed, 31 insertions(+) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 9372beaf1e92..40c93ab3cb76 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -319,3 +319,30 @@ def choice(a, size=None, replace=True, p=None, **kwargs): return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False, out=out) else: return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out) + + +def gamma(shape, scale, size=None, ctx=None, dtype=None, out=None): + from ...numpy import ndarray as np_ndarray + input_type = (isinstance(shape, np_ndarray), isinstance(scale, np_ndarray)) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if out is not None: + size = out.shape + if size == (): + size = None + if input_type == (True, True): + return _npi.gamma(shape, scale, shape=None, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (False, True): + return _npi.gamma(scale, shape=shape, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (True, False): + return _npi.gamma(shape, shape=None, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + else: + return _npi.gamma(shape=shape, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + + raise ValueError("Distribution parameters must be either mxnet.numpy.ndarray or numbers") diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index aace767c8d55..017e9c67d3b8 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -235,3 +235,7 @@ def choice(a, size=None, replace=True, p=None, **kwargs): array([2, 3, 0]) """ return _mx_nd_np.random.choice(a, size, replace, p, **kwargs) + + +def gamma(shape, scale=1.0, size=None, ctx=None, dtype=None, out=None): + return _mx_nd_np.random.gamma(shape, scale, size, ctx, dtype, out) From d70e4612ce01f67cc3f86831ab3b2a01a088b168 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Thu, 12 Sep 2019 09:41:35 +0000 Subject: [PATCH 2/6] gamma finished --- python/mxnet/ndarray/numpy/random.py | 43 ++- python/mxnet/numpy/random.py | 34 +- python/mxnet/symbol/numpy/random.py | 60 +++ src/operator/numpy/random/np_gamma_op.cc | 81 ++++ src/operator/numpy/random/np_gamma_op.cu | 46 +++ src/operator/numpy/random/np_gamma_op.h | 461 +++++++++++++++++++++++ tests/python/unittest/test_numpy_op.py | 46 +++ 7 files changed, 764 insertions(+), 7 deletions(-) create mode 100644 src/operator/numpy/random/np_gamma_op.cc create mode 100644 src/operator/numpy/random/np_gamma_op.cu create mode 100644 src/operator/numpy/random/np_gamma_op.h diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 40c93ab3cb76..8db25f8f75e1 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -321,9 +321,42 @@ def choice(a, size=None, replace=True, p=None, **kwargs): return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out) -def gamma(shape, scale, size=None, ctx=None, dtype=None, out=None): +def gamma(shape, scale, size=None, **kwargs): + """Draw samples from a Gamma distribution. + + Samples are drawn from a Gamma distribution with specified parameters, + `shape` (sometimes designated "k") and `scale` (sometimes designated + "theta"), where both parameters are > 0. + + Parameters + ---------- + shape : float or array_like of floats + The shape of the gamma distribution. Should be greater than zero. + scale : float or array_like of floats, optional + The scale of the gamma distribution. Should be greater than zero. + Default is equal to 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``shape`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(shape, scale).size`` samples are drawn. + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : ndarray or scalar + Drawn samples from the parameterized gamma distribution. + + The Gamma distribution is often used to model the times to failure of + electronic components, and arises naturally in processes for which the + waiting times between Poisson distributed events are relevant. + """ from ...numpy import ndarray as np_ndarray input_type = (isinstance(shape, np_ndarray), isinstance(scale, np_ndarray)) + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) + dtype = kwargs.pop('dtype', None) if dtype is None: dtype = 'float32' if ctx is None: @@ -334,15 +367,15 @@ def gamma(shape, scale, size=None, ctx=None, dtype=None, out=None): size = None if input_type == (True, True): return _npi.gamma(shape, scale, shape=None, scale=None, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) elif input_type == (False, True): return _npi.gamma(scale, shape=shape, scale=None, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) elif input_type == (True, False): return _npi.gamma(shape, shape=None, scale=scale, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) else: return _npi.gamma(shape=shape, scale=scale, size=size, - ctx=ctx, dtype=dtype, out=out) + ctx=ctx, dtype=dtype, out=out) raise ValueError("Distribution parameters must be either mxnet.numpy.ndarray or numbers") diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 017e9c67d3b8..e1691b60d3de 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -237,5 +237,35 @@ def choice(a, size=None, replace=True, p=None, **kwargs): return _mx_nd_np.random.choice(a, size, replace, p, **kwargs) -def gamma(shape, scale=1.0, size=None, ctx=None, dtype=None, out=None): - return _mx_nd_np.random.gamma(shape, scale, size, ctx, dtype, out) +def gamma(shape, scale=1.0, size=None, **kwargs): + """Draw samples from a Gamma distribution. + + Samples are drawn from a Gamma distribution with specified parameters, + `shape` (sometimes designated "k") and `scale` (sometimes designated + "theta"), where both parameters are > 0. + + Parameters + ---------- + shape : float or array_like of floats + The shape of the gamma distribution. Should be greater than zero. + scale : float or array_like of floats, optional + The scale of the gamma distribution. Should be greater than zero. + Default is equal to 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``shape`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(shape, scale).size`` samples are drawn. + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : ndarray or scalar + Drawn samples from the parameterized gamma distribution. + + The Gamma distribution is often used to model the times to failure of + electronic components, and arises naturally in processes for which the + waiting times between Poisson distributed events are relevant. + """ + return _mx_nd_np.random.gamma(shape, scale, size, **kwargs) diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 523983bac20a..a32bc5701f65 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -267,3 +267,63 @@ def choice(a, size=None, replace=True, p=None, **kwargs): return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False, out=out) else: return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out) + + +def gamma(shape, scale, size=None, **kwargs): + """Draw samples from a Gamma distribution. + + Samples are drawn from a Gamma distribution with specified parameters, + `shape` (sometimes designated "k") and `scale` (sometimes designated + "theta"), where both parameters are > 0. + + Parameters + ---------- + shape : float or array_like of floats + The shape of the gamma distribution. Should be greater than zero. + scale : float or array_like of floats, optional + The scale of the gamma distribution. Should be greater than zero. + Default is equal to 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``shape`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(shape, scale).size`` samples are drawn. + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : _Symbol + Drawn samples from the parameterized gamma distribution. + + The Gamma distribution is often used to model the times to failure of + electronic components, and arises naturally in processes for which the + waiting times between Poisson distributed events are relevant. + """ + from ._symbol import _Symbol as np_symbol + input_type = (isinstance(shape, np_symbol), isinstance(scale, np_symbol)) + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) + dtype = kwargs.pop('dtype', None) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if out is not None: + size = out.shape + if size == (): + size = None + if input_type == (True, True): + return _npi.gamma(shape, scale, shape=None, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (False, True): + return _npi.gamma(scale, shape=shape, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (True, False): + return _npi.gamma(shape, shape=None, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + else: + return _npi.gamma(shape=shape, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) + + raise ValueError("Distribution parameters must be either _Symbol or numbers") diff --git a/src/operator/numpy/random/np_gamma_op.cc b/src/operator/numpy/random/np_gamma_op.cc new file mode 100644 index 000000000000..d154024369ee --- /dev/null +++ b/src/operator/numpy/random/np_gamma_op.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_gamma_op.cc + * \brief Operator for random sampling from gamma distribution + */ + +#include "./np_gamma_op.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyGammaParam); + +template <> +void _copy(float *dst, float *src) { + *dst = *src; +} + +template <> +void _copy(double *dst, double *src) { + *dst = *src; +} + +NNVM_REGISTER_OP(_npi_gamma) +.describe("numpy behavior gamma") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyGammaParam& param = nnvm::get(attrs.parsed); + int num_inputs = 2; + if (param.shape.has_value()) num_inputs -= 1; + if (param.scale.has_value()) num_inputs -= 1; + return num_inputs; + } +) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyGammaParam& param = nnvm::get(attrs.parsed); + int num_inputs = 2; + if (param.scale.has_value()) num_inputs -= 1; + if (param.shape.has_value()) num_inputs -= 1; + if (num_inputs == 0) return std::vector(); + if (num_inputs == 1) return std::vector{"input1"}; + return std::vector{"input1", "input2"}; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", TwoparamsDistOpShape) +.set_attr("FInferType", NumpyGammaOpType) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, + ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyGammaForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_argument("input2", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyGammaParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_gamma_op.cu b/src/operator/numpy/random/np_gamma_op.cu new file mode 100644 index 000000000000..2b04b8b2b0e7 --- /dev/null +++ b/src/operator/numpy/random/np_gamma_op.cu @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_gamma_op.cu + * \brief Operator for random sampling from gamma distribution + */ + +#include "./np_gamma_op.h" +#include + +namespace mxnet { +namespace op { + +template <> +void _copy(float *dst, float *src) { + CUDA_CALL(cudaMemcpy(dst, src, sizeof(float), cudaMemcpyDeviceToHost)); +} + +template <> +void _copy(double *dst, double *src) { + CUDA_CALL(cudaMemcpy(dst, src, sizeof(double), cudaMemcpyDeviceToHost)); +} + +NNVM_REGISTER_OP(_npi_gamma) +.set_attr("FCompute", NumpyGammaForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h new file mode 100644 index 000000000000..197291e2e306 --- /dev/null +++ b/src/operator/numpy/random/np_gamma_op.h @@ -0,0 +1,461 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_gamma_op.h + * \brief Operator for random sampling from gamma distribution + */ + +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_GAMMA_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_GAMMA_OP_H_ + +#include +#include +#include +#include +#include +#include "./dist_common.h" +#include "../../elemwise_op_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" + +#define M 2 + +namespace mxnet { +namespace op { + +struct NumpyGammaParam : public dmlc::Parameter { + dmlc::optional shape; + dmlc::optional scale; + std::string ctx; + int dtype; + dmlc::optional> size; + DMLC_DECLARE_PARAMETER(NumpyGammaParam) { + DMLC_DECLARE_FIELD(shape); + DMLC_DECLARE_FIELD(scale); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional>()) + .describe("Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + DMLC_DECLARE_FIELD(ctx) + .set_default("xpu") + .describe("Context of output, in format [xpu|xpu|xpu_pinned](n)." + " Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .set_default(mshadow::kFloat32) + .describe("DType of the output in case this can't be inferred. " + "Defaults to float32 if not defined (dtype=None)."); + } +}; + + +inline bool NumpyGammaOpType(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const NumpyGammaParam ¶m = nnvm::get(attrs.parsed); + int otype = param.dtype; + if (otype != -1) { + (*out_attrs)[0] = otype; + } else { + (*out_attrs)[0] = mshadow::kFloat32; + } + return true; +} + +// template +// void _copy(FType *dst, FType *src) { +// #if USE_CUDA == 1 +// CUDA_CALL(cudaMemcpy(dst, src, sizeof(FType, cudaMemcpyDeviceToHost))) +// #else +// *dst = *src; +// #endif +// } + +// #if USE_CUDA == 1 +// template +// void _copy(context::xpu device, FType *dst, FType *src); +// #endif + +template +void _copy(float *dst, float*src); + +template +void _copy(double *dst, double*src); + +namespace mxnet_op { + +template +MSHADOW_XINLINE void GammaTransform(IType a, IType b, + FType* uniforms, FType* normals) { + // start + FType d = a < 1 ? a + 2.0 / 3.0 : a - 1.0 / 3.0; + FType k = sqrt(9.0 * d); + FType c = 1.0 / k; + // printf("c1 %f\n", uniforms[M - 1]); + for (size_t i = 0; i < M - 1; i++) { + FType u = uniforms[i]; + FType n = normals[i]; + uniforms[i] = FType(-1); + FType ocn = 1+c*n; + FType v = ocn*ocn*ocn; + if (v > 0) { + if (u <= (1 - 0.0331 * (n * n) * (n * n))) { + // rejection sample. The second operation should be + // performed with low probability. This is the "squeeze" + uniforms[i] = FType(d * v * b); + } + if (logf(u) < 0.5 * (n * n) + d * (1 - v + logf(v))) { + // rejection sample. The second operation should be + // performed with low probability. This is the "squeeze" + uniforms[i] = FType(d * v * b); + } + } + } + // printf("c2 %f\n", uniforms[M - 1]); +} + + +template +MSHADOW_XINLINE FType GammaReduce(IType a, FType* uniforms) { + FType u2 = uniforms[M - 1]; + for (size_t i = 0; i < M - 1; i++) { + FType sample = uniforms[i]; + if (sample > 0) { + // printf("a %f\n", a); + // printf("b %f\n",sample); + // printf("c %f\n", u2); + return a < 1 ? sample * powf(u2, FType(1.0 / a)) : sample; + } + } + return -1; +} + +template +struct CheckSuccessKernel { + MSHADOW_XINLINE static void Map(int i, OType* out, FType* flag) { + if (out[i] < 0) { + flag[0] = -1.0; + } + } +}; + +template +struct gamma_kernel { + MSHADOW_XINLINE static void Map(index_t i, + const Shape &lstride, const Shape &hstride, + const Shape &oshape, + IType *shape, IType *scale, + FType *uniforms, FType *normals, + OType *out) { + Shape coord = unravel(i, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto hidx = static_cast(dot(coord, hstride)); + IType shape_value = shape[lidx]; + IType scale_value = scale[hidx]; + // map phase + GammaTransform(shape_value, scale_value, + uniforms + i * M, normals + i * M); + // reduce phase + OType sample = (OType)GammaReduce(shape_value, uniforms + i * M); + out[i] = sample; + } +}; + +template +struct gamma_kernel_r { + MSHADOW_XINLINE static void Map(index_t i, + const Shape &lstride, const Shape &hstride, + const Shape &oshape, + IType *shape, IType *scale, + FType *uniforms, FType *normals, + OType *out, FType* flag) { + flag[0] = 1; + Shape coord = unravel(i, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto hidx = static_cast(dot(coord, hstride)); + IType shape_value = shape[lidx]; + IType scale_value = scale[hidx]; + if (out[i] < 0) { + // map phase + GammaTransform(shape_value, scale_value, + uniforms + i * M, normals + i * M); + // reduce phase + OType sample = (OType)GammaReduce(shape_value, uniforms + i * M); + out[i] = sample; + } + } +}; + +template +struct gamma_one_scalar_kernel { + MSHADOW_XINLINE static void Map(index_t i, int scalar_pos, + const Shape &stride, + const Shape &oshape, + IType *array, float scalar, + FType *uniforms, FType *normals, + OType *out) { + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + IType shape_value; + IType scale_value; + if (scalar_pos == 0) { + shape_value = scalar; + scale_value = array[idx]; + } else { + shape_value = array[idx]; + scale_value = scalar; + } + // map phase + GammaTransform(shape_value, scale_value, + uniforms + i * M, normals + i * M); + // reduce phase + OType sample = (OType)GammaReduce(shape_value, uniforms + i * M); + out[i] = sample; + } +}; + +template +struct gamma_one_scalar_kernel_r { + MSHADOW_XINLINE static void Map(index_t i, int scalar_pos, + const Shape &stride, + const Shape &oshape, + IType *array, float scalar, + FType *uniforms, FType *normals, + OType *out, FType *flag) { + flag[0] = 1; + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + IType shape_value; + IType scale_value; + if (scalar_pos == 0) { + shape_value = scalar; + scale_value = array[idx]; + } else { + shape_value = array[idx]; + scale_value = scalar; + } + if (out[i] < 0) { + // map phase + GammaTransform(shape_value, scale_value, + uniforms + i * M, normals + i * M); + // reduce phase + OType sample = (OType)GammaReduce(shape_value, uniforms + i * M); + out[i] = sample; + } + } +}; + +template +struct gamma_two_scalar_kernel { + MSHADOW_XINLINE static void Map(index_t i, float shape_value, + float scale_value, FType *uniforms_origin, + FType *normals_origin, OType *out) { + // map phase + FType *uniforms = uniforms_origin + i * M; + FType *normals = normals_origin + i * M; + GammaTransform(shape_value, scale_value, uniforms, + normals); + // reduce phase + OType sample = + (OType)GammaReduce(shape_value, uniforms); + out[i] = sample; + } +}; + +template +struct gamma_two_scalar_kernel_r { + MSHADOW_XINLINE static void Map(index_t i, float shape_value, + float scale_value, + FType *uniforms_origin, + FType *normals_origin, OType *out, FType* flag) { + flag[0] = 1; + // map phase + FType *uniforms = uniforms_origin + i * M; + FType *normals = normals_origin + i * M; + if (out[i] < 0) { + GammaTransform(shape_value, scale_value, uniforms, + normals); + // reduce phase + OType sample = + (OType)GammaReduce(shape_value, uniforms); + out[i] = sample; + } + } +}; +} // namespace mxnet_op + +template +void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyGammaParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(outputs.size(), 1); + Stream *s = ctx.get_stream(); + // Generate base random number. + Random *prnd = ctx.requested[0].get_random(s); + index_t output_len = outputs[0].Size(); + Tensor random_tensor = + ctx.requested[1].get_space_typed(Shape1(output_len * 2 * M + 1), s); + Tensor uniform_tensor = random_tensor.Slice(0, output_len * M); + Tensor normal_tensor = random_tensor.Slice(output_len * M, output_len * 2 * M); + prnd->SampleUniform(&uniform_tensor, 0, 1); + prnd->SampleGaussian(&normal_tensor, 0, 1); + mxnet::TShape new_lshape, new_hshape, new_oshape; + FType failure_indicator = 1.0; + Tensor failure_indic_workspace = + random_tensor.Slice(output_len * 2 * M, output_len * 2 * M + 1); + FType *failure_indicator_device = failure_indic_workspace.dptr_; + // [scalar scalar] case + if (inputs.size() == 0U) { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.shape.value(), param.scale.value(), + uniform_tensor.dptr_, normal_tensor.dptr_, outputs[0].dptr()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); + _copy(&failure_indicator, failure_indicator_device); + // cout<= 0) { + break; + } else { + prnd->SampleUniform(&uniform_tensor, 0, 1); + prnd->SampleGaussian(&normal_tensor, 0, 1); + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.shape.value(), param.scale.value(), + uniform_tensor.dptr_, normal_tensor.dptr_, + outputs[0].dptr(), failure_indicator_device); + failure_indicator = 1.0; + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); + _copy(&failure_indicator, failure_indicator_device); + } + } + }); + } else if (inputs.size() == 1U) { + // [scalar tensor], [tensor scalar] case + int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, + &new_lshape, &new_lshape, &new_oshape); + int scalar_pos; + float scalar_value; + // int type_flag = param.t; + if (param.shape.has_value()) { + scalar_pos = 0; + scalar_value = param.shape.value(); + } else { + scalar_pos = 1; + scalar_value = param.scale.value(); + } + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape stride = + mxnet_op::calc_stride(new_lshape.get()); + mxnet_op::Kernel, xpu>::Launch( + s, outputs[0].Size(), scalar_pos, stride, oshape, + inputs[0].dptr(), scalar_value, + uniform_tensor.dptr_, normal_tensor.dptr_, + outputs[0].dptr()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); + _copy(&failure_indicator, failure_indicator_device); + while (1) { + if (failure_indicator >= 0) { + break; + } else { + prnd->SampleUniform(&uniform_tensor, 0, 1); + prnd->SampleGaussian(&normal_tensor, 0, 1); + mxnet_op::Kernel, xpu>::Launch( + s, outputs[0].Size(), scalar_pos, stride, oshape, + inputs[0].dptr(), scalar_value, + uniform_tensor.dptr_, normal_tensor.dptr_, + outputs[0].dptr(), failure_indicator_device); + failure_indicator = 1.0; + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); + _copy(&failure_indicator, failure_indicator_device); + } + } + }); + }); + }); + } else if (inputs.size() == 2U) { + // [tensor tensor] case + int ndim = FillShape(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_hshape, &new_oshape); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = + mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape hstride = + mxnet_op::calc_stride(new_hshape.get()); + mxnet_op::Kernel, xpu>::Launch( + s, outputs[0].Size(), lstride, hstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), + uniform_tensor.dptr_, normal_tensor.dptr_, + outputs[0].dptr()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); + _copy(&failure_indicator, failure_indicator_device); + while (1) { + if (failure_indicator >= 0) { + break; + } else { + prnd->SampleUniform(&uniform_tensor, 0, 1); + prnd->SampleGaussian(&normal_tensor, 0, 1); + mxnet_op::Kernel, xpu>::Launch( + s, outputs[0].Size(), lstride, hstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), + uniform_tensor.dptr_, normal_tensor.dptr_, + outputs[0].dptr(), failure_indicator_device); + failure_indicator = 1.0; + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); + _copy(&failure_indicator, failure_indicator_device); + } + } + }); + }); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_GAMMA_OP_H_ diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1c9ebbb32866..ab1b15b7952f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1744,6 +1744,52 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): test_indexing_mode(test_choice_weighted, num_classes, num_classes // 2, replace, weight) +@with_seed() +@use_np +def test_np_random_gamma(): + shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] + dtypes = ['float16', 'float32', 'float64'] + op_names = ['gamma'] + for shape in shapes: + for dtype in dtypes: + for op_name in op_names: + print('-------------------------------') + print(op_name) + print(shape) + print(dtype) + op = getattr(np.random, op_name, None) + assert op is not None + out = op(scale=2.0, shape=2.0, size=shape, dtype=dtype) + expected_shape = shape + if not isinstance(shape, tuple): + expected_shape = () if shape is None else (shape,) + assert out.shape == expected_shape + + class TestRandom(HybridBlock): + def __init__(self, shape, op_name): + super(TestRandom, self).__init__() + self._shape = shape + self._op_name = op_name + + def hybrid_forward(self, F, x): + op = getattr(F.np.random, self._op_name, None) + assert op is not None + return x + op(scale=2.0, shape=2.0, size=shape) + + x = np.ones(()) + for op_name in op_names: + for shape in shapes: + for hybridize in [False, True]: + net = TestRandom(shape, op_name) + if hybridize: + net.hybridize() + out = net(x) + expected_shape = shape + if not isinstance(shape, tuple): + expected_shape = () if shape is None else (shape,) + assert out.shape == expected_shape + + if __name__ == '__main__': import nose nose.runmodule() From dd19611739e1abd8a9e6534033a48fa1b9aa27a1 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Fri, 13 Sep 2019 13:25:31 +0000 Subject: [PATCH 3/6] retrigger CI --- src/operator/numpy/random/np_gamma_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/numpy/random/np_gamma_op.cc b/src/operator/numpy/random/np_gamma_op.cc index d154024369ee..a5e254e320c4 100644 --- a/src/operator/numpy/random/np_gamma_op.cc +++ b/src/operator/numpy/random/np_gamma_op.cc @@ -41,7 +41,7 @@ void _copy(double *dst, double *src) { } NNVM_REGISTER_OP(_npi_gamma) -.describe("numpy behavior gamma") +.describe("Numpy behavior gamma") .set_num_inputs( [](const nnvm::NodeAttrs& attrs) { const NumpyGammaParam& param = nnvm::get(attrs.parsed); From 05da1597f1ca62c22c82bb12a23c95b4653dc774 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Thu, 9 Jan 2020 05:20:53 +0000 Subject: [PATCH 4/6] simplify gamma kernel --- src/operator/numpy/random/np_gamma_op.cc | 8 +- src/operator/numpy/random/np_gamma_op.h | 350 +++++++++-------------- 2 files changed, 143 insertions(+), 215 deletions(-) diff --git a/src/operator/numpy/random/np_gamma_op.cc b/src/operator/numpy/random/np_gamma_op.cc index 5d968b059884..72e337b1642b 100644 --- a/src/operator/numpy/random/np_gamma_op.cc +++ b/src/operator/numpy/random/np_gamma_op.cc @@ -30,10 +30,10 @@ namespace op { DMLC_REGISTER_PARAMETER(NumpyGammaParam); -inline bool NumpyGammaOpType(const nnvm::NodeAttrs &attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - const NumpyGammaParam ¶m = nnvm::get(attrs.parsed); +inline bool NumpyGammaOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + const NumpyGammaParam& param = nnvm::get(attrs.parsed); int otype = param.dtype; if (otype != -1) { (*out_attrs)[0] = otype; diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index e274bb327a50..7419838664a6 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -53,21 +53,21 @@ struct NumpyGammaParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(shape); DMLC_DECLARE_FIELD(scale); DMLC_DECLARE_FIELD(size) - .set_default(dmlc::optional>()) - .describe("Output shape. If the given shape is, " - "e.g., (m, n, k), then m * n * k samples are drawn. " - "Default is None, in which case a single value is returned."); + .set_default(dmlc::optional>()) + .describe("Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); DMLC_DECLARE_FIELD(ctx) - .set_default("xpu") - .describe("Context of output, in format [xpu|xpu|xpu_pinned](n)." - " Only used for imperative calls."); + .set_default("cpu") + .describe("Context of output, in format [xpu|xpu|xpu_pinned](n)." + " Only used for imperative calls."); DMLC_DECLARE_FIELD(dtype) - .add_enum("float32", mshadow::kFloat32) - .add_enum("float64", mshadow::kFloat64) - .add_enum("float16", mshadow::kFloat16) - .set_default(mshadow::kFloat32) - .describe("DType of the output in case this can't be inferred. " - "Defaults to float32 if not defined (dtype=None)."); + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .set_default(mshadow::kFloat32) + .describe("DType of the output in case this can't be inferred. " + "Defaults to float32 if not defined (dtype=None)."); } }; @@ -77,7 +77,6 @@ namespace mxnet_op { template MSHADOW_XINLINE void GammaTransform(IType a, IType b, FType* uniforms, FType* normals) { - // start FType d = a < 1 ? a + 2.0 / 3.0 : a - 1.0 / 3.0; FType k = sqrt(9.0 * d); FType c = 1.0 / k; @@ -126,47 +125,32 @@ struct CheckSuccessKernel { template struct gamma_kernel { - MSHADOW_XINLINE static void Map(index_t i, - const Shape &lstride, const Shape &hstride, - const Shape &oshape, - IType *shape, IType *scale, - FType *uniforms, FType *normals, - OType *out) { - Shape coord = unravel(i, oshape); - auto lidx = static_cast(dot(coord, lstride)); - auto hidx = static_cast(dot(coord, hstride)); - IType shape_value = shape[lidx]; - IType scale_value = scale[hidx]; - // map phase - GammaTransform(shape_value, scale_value, - uniforms + i * M, normals + i * M); - // reduce phase - OType sample = (OType)GammaReduce(shape_value, uniforms + i * M); - out[i] = sample; - } -}; - -template -struct gamma_kernel_r { - MSHADOW_XINLINE static void Map(index_t i, - const Shape &lstride, const Shape &hstride, - const Shape &oshape, - IType *shape, IType *scale, - FType *uniforms, FType *normals, - OType *out, FType* flag) { - flag[0] = 1; - Shape coord = unravel(i, oshape); - auto lidx = static_cast(dot(coord, lstride)); - auto hidx = static_cast(dot(coord, hstride)); - IType shape_value = shape[lidx]; - IType scale_value = scale[hidx]; - if (out[i] < 0) { - // map phase - GammaTransform(shape_value, scale_value, - uniforms + i * M, normals + i * M); - // reduce phase - OType sample = (OType)GammaReduce(shape_value, uniforms + i * M); - out[i] = sample; + MSHADOW_XINLINE static void Map(index_t i, const Shape &lstride, + const Shape &hstride, + const Shape &oshape, IType *shape, + IType *scale, FType *uniforms, FType *normals, + OType *out, FType *flag = nullptr) { + // We know the sampling procedure is in its first stage, if `flag` is + // nullptr, i.e. there is no need for reseting the indicator + // variable(flag[0] = 1) nor checking whether a specific element is sampled + // successfully (out[i] < 0). + bool in_first_stage = (flag == nullptr); + if (!in_first_stage) { + flag[0] = 1; + } + Shape coord = unravel(i, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto hidx = static_cast(dot(coord, hstride)); + IType shape_value = shape[lidx]; + IType scale_value = scale[hidx]; + if (in_first_stage || out[i] < 0) { + // map phase + GammaTransform(shape_value, scale_value, uniforms + i * M, + normals + i * M); + // reduce phase + OType sample = + (OType)GammaReduce(shape_value, uniforms + i * M); + out[i] = sample; } } }; @@ -174,58 +158,37 @@ struct gamma_kernel_r { template struct gamma_one_scalar_kernel { MSHADOW_XINLINE static void Map(index_t i, int scalar_pos, - const Shape &stride, - const Shape &oshape, - IType *array, float scalar, - FType *uniforms, FType *normals, - OType *out) { - Shape coord = unravel(i, oshape); - auto idx = static_cast(dot(coord, stride)); - IType shape_value; - IType scale_value; - if (scalar_pos == 0) { - shape_value = scalar; - scale_value = array[idx]; - } else { - shape_value = array[idx]; - scale_value = scalar; - } - // map phase - GammaTransform(shape_value, scale_value, - uniforms + i * M, normals + i * M); - // reduce phase - OType sample = (OType)GammaReduce(shape_value, uniforms + i * M); - out[i] = sample; - } -}; - -template -struct gamma_one_scalar_kernel_r { - MSHADOW_XINLINE static void Map(index_t i, int scalar_pos, - const Shape &stride, - const Shape &oshape, - IType *array, float scalar, - FType *uniforms, FType *normals, - OType *out, FType *flag) { - flag[0] = 1; - Shape coord = unravel(i, oshape); - auto idx = static_cast(dot(coord, stride)); - IType shape_value; - IType scale_value; - if (scalar_pos == 0) { - shape_value = scalar; - scale_value = array[idx]; - } else { - shape_value = array[idx]; - scale_value = scalar; - } - if (out[i] < 0) { - // map phase - GammaTransform(shape_value, scale_value, - uniforms + i * M, normals + i * M); - // reduce phase - OType sample = (OType)GammaReduce(shape_value, uniforms + i * M); - out[i] = sample; + const Shape &stride, + const Shape &oshape, IType *array, + float scalar, FType *uniforms, FType *normals, + OType *out, FType *flag = nullptr) { + // We know the sampling procedure is in its first stage, if `flag` is + // nullptr, i.e. there is no need for reseting the indicator + // variable(flag[0] = 1) nor checking whether a specific element is sampled + // successfully (out[i] < 0). + bool in_first_stage = (flag == nullptr); + if (!in_first_stage) { + flag[0] = 1; + } + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + IType shape_value; + IType scale_value; + if (scalar_pos == 0) { + shape_value = scalar; + scale_value = array[idx]; + } else { + shape_value = array[idx]; + scale_value = scalar; + } + if (in_first_stage || out[i] < 0) { + // map phase + GammaTransform(shape_value, scale_value, uniforms + i * M, + normals + i * M); + // reduce phase + OType sample = + (OType)GammaReduce(shape_value, uniforms + i * M); + out[i] = sample; } } }; @@ -234,46 +197,34 @@ template struct gamma_two_scalar_kernel { MSHADOW_XINLINE static void Map(index_t i, float shape_value, float scale_value, FType *uniforms_origin, - FType *normals_origin, OType *out) { - // map phase - FType *uniforms = uniforms_origin + i * M; - FType *normals = normals_origin + i * M; - GammaTransform(shape_value, scale_value, uniforms, - normals); - // reduce phase - OType sample = - (OType)GammaReduce(shape_value, uniforms); - out[i] = sample; + FType *normals_origin, OType *out, + FType *flag = nullptr) { + // We know the sampling procedure is in its first stage, if `flag` is + // nullptr, i.e. there is no need for reseting the indicator + // variable(flag[0] = 1) nor checking whether a specific element is sampled + // successfully (out[i] < 0). + bool in_first_stage = (flag == nullptr); + if (!in_first_stage) { + flag[0] = 1; } -}; - -template -struct gamma_two_scalar_kernel_r { - MSHADOW_XINLINE static void Map(index_t i, float shape_value, - float scale_value, - FType *uniforms_origin, - FType *normals_origin, OType *out, FType* flag) { - flag[0] = 1; - // map phase FType *uniforms = uniforms_origin + i * M; FType *normals = normals_origin + i * M; - if (out[i] < 0) { - GammaTransform(shape_value, scale_value, uniforms, - normals); - // reduce phase - OType sample = - (OType)GammaReduce(shape_value, uniforms); + if (in_first_stage || out[i] < 0) { + // map phase + GammaTransform(shape_value, scale_value, uniforms, normals); + // reduce phase + OType sample = (OType)GammaReduce(shape_value, uniforms); out[i] = sample; - } } + } }; } // namespace mxnet_op template void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { using namespace mshadow; using namespace mxnet_op; const NumpyGammaParam ¶m = nnvm::get(attrs.parsed); @@ -295,31 +246,24 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, FType *failure_indicator_device = failure_indic_workspace.dptr_; // [scalar scalar] case if (inputs.size() == 0U) { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { - Kernel, xpu>::Launch( - s, outputs[0].Size(), param.shape.value(), param.scale.value(), - uniform_tensor.dptr_, normal_tensor.dptr_, outputs[0].dptr()); - Kernel, xpu>::Launch( - s, outputs[0].Size(), outputs[0].dptr(), - failure_indicator_device); - _copy(s, &failure_indicator, failure_indicator_device); - while (1) { - if (failure_indicator >= 0) { - break; - } else { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + bool in_resample_stage = false; + do { + if (in_resample_stage) { prnd->SampleUniform(&uniform_tensor, 0, 1); prnd->SampleGaussian(&normal_tensor, 0, 1); - Kernel, xpu>::Launch( - s, outputs[0].Size(), param.shape.value(), param.scale.value(), - uniform_tensor.dptr_, normal_tensor.dptr_, - outputs[0].dptr(), failure_indicator_device); - failure_indicator = 1.0; - Kernel, xpu>::Launch( - s, outputs[0].Size(), outputs[0].dptr(), - failure_indicator_device); - _copy(s, &failure_indicator, failure_indicator_device); } - } + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.shape.value(), param.scale.value(), + uniform_tensor.dptr_, normal_tensor.dptr_, outputs[0].dptr(), + in_resample_stage ? failure_indicator_device : nullptr); + failure_indicator = 1.0; + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); + _copy(s, &failure_indicator, failure_indicator_device); + in_resample_stage = true; + } while (failure_indicator < 0); }); } else if (inputs.size() == 1U) { // [scalar tensor], [tensor scalar] case @@ -327,7 +271,6 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, &new_lshape, &new_lshape, &new_oshape); int scalar_pos; float scalar_value; - // int type_flag = param.t; if (param.shape.has_value()) { scalar_pos = 0; scalar_value = param.shape.value(); @@ -335,39 +278,31 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, scalar_pos = 1; scalar_value = param.scale.value(); } - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape stride = - mxnet_op::calc_stride(new_lshape.get()); - mxnet_op::Kernel, xpu>::Launch( - s, outputs[0].Size(), scalar_pos, stride, oshape, - inputs[0].dptr(), scalar_value, - uniform_tensor.dptr_, normal_tensor.dptr_, - outputs[0].dptr()); - Kernel, xpu>::Launch( - s, outputs[0].Size(), outputs[0].dptr(), - failure_indicator_device); - _copy(s, &failure_indicator, failure_indicator_device); - while (1) { - if (failure_indicator >= 0) { - break; - } else { - prnd->SampleUniform(&uniform_tensor, 0, 1); - prnd->SampleGaussian(&normal_tensor, 0, 1); - mxnet_op::Kernel, xpu>::Launch( - s, outputs[0].Size(), scalar_pos, stride, oshape, - inputs[0].dptr(), scalar_value, - uniform_tensor.dptr_, normal_tensor.dptr_, - outputs[0].dptr(), failure_indicator_device); - failure_indicator = 1.0; - Kernel, xpu>::Launch( - s, outputs[0].Size(), outputs[0].dptr(), - failure_indicator_device); - _copy(s, &failure_indicator, failure_indicator_device); - } - } + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape stride = + mxnet_op::calc_stride(new_lshape.get()); + bool in_resample_stage = false; + do { + if (in_resample_stage) { + prnd->SampleUniform(&uniform_tensor, 0, 1); + prnd->SampleGaussian(&normal_tensor, 0, 1); + } + mxnet_op::Kernel,xpu>::Launch( + s, outputs[0].Size(), scalar_pos, stride, oshape, + inputs[0].dptr(), scalar_value, + uniform_tensor.dptr_, normal_tensor.dptr_, + outputs[0].dptr(), + in_resample_stage ? failure_indicator_device : nullptr); + failure_indicator = 1.0; + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); + _copy(s, &failure_indicator, failure_indicator_device); + in_resample_stage = true; + } while (failure_indicator < 0); }); }); }); @@ -375,41 +310,34 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, // [tensor tensor] case int ndim = FillShape(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_hshape, &new_oshape); - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); mshadow::Shape hstride = mxnet_op::calc_stride(new_hshape.get()); - mxnet_op::Kernel, xpu>::Launch( - s, outputs[0].Size(), lstride, hstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), - uniform_tensor.dptr_, normal_tensor.dptr_, - outputs[0].dptr()); - Kernel, xpu>::Launch( - s, outputs[0].Size(), outputs[0].dptr(), - failure_indicator_device); - _copy(s, &failure_indicator, failure_indicator_device); - while (1) { - if (failure_indicator >= 0) { - break; - } else { + bool in_resample_stage = false; + do { + if (in_resample_stage) { + prnd->SampleUniform(&uniform_tensor, 0, 1); + prnd->SampleGaussian(&normal_tensor, 0, 1); + } prnd->SampleUniform(&uniform_tensor, 0, 1); prnd->SampleGaussian(&normal_tensor, 0, 1); - mxnet_op::Kernel, xpu>::Launch( + mxnet_op::Kernel, xpu>::Launch( s, outputs[0].Size(), lstride, hstride, oshape, inputs[0].dptr(), inputs[1].dptr(), uniform_tensor.dptr_, normal_tensor.dptr_, - outputs[0].dptr(), failure_indicator_device); + outputs[0].dptr(), in_resample_stage ? failure_indicator_device : nullptr); failure_indicator = 1.0; Kernel, xpu>::Launch( s, outputs[0].Size(), outputs[0].dptr(), failure_indicator_device); _copy(s, &failure_indicator, failure_indicator_device); - } - } + in_resample_stage = true; + } while (failure_indicator < 0); }); }); }); From ed387372b67c08fc8819e66e0190d698637155ef Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Thu, 9 Jan 2020 05:23:40 +0000 Subject: [PATCH 5/6] fix lint erropr --- src/operator/numpy/random/np_gamma_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index 7419838664a6..0db0d86e09ff 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -290,7 +290,7 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, prnd->SampleUniform(&uniform_tensor, 0, 1); prnd->SampleGaussian(&normal_tensor, 0, 1); } - mxnet_op::Kernel,xpu>::Launch( + mxnet_op::Kernel, xpu>::Launch( s, outputs[0].Size(), scalar_pos, stride, oshape, inputs[0].dptr(), scalar_value, uniform_tensor.dptr_, normal_tensor.dptr_, From 3fdbb3ad9ab0379a07ce2a4a2d37cfe768993207 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Thu, 9 Jan 2020 08:55:49 +0000 Subject: [PATCH 6/6] fix style --- src/operator/numpy/random/np_gamma_op.h | 41 ++++++++++++------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index 0db0d86e09ff..83e8f1f5242b 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -254,13 +254,13 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, prnd->SampleGaussian(&normal_tensor, 0, 1); } Kernel, xpu>::Launch( - s, outputs[0].Size(), param.shape.value(), param.scale.value(), - uniform_tensor.dptr_, normal_tensor.dptr_, outputs[0].dptr(), - in_resample_stage ? failure_indicator_device : nullptr); + s, outputs[0].Size(), param.shape.value(), param.scale.value(), + uniform_tensor.dptr_, normal_tensor.dptr_, outputs[0].dptr(), + in_resample_stage ? failure_indicator_device : nullptr); failure_indicator = 1.0; Kernel, xpu>::Launch( - s, outputs[0].Size(), outputs[0].dptr(), - failure_indicator_device); + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); _copy(s, &failure_indicator, failure_indicator_device); in_resample_stage = true; } while (failure_indicator < 0); @@ -282,24 +282,23 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape stride = - mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape stride = calc_stride(new_lshape.get()); bool in_resample_stage = false; do { if (in_resample_stage) { prnd->SampleUniform(&uniform_tensor, 0, 1); prnd->SampleGaussian(&normal_tensor, 0, 1); } - mxnet_op::Kernel, xpu>::Launch( - s, outputs[0].Size(), scalar_pos, stride, oshape, - inputs[0].dptr(), scalar_value, - uniform_tensor.dptr_, normal_tensor.dptr_, - outputs[0].dptr(), - in_resample_stage ? failure_indicator_device : nullptr); + Kernel, xpu>::Launch( + s, outputs[0].Size(), scalar_pos, stride, oshape, + inputs[0].dptr(), scalar_value, + uniform_tensor.dptr_, normal_tensor.dptr_, + outputs[0].dptr(), + in_resample_stage ? failure_indicator_device : nullptr); failure_indicator = 1.0; Kernel, xpu>::Launch( - s, outputs[0].Size(), outputs[0].dptr(), - failure_indicator_device); + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); _copy(s, &failure_indicator, failure_indicator_device); in_resample_stage = true; } while (failure_indicator < 0); @@ -314,10 +313,8 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = - mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape hstride = - mxnet_op::calc_stride(new_hshape.get()); + mshadow::Shape lstride = calc_stride(new_lshape.get()); + mshadow::Shape hstride = calc_stride(new_hshape.get()); bool in_resample_stage = false; do { if (in_resample_stage) { @@ -326,15 +323,15 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } prnd->SampleUniform(&uniform_tensor, 0, 1); prnd->SampleGaussian(&normal_tensor, 0, 1); - mxnet_op::Kernel, xpu>::Launch( + Kernel, xpu>::Launch( s, outputs[0].Size(), lstride, hstride, oshape, inputs[0].dptr(), inputs[1].dptr(), uniform_tensor.dptr_, normal_tensor.dptr_, outputs[0].dptr(), in_resample_stage ? failure_indicator_device : nullptr); failure_indicator = 1.0; Kernel, xpu>::Launch( - s, outputs[0].Size(), outputs[0].dptr(), - failure_indicator_device); + s, outputs[0].Size(), outputs[0].dptr(), + failure_indicator_device); _copy(s, &failure_indicator, failure_indicator_device); in_resample_stage = true; } while (failure_indicator < 0);