From d84a34be3e189a0325da0514bb86af937780fd6e Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Mon, 12 Aug 2019 09:07:08 +0000 Subject: [PATCH 1/3] numpy behavior uniform sampling implemented --- python/mxnet/ndarray/numpy/random.py | 62 +++++- python/mxnet/numpy/random.py | 37 +++- python/mxnet/symbol/numpy/random.py | 62 +++++- src/operator/numpy/random/dist_common.h | 180 ++++++++++++++++ src/operator/numpy/random/np_uniform_op.cc | 63 ++++++ src/operator/numpy/random/np_uniform_op.cu | 35 ++++ src/operator/numpy/random/np_uniform_op.h | 218 ++++++++++++++++++++ tests/python/unittest/test_numpy_ndarray.py | 36 ++++ 8 files changed, 690 insertions(+), 3 deletions(-) create mode 100644 src/operator/numpy/random/dist_common.h create mode 100644 src/operator/numpy/random/np_uniform_op.cc create mode 100644 src/operator/numpy/random/np_uniform_op.cu create mode 100644 src/operator/numpy/random/np_uniform_op.h diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 339fb1e77920..0fc2de6a9c8e 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -17,5 +17,65 @@ """Namespace for operators used in Gluon dispatched by F=ndarray.""" from __future__ import absolute_import +from ...context import current_context +from . import _internal as _npi -__all__ = [] +__all__ = ['uniform'] + + +def uniform(low=0.0, high=1.0, size=None, ctx=None, dtype=None, out=None): + """Draw samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval + ``[low, high)`` (includes low, but excludes high). In other words, + any value within the given interval is equally likely to be drawn + by `uniform`. + + Parameters + ---------- + low : float, ndarray, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float, ndarray, optional + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized uniform distribution. + """ + from ...numpy import ndarray as np_ndarray + input_type = (isinstance(low, np_ndarray), isinstance(high, np_ndarray)) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if out is not None: + size = out.shape + if size == (): + size = None + if input_type == (True, True): + return _npi.uniform(low, high, low=None, high=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (False, True): + return _npi.uniform(high, low=low, high=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (True, False): + return _npi.uniform(low, low=None, high=high, size=size, + ctx=ctx, dtype=dtype, out=out) + else: + return _npi.uniform(low=low, high=high, size=size, + ctx=ctx, dtype=dtype, out=out) + + raise ValueError( + "Distribution parameters must be either mxnet.numpy.ndarray or numbers") diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index c4109378e146..feea13e2fd5f 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -18,5 +18,40 @@ """Namespace for ops used in imperative programming.""" from __future__ import absolute_import +from ..ndarray import numpy as _mx_nd_np -__all__ = [] +__all__ = ['uniform'] + + +def uniform(low=0.0, high=1.0, size=None, ctx=None, dtype=None, out=None): + """Draw samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval + ``[low, high)`` (includes low, but excludes high). In other words, + any value within the given interval is equally likely to be drawn + by `uniform`. + + Parameters + ---------- + low : float, ndarray, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float, ndarray, optional + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized uniform distribution. + """ + return _mx_nd_np.random.uniform(low, high, size=size, ctx=ctx, dtype=dtype, out=out) diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 28cfd0f3806a..40e6e8d2d603 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -18,5 +18,65 @@ """Namespace for operators used in Gluon dispatched by F=symbol.""" from __future__ import absolute_import +from ...context import current_context +from . import _internal as _npi -__all__ = [] +__all__ = ['uniform'] + + +def uniform(low=0.0, high=1.0, size=None, ctx=None, dtype=None, out=None): + """Draw samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval + ``[low, high)`` (includes low, but excludes high). In other words, + any value within the given interval is equally likely to be drawn + by `uniform`. + + Parameters + ---------- + low : float, ndarray, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float, ndarray, optional + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized uniform distribution. + """ + from ._symbol import _Symbol as np_symbol + input_type = (isinstance(low, np_symbol), isinstance(high, np_symbol)) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if out is not None: + size = out.shape + if size == (): + size = None + if input_type == (True, True): + return _npi.uniform(low, high, low=None, high=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (False, True): + return _npi.uniform(high, low=low, high=None, size=size, + ctx=ctx, dtype=dtype, out=out) + elif input_type == (True, False): + return _npi.uniform(low, low=None, high=high, size=size, + ctx=ctx, dtype=dtype, out=out) + else: + return _npi.uniform(low=low, high=high, size=size, + ctx=ctx, dtype=dtype, out=out) + + raise ValueError( + "Distribution parameters must be either mxnet.numpy.ndarray or numbers") diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h new file mode 100644 index 000000000000..c8978b5b380c --- /dev/null +++ b/src/operator/numpy/random/dist_common.h @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file etwoparams_dist_common.h + * \brief Function definition of common functions for distributions + * \with two parameters. + */ + +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_DIST_COMMON_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_DIST_COMMON_H_ + +#include +#include +#include +#include +#include +#include "../../elemwise_op_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" + +namespace mxnet { +namespace op { + +inline int FillShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape, + const mxnet::TShape &oshape, mxnet::TShape *new_lshape, + mxnet::TShape *new_rshape, mxnet::TShape *new_oshape) { + const int odim = std::max(oshape.ndim(), broadcast::MAX_DIM); + *new_lshape = mxnet::TShape(odim, 1); + *new_rshape = mxnet::TShape(odim, 1); + *new_oshape = mxnet::TShape(odim, 1); + int bl = oshape.ndim() - lshape.ndim(); + int br = oshape.ndim() - rshape.ndim(); + int j = 0, lprod = 1, rprod = 1, oprod = 1; + for (int i = 0; i < oshape.ndim(); ++i) { + int l = 1; + int r = 1; + int o = oshape[i]; + if (i >= bl) l = lshape[i - bl]; + if (i >= br) r = rshape[i - br]; + if ((lprod != rprod || lprod != oprod || l != r || l != o) && + (lprod * l > 1 || rprod * r > 1 || oprod * o > 1)) { + (*new_lshape)[j] = lprod; + (*new_rshape)[j] = rprod; + (*new_oshape)[j] = oprod; + lprod = rprod = oprod = 1; ++j; + } + lprod *= l; + rprod *= r; + oprod *= o; + } + if (lprod > 1 || rprod > 1 || oprod > 1) { + (*new_lshape)[j] = lprod; + (*new_rshape)[j] = rprod; + (*new_oshape)[j] = oprod; + ++j; + } + if (j <= broadcast::MAX_DIM) { + BROADCAST_NDIM_SWITCH(j, NDim, { + new_lshape->assign(new_lshape->begin(), new_lshape->begin() + NDim); + new_rshape->assign(new_rshape->begin(), new_rshape->begin() + NDim); + new_oshape->assign(new_oshape->begin(), new_oshape->begin() + NDim); + }); + } else { + LOG(FATAL) << "Too many broadcast dimensions with operands " << lshape << " " << rshape; + } + return j; +} + +inline void CheckBroadcastable(const mxnet::TShape &from, const mxnet::TShape &to) { + const int bl = to.ndim() - from.ndim(); + const int br = 0; + for (int i = 0; i < to.ndim(); ++i) { + int l = 1, r = 1; + if (i >= bl) + l = from[i - bl]; + if (i >= br) + r = to[i - br]; + if (!mxnet::dim_size_is_known(l) || !mxnet::dim_size_is_known(r)) + continue; + if (l != r) { + // Make it compatible with NumPy. + // For example, (2, 3) cannot broadcast to (2, 0, 3), but (1, 3) can + // broadcast to (2, 0, 3). + CHECK(l == 1 || r == 1) + << "operands could not be broadcast together with shapes " << from + << " " << to; + } + } +} + +inline void InferBroadcastShape(const mxnet::TShape &lhs, const mxnet::TShape &rhs, + mxnet::TShape* out_ptr) { + mxnet::TShape& out = (*out_ptr); + const int bl = out.ndim() - lhs.ndim(); + const int br = out.ndim() - rhs.ndim(); + for (int i = 0; i < out.ndim(); ++i) { + int l = 1, r = 1; + if (i >= bl) + l = lhs[i - bl]; + if (i >= br) + r = rhs[i - br]; + if (!mxnet::dim_size_is_known(l) || !mxnet::dim_size_is_known(r)) + continue; + if (l != r) { + // Make it compatible with NumPy. + // For example, (2, 3) cannot broadcast to (2, 0, 3), but (1, 3) can + // broadcast to (2, 0, 3). + CHECK(l == 1 || r == 1) + << "operands could not be broadcast together with shapes " << lhs + << " " << rhs; + out[i] = (l == 1 ? r : l); + } else { + out[i] = l; + } + } +} + +template +inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const DistParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size(), 1U); + if (param.size.has_value()) { + // Size declared. + std::vector oshape_vec; + const mxnet::Tuple &size = param.size.value(); + for (int i = 0; i < size.ndim(); ++i) { + oshape_vec.emplace_back(size[i]); + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec)); + for (size_t input_idx = 0; input_idx < in_attrs->size(); input_idx++) { + CheckBroadcastable((*in_attrs)[input_idx], (*out_attrs)[0]); + } + } else { + // Size undeclared. + if (in_attrs->size() == 2U) { + // Both params from ndarray. + mxnet::TShape& low = (*in_attrs)[0]; + mxnet::TShape& high = (*in_attrs)[1]; + mxnet::TShape out(std::max(low.ndim(), high.ndim()), -1); + InferBroadcastShape(low, high, &out); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out); + } else if (in_attrs->size() == 1U) { + // One param from ndarray. + SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)) + } else if (in_attrs->size() == 0) { + // Two scalar case. + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) + return true; + } + } + return out_attrs->at(0).ndim() != 0U; +} + + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_DIST_COMMON_H_ */ diff --git a/src/operator/numpy/random/np_uniform_op.cc b/src/operator/numpy/random/np_uniform_op.cc new file mode 100644 index 000000000000..01e7d9c4d8de --- /dev/null +++ b/src/operator/numpy/random/np_uniform_op.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_uniform_op.h + * \brief Operator for numpy sampling from uniform distributions + */ +#include "./np_uniform_op.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyUniformParam); + +NNVM_REGISTER_OP(_npi_uniform) +.describe("numpy behavior uniform") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyUniformParam& param = nnvm::get(attrs.parsed); + int num_inputs = 2; + if (param.low.has_value()) num_inputs -= 1; + if (param.high.has_value()) num_inputs -= 1; + return num_inputs; + } +) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"input1", "input2"}; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", TwoparamsDistOpShape) +.set_attr("FInferType", NumpyUniformOpType) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyUniformForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_argument("input2", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyUniformParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_uniform_op.cu b/src/operator/numpy/random/np_uniform_op.cu new file mode 100644 index 000000000000..befaa407804e --- /dev/null +++ b/src/operator/numpy/random/np_uniform_op.cu @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_uniform_op.cu + * \brief Operator for numpy sampling from uniform distributions + */ + +#include "./np_uniform_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_uniform) +.set_attr("FCompute", NumpyUniformForward); + +} +} diff --git a/src/operator/numpy/random/np_uniform_op.h b/src/operator/numpy/random/np_uniform_op.h new file mode 100644 index 000000000000..24032dd0bc5d --- /dev/null +++ b/src/operator/numpy/random/np_uniform_op.h @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_uniform_op.h + * \brief Operator for numpy sampling from uniform distributions + */ +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ + +#include +#include +#include +#include +#include +#include "./dist_common.h" +#include "../../elemwise_op_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" + +namespace mxnet { +namespace op { + +struct NumpyUniformParam : public dmlc::Parameter { + dmlc::optional low; + dmlc::optional high; + std::string ctx; + int dtype; + dmlc::optional> size; + DMLC_DECLARE_PARAMETER(NumpyUniformParam) { + DMLC_DECLARE_FIELD(low); + DMLC_DECLARE_FIELD(high); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional>()) + .describe("Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + DMLC_DECLARE_FIELD(ctx) + .set_default("cpu") + .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." + " Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .set_default(mshadow::kFloat32) + .describe("DType of the output in case this can't be inferred. " + "Defaults to float32 if not defined (dtype=None)."); + } +}; + +inline bool NumpyUniformOpType(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const NumpyUniformParam ¶m = nnvm::get(attrs.parsed); + int otype = param.dtype; + if (otype != -1) { + (*out_attrs)[0] = otype; + } else { + (*out_attrs)[0] = mshadow::kFloat32; + } + return true; +} + +namespace mxnet_op { +template +struct uniform_kernel { + MSHADOW_XINLINE static void Map(index_t i, + const Shape &lstride, const Shape &hstride, + const Shape &oshape, + IType *low, IType *high, + float *uniform, OType *out) { + Shape coord = unravel(i, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto hidx = static_cast(dot(coord, hstride)); + IType low_value = low[lidx]; + IType high_value = high[hidx]; + out[i] = low_value + uniform[i] * (high_value - low_value); + } +}; +} // namespace mxnet_op + +namespace mxnet_op { +template +struct uniform_one_scalar_kernel { + MSHADOW_XINLINE static void Map(index_t i, int scalar_pos, + const Shape &stride, + const Shape &oshape, + IType *array, float scalar, + float *uniform, OType *out) { + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + IType low_value; + IType high_value; + if (scalar_pos == 0) { + low_value = scalar; + high_value = array[idx]; + } else { + low_value = array[idx]; + high_value = scalar; + } + out[i] = low_value + uniform[i] * (high_value - low_value); + } +}; +} // namespace mxnet_op + +namespace mxnet_op { +template +struct uniform_two_scalar_kernel { + MSHADOW_XINLINE static void Map(index_t i, + float low, float high, + float *uniform, OType *out) { + out[i] = low + uniform[i] * (high - low); + } +}; +} // namespace mxnet_op + + + + +template +void NumpyUniformForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyUniformParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(outputs.size(), 1); + Stream *s = ctx.get_stream(); + + // Generate base random number. + Random *prnd = ctx.requested[0].get_random(s); + Tensor uniform_tensor = + ctx.requested[1].get_space_typed(Shape1(outputs[0].Size()), s); + prnd->SampleUniform(&uniform_tensor, 0, 1); + mxnet::TShape new_lshape, new_hshape, new_oshape; + + // [scalar scalar] case + if (inputs.size() == 0U) { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + mxnet_op::Kernel, xpu>::Launch( + s, outputs[0].Size(), + param.low.value(), param.high.value(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + } else if (inputs.size() == 1U) { + // [scalar tensor], [tensor scalar] case + int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, + &new_lshape, &new_lshape, &new_oshape); + int scalar_pos; + float scalar_value; + // int type_flag = param.t; + if (param.low.has_value()) { + scalar_pos = 0; + scalar_value = param.low.value(); + } else { + scalar_pos = 1; + scalar_value = param.high.value(); + } + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape stride = + mxnet_op::calc_stride(new_lshape.get()); + mxnet_op::Kernel, xpu>::Launch( + s, outputs[0].Size(), scalar_pos, stride, oshape, + inputs[0].dptr(), scalar_value, + uniform_tensor.dptr_, outputs[0].dptr()); + }); + }); + }); + } else if (inputs.size() == 2U) { + // [tensor tensor] case + int ndim = FillShape(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_hshape, &new_oshape); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = + mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape hstride = + mxnet_op::calc_stride(new_hshape.get()); + mxnet_op::Kernel, xpu>::Launch( + s, outputs[0].Size(), lstride, hstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + }); + }); + } +} + +}; // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 6024ac9b4acd..7fa05206ac09 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -25,6 +25,8 @@ from mxnet.gluon import HybridBlock from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, assert_exception, use_np from common import with_seed, TemporaryDirectory +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf +import scipy.stats as ss @with_seed() @@ -667,6 +669,40 @@ def test_np_save_load_ndarrays(): assert _np.array_equal(v.asnumpy(), arr_dict[k].asnumpy()) +@retry(5) +@with_seed() +@use_np +def test_np_uniform(): + types = [None, "float32", "float64"] + ctx = mx.context.current_context() + samples = 1000000 + # Generation test + trials = 8 + num_buckets = 5 + for dtype in types: + for low, high in [(-100.0, -98.0), (99.0, 101.0)]: + scale = high - low + buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), num_buckets) + buckets = np.array(buckets, dtype=dtype).tolist() + probs = [(buckets[i][1] - buckets[i][0])/scale for i in range(num_buckets)] + generator_mx_np = lambda x: mx.np.random.uniform(low, high, size=x, ctx=ctx, dtype=dtype).asnumpy() + verify_generator(generator=generator_mx_np, buckets=buckets, probs=probs, nsamples=samples, nrepeat=trials) + + # Broadcasting test + params = [ + (1.0, mx.np.ones((4,4)) + 2.0), + (mx.np.zeros((4,4)) + 1, 2.0), + (mx.np.zeros((1,4)), mx.np.ones((4,4)) + mx.np.array([1, 2, 3, 4])), + (mx.np.array([1, 2, 3, 4]), mx.np.ones((2,4,4)) * 5) + ] + for dtype in types: + for low, high in params: + expect_mean = (low + high) / 2 + expanded_size = (samples,) + expect_mean.shape + uniform_samples = mx.np.random.uniform(low, high, size=expanded_size, dtype=dtype) + mx.test_utils.assert_almost_equal(uniform_samples.asnumpy().mean(0), expect_mean.asnumpy(), rtol=0.20, atol=1e-1) + + if __name__ == '__main__': import nose nose.runmodule() From bf195e9216954317929c75ba9f1186d20ac99ab2 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 13 Aug 2019 03:09:36 +0000 Subject: [PATCH 2/3] improve code style according to review comment --- python/mxnet/ndarray/numpy/random.py | 2 +- python/mxnet/numpy/random.py | 2 +- python/mxnet/symbol/numpy/random.py | 2 +- src/operator/numpy/random/dist_common.h | 61 ++++----- src/operator/numpy/random/np_uniform_op.cc | 2 +- src/operator/numpy/random/np_uniform_op.cu | 4 +- src/operator/numpy/random/np_uniform_op.h | 151 ++++++++++----------- 7 files changed, 107 insertions(+), 117 deletions(-) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 0fc2de6a9c8e..be918615bfd9 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -23,7 +23,7 @@ __all__ = ['uniform'] -def uniform(low=0.0, high=1.0, size=None, ctx=None, dtype=None, out=None): +def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): """Draw samples from a uniform distribution. Samples are uniformly distributed over the half-open interval diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index feea13e2fd5f..f85936345b7f 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -23,7 +23,7 @@ __all__ = ['uniform'] -def uniform(low=0.0, high=1.0, size=None, ctx=None, dtype=None, out=None): +def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): """Draw samples from a uniform distribution. Samples are uniformly distributed over the half-open interval diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 40e6e8d2d603..338a5e28be4e 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -24,7 +24,7 @@ __all__ = ['uniform'] -def uniform(low=0.0, high=1.0, size=None, ctx=None, dtype=None, out=None): +def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): """Draw samples from a uniform distribution. Samples are uniformly distributed over the half-open interval diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h index c8978b5b380c..42d7507f7e74 100644 --- a/src/operator/numpy/random/dist_common.h +++ b/src/operator/numpy/random/dist_common.h @@ -19,7 +19,7 @@ /*! * Copyright (c) 2015 by Contributors - * \file etwoparams_dist_common.h + * \file dist_common.h * \brief Function definition of common functions for distributions * \with two parameters. */ @@ -27,16 +27,16 @@ #ifndef MXNET_OPERATOR_NUMPY_RANDOM_DIST_COMMON_H_ #define MXNET_OPERATOR_NUMPY_RANDOM_DIST_COMMON_H_ -#include #include -#include -#include +#include #include +#include +#include #include "../../elemwise_op_common.h" -#include "../../tensor/elemwise_binary_broadcast_op.h" #include "../../mshadow_op.h" #include "../../mxnet_op.h" #include "../../operator_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" namespace mxnet { namespace op { @@ -55,14 +55,15 @@ inline int FillShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape, int l = 1; int r = 1; int o = oshape[i]; - if (i >= bl) l = lshape[i - bl]; - if (i >= br) r = rshape[i - br]; + if (i >= bl) l = lshape[i - bl]; + if (i >= br) r = rshape[i - br]; if ((lprod != rprod || lprod != oprod || l != r || l != o) && (lprod * l > 1 || rprod * r > 1 || oprod * o > 1)) { (*new_lshape)[j] = lprod; (*new_rshape)[j] = rprod; (*new_oshape)[j] = oprod; - lprod = rprod = oprod = 1; ++j; + lprod = rprod = oprod = 1; + ++j; } lprod *= l; rprod *= r; @@ -81,22 +82,21 @@ inline int FillShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape, new_oshape->assign(new_oshape->begin(), new_oshape->begin() + NDim); }); } else { - LOG(FATAL) << "Too many broadcast dimensions with operands " << lshape << " " << rshape; + LOG(FATAL) << "Too many broadcast dimensions with operands " << lshape + << " " << rshape; } return j; } -inline void CheckBroadcastable(const mxnet::TShape &from, const mxnet::TShape &to) { +inline void CheckBroadcastable(const mxnet::TShape &from, + const mxnet::TShape &to) { const int bl = to.ndim() - from.ndim(); const int br = 0; for (int i = 0; i < to.ndim(); ++i) { int l = 1, r = 1; - if (i >= bl) - l = from[i - bl]; - if (i >= br) - r = to[i - br]; - if (!mxnet::dim_size_is_known(l) || !mxnet::dim_size_is_known(r)) - continue; + if (i >= bl) l = from[i - bl]; + if (i >= br) r = to[i - br]; + if (!mxnet::dim_size_is_known(l) || !mxnet::dim_size_is_known(r)) continue; if (l != r) { // Make it compatible with NumPy. // For example, (2, 3) cannot broadcast to (2, 0, 3), but (1, 3) can @@ -108,19 +108,17 @@ inline void CheckBroadcastable(const mxnet::TShape &from, const mxnet::TShape &t } } -inline void InferBroadcastShape(const mxnet::TShape &lhs, const mxnet::TShape &rhs, - mxnet::TShape* out_ptr) { - mxnet::TShape& out = (*out_ptr); +inline void InferBroadcastShape(const mxnet::TShape &lhs, + const mxnet::TShape &rhs, + mxnet::TShape *out_ptr) { + mxnet::TShape &out = (*out_ptr); const int bl = out.ndim() - lhs.ndim(); const int br = out.ndim() - rhs.ndim(); for (int i = 0; i < out.ndim(); ++i) { int l = 1, r = 1; - if (i >= bl) - l = lhs[i - bl]; - if (i >= br) - r = rhs[i - br]; - if (!mxnet::dim_size_is_known(l) || !mxnet::dim_size_is_known(r)) - continue; + if (i >= bl) l = lhs[i - bl]; + if (i >= br) r = rhs[i - br]; + if (!mxnet::dim_size_is_known(l) || !mxnet::dim_size_is_known(r)) continue; if (l != r) { // Make it compatible with NumPy. // For example, (2, 3) cannot broadcast to (2, 0, 3), but (1, 3) can @@ -135,10 +133,10 @@ inline void InferBroadcastShape(const mxnet::TShape &lhs, const mxnet::TShape &r } } -template +template inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector *in_attrs, + std::vector *out_attrs) { const DistParam ¶m = nnvm::get(attrs.parsed); CHECK_EQ(out_attrs->size(), 1U); if (param.size.has_value()) { @@ -156,8 +154,8 @@ inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs, // Size undeclared. if (in_attrs->size() == 2U) { // Both params from ndarray. - mxnet::TShape& low = (*in_attrs)[0]; - mxnet::TShape& high = (*in_attrs)[1]; + mxnet::TShape &low = (*in_attrs)[0]; + mxnet::TShape &high = (*in_attrs)[1]; mxnet::TShape out(std::max(low.ndim(), high.ndim()), -1); InferBroadcastShape(low, high, &out); SHAPE_ASSIGN_CHECK(*out_attrs, 0, out); @@ -165,7 +163,7 @@ inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs, // One param from ndarray. SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)) } else if (in_attrs->size() == 0) { - // Two scalar case. + // Two scalar case. SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) return true; } @@ -173,7 +171,6 @@ inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs, return out_attrs->at(0).ndim() != 0U; } - } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_uniform_op.cc b/src/operator/numpy/random/np_uniform_op.cc index 01e7d9c4d8de..394626d07596 100644 --- a/src/operator/numpy/random/np_uniform_op.cc +++ b/src/operator/numpy/random/np_uniform_op.cc @@ -19,7 +19,7 @@ /*! * Copyright (c) 2019 by Contributors - * \file np_uniform_op.h + * \file np_uniform_op.cc * \brief Operator for numpy sampling from uniform distributions */ #include "./np_uniform_op.h" diff --git a/src/operator/numpy/random/np_uniform_op.cu b/src/operator/numpy/random/np_uniform_op.cu index befaa407804e..d997bc57d3be 100644 --- a/src/operator/numpy/random/np_uniform_op.cu +++ b/src/operator/numpy/random/np_uniform_op.cu @@ -31,5 +31,5 @@ namespace op { NNVM_REGISTER_OP(_npi_uniform) .set_attr("FCompute", NumpyUniformForward); -} -} +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_uniform_op.h b/src/operator/numpy/random/np_uniform_op.h index 24032dd0bc5d..a664bb5a8e71 100644 --- a/src/operator/numpy/random/np_uniform_op.h +++ b/src/operator/numpy/random/np_uniform_op.h @@ -25,17 +25,17 @@ #ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ #define MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ -#include #include -#include -#include +#include #include -#include "./dist_common.h" +#include +#include #include "../../elemwise_op_common.h" -#include "../../tensor/elemwise_binary_broadcast_op.h" #include "../../mshadow_op.h" #include "../../mxnet_op.h" #include "../../operator_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "./dist_common.h" namespace mxnet { namespace op { @@ -51,26 +51,27 @@ struct NumpyUniformParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(high); DMLC_DECLARE_FIELD(size) .set_default(dmlc::optional>()) - .describe("Output shape. If the given shape is, " - "e.g., (m, n, k), then m * n * k samples are drawn. " - "Default is None, in which case a single value is returned."); - DMLC_DECLARE_FIELD(ctx) - .set_default("cpu") - .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." - " Only used for imperative calls."); + .describe( + "Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + DMLC_DECLARE_FIELD(ctx).set_default("cpu").describe( + "Context of output, in format [cpu|gpu|cpu_pinned](n)." + " Only used for imperative calls."); DMLC_DECLARE_FIELD(dtype) - .add_enum("float32", mshadow::kFloat32) - .add_enum("float64", mshadow::kFloat64) - .add_enum("float16", mshadow::kFloat16) - .set_default(mshadow::kFloat32) - .describe("DType of the output in case this can't be inferred. " - "Defaults to float32 if not defined (dtype=None)."); + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .set_default(mshadow::kFloat32) + .describe( + "DType of the output in case this can't be inferred. " + "Defaults to float32 if not defined (dtype=None)."); } }; inline bool NumpyUniformOpType(const nnvm::NodeAttrs &attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector *in_attrs, + std::vector *out_attrs) { const NumpyUniformParam ¶m = nnvm::get(attrs.parsed); int otype = param.dtype; if (otype != -1) { @@ -84,61 +85,52 @@ inline bool NumpyUniformOpType(const nnvm::NodeAttrs &attrs, namespace mxnet_op { template struct uniform_kernel { - MSHADOW_XINLINE static void Map(index_t i, - const Shape &lstride, const Shape &hstride, - const Shape &oshape, - IType *low, IType *high, - float *uniform, OType *out) { - Shape coord = unravel(i, oshape); - auto lidx = static_cast(dot(coord, lstride)); - auto hidx = static_cast(dot(coord, hstride)); - IType low_value = low[lidx]; - IType high_value = high[hidx]; - out[i] = low_value + uniform[i] * (high_value - low_value); + MSHADOW_XINLINE static void Map(index_t i, const Shape &lstride, + const Shape &hstride, + const Shape &oshape, IType *low, + IType *high, float *uniform, OType *out) { + Shape coord = unravel(i, oshape); + auto lidx = static_cast(dot(coord, lstride)); + auto hidx = static_cast(dot(coord, hstride)); + IType low_value = low[lidx]; + IType high_value = high[hidx]; + out[i] = low_value + uniform[i] * (high_value - low_value); } }; -} // namespace mxnet_op -namespace mxnet_op { template struct uniform_one_scalar_kernel { MSHADOW_XINLINE static void Map(index_t i, int scalar_pos, - const Shape &stride, - const Shape &oshape, - IType *array, float scalar, - float *uniform, OType *out) { - Shape coord = unravel(i, oshape); - auto idx = static_cast(dot(coord, stride)); - IType low_value; - IType high_value; - if (scalar_pos == 0) { - low_value = scalar; - high_value = array[idx]; - } else { - low_value = array[idx]; - high_value = scalar; - } - out[i] = low_value + uniform[i] * (high_value - low_value); + const Shape &stride, + const Shape &oshape, IType *array, + float scalar, float *uniform, OType *out) { + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + IType low_value; + IType high_value; + if (scalar_pos == 0) { + low_value = scalar; + high_value = array[idx]; + } else { + low_value = array[idx]; + high_value = scalar; + } + out[i] = low_value + uniform[i] * (high_value - low_value); } }; -} // namespace mxnet_op -namespace mxnet_op { template struct uniform_two_scalar_kernel { - MSHADOW_XINLINE static void Map(index_t i, - float low, float high, + MSHADOW_XINLINE static void Map(index_t i, float low, float high, float *uniform, OType *out) { - out[i] = low + uniform[i] * (high - low); + out[i] = low + uniform[i] * (high - low); } }; } // namespace mxnet_op - - - template -void NumpyUniformForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, +void NumpyUniformForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { @@ -151,7 +143,8 @@ void NumpyUniformForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, // Generate base random number. Random *prnd = ctx.requested[0].get_random(s); Tensor uniform_tensor = - ctx.requested[1].get_space_typed(Shape1(outputs[0].Size()), s); + ctx.requested[1].get_space_typed(Shape1(outputs[0].Size()), + s); prnd->SampleUniform(&uniform_tensor, 0, 1); mxnet::TShape new_lshape, new_hshape, new_oshape; @@ -159,9 +152,8 @@ void NumpyUniformForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, if (inputs.size() == 0U) { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { mxnet_op::Kernel, xpu>::Launch( - s, outputs[0].Size(), - param.low.value(), param.high.value(), - uniform_tensor.dptr_, outputs[0].dptr()); + s, outputs[0].Size(), param.low.value(), param.high.value(), + uniform_tensor.dptr_, outputs[0].dptr()); }); } else if (inputs.size() == 1U) { // [scalar tensor], [tensor scalar] case @@ -180,13 +172,14 @@ void NumpyUniformForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape stride = - mxnet_op::calc_stride(new_lshape.get()); - mxnet_op::Kernel, xpu>::Launch( - s, outputs[0].Size(), scalar_pos, stride, oshape, - inputs[0].dptr(), scalar_value, - uniform_tensor.dptr_, outputs[0].dptr()); + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape stride = + mxnet_op::calc_stride(new_lshape.get()); + mxnet_op::Kernel, + xpu>::Launch(s, outputs[0].Size(), scalar_pos, + stride, oshape, inputs[0].dptr(), + scalar_value, uniform_tensor.dptr_, + outputs[0].dptr()); }); }); }); @@ -197,22 +190,22 @@ void NumpyUniformForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = - mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape hstride = - mxnet_op::calc_stride(new_hshape.get()); - mxnet_op::Kernel, xpu>::Launch( - s, outputs[0].Size(), lstride, hstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), - uniform_tensor.dptr_, outputs[0].dptr()); + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = + mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape hstride = + mxnet_op::calc_stride(new_hshape.get()); + mxnet_op::Kernel, xpu>::Launch( + s, outputs[0].Size(), lstride, hstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), + uniform_tensor.dptr_, outputs[0].dptr()); }); }); }); } } -}; // namespace op +} // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ From d02f0d7c0604940dab4e78f24e426e946eeb0071 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 13 Aug 2019 05:12:19 +0000 Subject: [PATCH 3/3] remove rebundant namespace --- src/operator/numpy/random/np_uniform_op.h | 26 ++++++++++------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/operator/numpy/random/np_uniform_op.h b/src/operator/numpy/random/np_uniform_op.h index a664bb5a8e71..580dc5d05eaa 100644 --- a/src/operator/numpy/random/np_uniform_op.h +++ b/src/operator/numpy/random/np_uniform_op.h @@ -151,7 +151,7 @@ void NumpyUniformForward(const nnvm::NodeAttrs &attrs, // [scalar scalar] case if (inputs.size() == 0U) { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { - mxnet_op::Kernel, xpu>::Launch( + Kernel, xpu>::Launch( s, outputs[0].Size(), param.low.value(), param.high.value(), uniform_tensor.dptr_, outputs[0].dptr()); }); @@ -172,14 +172,12 @@ void NumpyUniformForward(const nnvm::NodeAttrs &attrs, MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape stride = - mxnet_op::calc_stride(new_lshape.get()); - mxnet_op::Kernel, - xpu>::Launch(s, outputs[0].Size(), scalar_pos, - stride, oshape, inputs[0].dptr(), - scalar_value, uniform_tensor.dptr_, - outputs[0].dptr()); + Shape oshape = new_oshape.get(); + Shape stride = calc_stride(new_lshape.get()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), scalar_pos, stride, oshape, + inputs[0].dptr(), scalar_value, uniform_tensor.dptr_, + outputs[0].dptr()); }); }); }); @@ -190,12 +188,10 @@ void NumpyUniformForward(const nnvm::NodeAttrs &attrs, MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = - mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape hstride = - mxnet_op::calc_stride(new_hshape.get()); - mxnet_op::Kernel, xpu>::Launch( + Shape oshape = new_oshape.get(); + Shape lstride = calc_stride(new_lshape.get()); + Shape hstride = calc_stride(new_hshape.get()); + Kernel, xpu>::Launch( s, outputs[0].Size(), lstride, hstride, oshape, inputs[0].dptr(), inputs[1].dptr(), uniform_tensor.dptr_, outputs[0].dptr());