diff --git a/python/mxnet/base.py b/python/mxnet/base.py index cbd9abe9d754..96f099146490 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -752,7 +752,7 @@ def write_all_str(module_file, module_all_list): _NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_'] _NP_EXT_OP_PREFIX = '_npx_' -_NP_EXT_OP_SUBMODULE_LIST = ['_image_'] +_NP_EXT_OP_SUBMODULE_LIST = ['_image_', '_random_'] _NP_INTERNAL_OP_PREFIX = '_npi_' diff --git a/python/mxnet/ndarray/numpy_extension/__init__.py b/python/mxnet/ndarray/numpy_extension/__init__.py index 5be34ac9b3d5..afa81edc3820 100644 --- a/python/mxnet/ndarray/numpy_extension/__init__.py +++ b/python/mxnet/ndarray/numpy_extension/__init__.py @@ -19,6 +19,7 @@ from . import _op from . import image +from . import random from . import _register from ._op import * # pylint: disable=wildcard-import diff --git a/python/mxnet/ndarray/numpy_extension/random.py b/python/mxnet/ndarray/numpy_extension/random.py new file mode 100644 index 000000000000..b0472a4ab122 --- /dev/null +++ b/python/mxnet/ndarray/numpy_extension/random.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for operators used in Gluon dispatched by F=ndarray.""" +from __future__ import absolute_import +from ...context import current_context +from ..numpy import _internal as _npi + + +__all__ = ['bernoulli'] + + +def bernoulli(prob, logit, size, dtype, ctx, out): + """Creates a Bernoulli distribution parameterized by :attr:`prob` + or :attr:`logit` (but not both). + + Samples are binary (0 or 1). They take the value `1` with probability `p` + and `0` with probability `1 - p`. + + Parameters + ---------- + prob : float, ndarray + The probability of sampling '1'. + Only one of prob or logit should be passed in. + logit : float, ndarray + The log-odds of sampling '1'. + Only one of prob or logit should be passed in. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + dtype : dtype, optional + Desired dtype of the result. All dtypes are determined by their + name, i.e., 'int64', 'int', etc, so byteorder is not available + and a specific precision may have different C types depending + on the platform. The default value is 'np.float32'. + ctx : Context, optional + Device context of output. Default is current context. + out : symbol, optional + The output symbol (default is `None`). + + Returns + ------- + out : ndarray + Drawn samples from the parameterized bernoulli distribution. + + Examples + -------- + >>> prob = np.random.uniform(size=(4,4)) + >>> logit = np.log(prob) - np.log(1 - prob) + >>> npx.random.bernoulli(logit=logit) + array([[0., 1., 1., 1.], + [0., 1., 1., 1.], + [0., 1., 0., 0.], + [1., 0., 1., 0.]]) + + >>> npx.random.bernoulli(prob=prob) + array([[0., 1., 0., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 0.], + [1., 0., 1., 0.]]) + """ + from ...numpy import ndarray as np_ndarray + tensor_type_name = np_ndarray + if (prob is None) == (logit is None): + raise ValueError( + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if size == (): + size = None + if prob is not None: + is_tensor = isinstance(prob, tensor_type_name) + if is_tensor: + return _npi.bernoulli(prob, prob=None, logit=None, is_logit=False, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + return _npi.bernoulli(prob=prob, logit=None, is_logit=False, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + is_tensor = isinstance(logit, tensor_type_name) + if is_tensor: + return _npi.bernoulli(logit, prob=None, logit=None, is_logit=True, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + return _npi.bernoulli(prob=None, logit=logit, is_logit=True, + size=size, ctx=ctx, dtype=dtype, out=out) diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py index 879ab4d56a46..c4b90f2f8d35 100644 --- a/python/mxnet/numpy_extension/__init__.py +++ b/python/mxnet/numpy_extension/__init__.py @@ -22,12 +22,12 @@ from __future__ import absolute_import from . import _op from . import image +from . import random # pylint: disable=wildcard-import from . import _register from ._op import * # pylint: disable=wildcard-import from ..context import * # pylint: disable=wildcard-import from ..util import is_np_shape, is_np_array, set_np, reset_np, get_cuda_compute_capability from ..ndarray import waitall from .utils import * # pylint: disable=wildcard-import -from . import random # pylint: disable=wildcard-import __all__ = [] diff --git a/python/mxnet/numpy_extension/random.py b/python/mxnet/numpy_extension/random.py index 5aa58a0cc69d..300d0b22792f 100644 --- a/python/mxnet/numpy_extension/random.py +++ b/python/mxnet/numpy_extension/random.py @@ -19,9 +19,10 @@ from __future__ import absolute_import from .. import random as _mx_rand +from ..ndarray import numpy_extension as _mx_nd_npx -__all__ = ['seed'] +__all__ = ['seed', 'bernoulli'] def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name @@ -72,3 +73,56 @@ def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name array(0.9894903, ctx=gpu(0)) """ _mx_rand.seed(seed_state=seed, ctx=ctx) + + +def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, out=None): + """Creates a Bernoulli distribution parameterized by :attr:`prob` + or :attr:`logit` (but not both). + + Samples are binary (0 or 1). They take the value `1` with probability `p` + and `0` with probability `1 - p`. + + Parameters + ---------- + prob : float, ndarray + The probability of sampling '1'. + Only one of prob or logit should be passed in. + logit : float, ndarray + The log-odds of sampling '1'. + Only one of prob or logit should be passed in. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + dtype : dtype, optional + Desired dtype of the result. All dtypes are determined by their + name, i.e., 'int64', 'int', etc, so byteorder is not available + and a specific precision may have different C types depending + on the platform. The default value is 'np.float32'. + ctx : Context, optional + Device context of output. Default is current context. + out : symbol, optional + The output symbol (default is `None`). + + Returns + ------- + out : ndarray + Drawn samples from the parameterized bernoulli distribution. + + Examples + -------- + >>> prob = np.random.uniform(size=(4,4)) + >>> logit = np.log(prob) - np.log(1 - prob) + >>> npx.random.bernoulli(logit=logit) + array([[0., 1., 1., 1.], + [0., 1., 1., 1.], + [0., 1., 0., 0.], + [1., 0., 1., 0.]]) + + >>> npx.random.bernoulli(prob=prob) + array([[0., 1., 0., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 0.], + [1., 0., 1., 0.]]) + """ + return _mx_nd_npx.random.bernoulli(prob, logit, size, dtype, ctx, out) diff --git a/python/mxnet/symbol/numpy_extension/__init__.py b/python/mxnet/symbol/numpy_extension/__init__.py index 5be34ac9b3d5..afa81edc3820 100644 --- a/python/mxnet/symbol/numpy_extension/__init__.py +++ b/python/mxnet/symbol/numpy_extension/__init__.py @@ -19,6 +19,7 @@ from . import _op from . import image +from . import random from . import _register from ._op import * # pylint: disable=wildcard-import diff --git a/python/mxnet/symbol/numpy_extension/random.py b/python/mxnet/symbol/numpy_extension/random.py new file mode 100644 index 000000000000..a557a75d56f7 --- /dev/null +++ b/python/mxnet/symbol/numpy_extension/random.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for operators used in Gluon dispatched by F=symbol.""" + +from __future__ import absolute_import +from ...context import current_context +from ..numpy import _internal as _npi + +__all__ = ['bernoulli'] + + +def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, out=None): + """Creates a Bernoulli distribution parameterized by :attr:`prob` + or :attr:`logit` (but not both). + + Samples are binary (0 or 1). They take the value `1` with probability `p` + and `0` with probability `1 - p`. + + Parameters + ---------- + prob : float, ndarray + The probability of sampling '1'. + Only one of prob or logit should be passed in. + logit : float, ndarray + The log-odds of sampling '1'. + Only one of prob or logit should be passed in. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + dtype : dtype, optional + Desired dtype of the result. All dtypes are determined by their + name, i.e., 'int64', 'int', etc, so byteorder is not available + and a specific precision may have different C types depending + on the platform. The default value is 'np.float32'. + ctx : Context, optional + Device context of output. Default is current context. + out : symbol, optional + The output symbol (default is `None`). + + Returns + ------- + out : _Symbol + Drawn samples from the parameterized bernoulli distribution. + + Examples + -------- + >>> prob = np.random.uniform(size=(4,4)) + >>> logit = np.log(prob) - np.log(1 - prob) + >>> npx.random.bernoulli(logit=logit) + array([[0., 1., 1., 1.], + [0., 1., 1., 1.], + [0., 1., 0., 0.], + [1., 0., 1., 0.]]) + + >>> npx.random.bernoulli(prob=prob) + array([[0., 1., 0., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 0.], + [1., 0., 1., 0.]]) + """ + from ..numpy import _Symbol as np_symbol + tensor_type_name = np_symbol + if (prob is None) == (logit is None): + raise ValueError( + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if size == (): + size = None + if prob is not None: + is_tensor = isinstance(prob, tensor_type_name) + if is_tensor: + return _npi.bernoulli(prob, prob=None, logit=None, is_logit=False, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + return _npi.bernoulli(prob=prob, logit=None, is_logit=False, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + is_tensor = isinstance(logit, tensor_type_name) + if is_tensor: + return _npi.bernoulli(logit, prob=None, logit=None, is_logit=True, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + return _npi.bernoulli(prob=None, logit=logit, is_logit=True, + size=size, ctx=ctx, dtype=dtype, out=out) diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h index a3c48f5b4d5c..199b55b9afb3 100644 --- a/src/operator/numpy/random/dist_common.h +++ b/src/operator/numpy/random/dist_common.h @@ -27,7 +27,6 @@ #ifndef MXNET_OPERATOR_NUMPY_RANDOM_DIST_COMMON_H_ #define MXNET_OPERATOR_NUMPY_RANDOM_DIST_COMMON_H_ -#include #include #include #include @@ -172,10 +171,36 @@ inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs, } else if (in_attrs->size() == 0) { // Two scalar case. SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) - return true; } } - return out_attrs->at(0).ndim() != 0U; + return shape_is_known(out_attrs->at(0)); +} + +template +inline bool UnaryDistOpShape(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const DistParam ¶m = nnvm::get(attrs.parsed); + if (param.size.has_value()) { + // Size declared. + std::vector oshape_vec; + const mxnet::Tuple &size = param.size.value(); + for (int i = 0; i < size.ndim(); ++i) { + oshape_vec.emplace_back(size[i]); + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec)); + for (size_t input_idx = 0; input_idx < in_attrs->size(); input_idx++) { + CheckBroadcastable((*in_attrs)[input_idx], (*out_attrs)[0]); + } + } else { + if (in_attrs->size() == 1U) { + // One param from ndarray. + SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)) + } else { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) + } + } + return shape_is_known(out_attrs->at(0)); } } // namespace op diff --git a/src/operator/numpy/random/np_bernoulli_op.cc b/src/operator/numpy/random/np_bernoulli_op.cc new file mode 100644 index 000000000000..d67ad1b8d7f6 --- /dev/null +++ b/src/operator/numpy/random/np_bernoulli_op.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_bernoulli_op.cc + * \brief Operator for numpy sampling from bernoulli distributions + */ + +#include "./np_bernoulli_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyBernoulliParam); + +NNVM_REGISTER_OP(_npi_bernoulli) +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyBernoulliParam& param = nnvm::get(attrs.parsed); + int num_inputs = 1; + if (param.logit.has_value() || param.prob.has_value()) { + num_inputs -= 1; + } + return num_inputs; + } +) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyBernoulliParam& param = nnvm::get(attrs.parsed); + int num_inputs = 1; + if (param.logit.has_value() || param.prob.has_value()) { + num_inputs -= 1; + } + return (num_inputs == 0) ? std::vector() : std::vector{"input1"}; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", UnaryDistOpShape) +.set_attr("FInferType", NumpyBernoulliOpType) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBernoulliForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyBernoulliParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_bernoulli_op.cu b/src/operator/numpy/random/np_bernoulli_op.cu new file mode 100644 index 000000000000..a73bf9929db3 --- /dev/null +++ b/src/operator/numpy/random/np_bernoulli_op.cu @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_bernoulli_op.cu + * \brief Operator for numpy sampling from bernoulli distributions + */ + +#include "./np_bernoulli_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_bernoulli) +.set_attr("FCompute", NumpyBernoulliForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_bernoulli_op.h b/src/operator/numpy/random/np_bernoulli_op.h new file mode 100644 index 000000000000..aa8e344e2842 --- /dev/null +++ b/src/operator/numpy/random/np_bernoulli_op.h @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_bernoulli_op.h + * \brief Operator for numpy sampling from bernoulli distribution. + */ +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_BERNOULLI_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_BERNOULLI_OP_H_ + +#include +#include +#include +#include +#include "../../elemwise_op_common.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +struct NumpyBernoulliParam : public dmlc::Parameter { + dmlc::optional prob; + dmlc::optional logit; + std::string ctx; + int dtype; + bool is_logit; + dmlc::optional> size; + DMLC_DECLARE_PARAMETER(NumpyBernoulliParam) { + DMLC_DECLARE_FIELD(prob); + DMLC_DECLARE_FIELD(logit); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional>()) + .describe( + "Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + DMLC_DECLARE_FIELD(ctx).set_default("cpu").describe( + "Context of output, in format [cpu|gpu|cpu_pinned](n)." + " Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("uint8", mshadow::kUint8) + .add_enum("int32", mshadow::kInt32) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .add_enum("bool", mshadow::kBool) + .set_default(mshadow::kFloat32) + .describe( + "DType of the output in case this can't be inferred. " + "Defaults to float32 if not defined (dtype=None)."); + DMLC_DECLARE_FIELD(is_logit); + } +}; + +inline bool NumpyBernoulliOpType(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const NumpyBernoulliParam ¶m = nnvm::get(attrs.parsed); + int otype = param.dtype; + (*out_attrs)[0] = otype; + return true; +} + +namespace mxnet_op { + +struct prob_to_logit { + MSHADOW_XINLINE static void Map(index_t i, float* uniforms) { + float prob = uniforms[i]; + uniforms[i] = log(prob) - log(1 - prob); + } +}; + +template +struct bernoulli_kernel { + MSHADOW_XINLINE static void Map(index_t i, + const Shape &stride, + const Shape &oshape, + IType *inputs, float* threshold, OType *out) { + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + out[i] = inputs[idx] > threshold[i] ? OType(1) : OType(0); + } +}; + +template +struct scalar_bernoulli_kernel { + MSHADOW_XINLINE static void Map(index_t i, float inputs, float *threshold, + OType *out) { + out[i] = inputs > threshold[i] ? OType(1) : OType(0); + } +}; + +template +struct check_legal_prob_kernel { + MSHADOW_XINLINE static void Map(index_t i, IType *scalar, float* flag) { + if (scalar[i] < 0.0 || scalar[i] > 1.0) { + flag[0] = -1.0; + } + } +}; + +} // namespace mxnet_op + +template +void NumpyBernoulliForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyBernoulliParam ¶m = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + index_t output_len = outputs[0].Size(); + Random *prnd = ctx.requested[0].get_random(s); + Tensor workspace = + ctx.requested[1].get_space_typed(Shape1(output_len + 1), s); + Tensor uniform_tensor = workspace.Slice(0, output_len); + Tensor indicator_device = workspace.Slice(output_len, output_len + 1); + float indicator_host = 1.0; + float *indicator_device_ptr = indicator_device.dptr_; + Kernel::Launch(s, 1, indicator_device_ptr); + prnd->SampleUniform(&uniform_tensor, 0.0, 1.0); + if (param.prob.has_value()) { + // scalar prob input + CHECK_LE(param.prob.value(), 1.0) << "ValueError: expect probs >= 0 && probs <= 1"; + CHECK_GE(param.prob.value(), 0.0) << "ValueError: expect probs >= 0 && probs <= 1"; + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.prob.value(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + } else if (param.logit.has_value()) { + // scalar logit input + // sigmoid(x) > u <=> x > logit(u) + Kernel::Launch(s, outputs[0].Size(), + uniform_tensor.dptr_); + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.logit.value(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + } else { + if (param.is_logit) { + // tensor logit input + Kernel::Launch(s, outputs[0].Size(), + uniform_tensor.dptr_); + } else { + // tensor prob input + // sigmoid(x) > u <=> x > logit(u) + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + Kernel, xpu>::Launch( + s, inputs[0].Size(), inputs[0].dptr(), indicator_device_ptr); + }); + _copy(&indicator_host, indicator_device_ptr); + CHECK_GE(indicator_host, 0.0) + << "ValueError: expect probs >= 0 && probs <= 1"; + } + mxnet::TShape new_lshape, new_oshape; + int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, + &new_lshape, &new_lshape, &new_oshape); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + Shape oshape = new_oshape.get(); + Shape stride = calc_stride(new_lshape.get()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), stride, oshape, inputs[0].dptr(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + }); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_BERNOULLI_OP_H_ diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h index 678f0ec2fb47..deb603ac5270 100644 --- a/src/operator/numpy/random/np_normal_op.h +++ b/src/operator/numpy/random/np_normal_op.h @@ -158,6 +158,7 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, Tensor indicator_device = workspace.Slice(output_len, output_len + 1); float indicator_host = 1.0; float *indicator_device_ptr = indicator_device.dptr_; + Kernel::Launch(s, 1, indicator_device_ptr); prnd->SampleGaussian(&normal_tensor, 0.0, 1.0); mxnet::TShape new_lshape, new_hshape, new_oshape; diff --git a/src/operator/numpy/random/np_uniform_op.h b/src/operator/numpy/random/np_uniform_op.h index 580dc5d05eaa..1df0c39d4e57 100644 --- a/src/operator/numpy/random/np_uniform_op.h +++ b/src/operator/numpy/random/np_uniform_op.h @@ -25,7 +25,6 @@ #ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ #define MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ -#include #include #include #include diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 5476fbee8be4..45ee8d2bbd84 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2294,6 +2294,41 @@ def hybrid_forward(self, F, x): assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False) +@with_seed() +@use_np +def test_npx_random_bernoulli(): + def _test_bernoulli_exception(prob, logit): + output = npx.random.bernoulli(prob=prob, logit=logit).asnumpy() + + shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] + dtypes = ['float16', 'float32', 'float64', 'int32', 'bool'] + for shape, dtype in itertools.product(shapes, dtypes): + prob = np.random.uniform(size=shape) + logit = np.log(prob) - np.log(1 - prob) + expected_shape = shape + if not isinstance(shape, tuple): + expected_shape = () if shape is None else (shape,) + out_prob = npx.random.bernoulli(prob=prob, size=shape, dtype=dtype) + assert out_prob.shape == expected_shape + assert int((out_prob.asnumpy() == 0).sum() + (out_prob.asnumpy() == 1).sum()) == out_prob.size + out_logit = npx.random.bernoulli(logit=logit, size=shape, dtype=dtype) + assert out_logit.shape == expected_shape + assert int((out_logit.asnumpy() == 0).sum() + (out_logit.asnumpy() == 1).sum()) == out_logit.size + # Test Exception. + assertRaises(ValueError, _test_bernoulli_exception, prob, logit) + if prob.size > 0: + # larger than 1 + assertRaises(MXNetError, _test_bernoulli_exception, prob + 2.0, None) + # smaller than 0 + assertRaises(MXNetError, _test_bernoulli_exception, prob - 2.0, None) + # mixed case + low, high = (-1.0, 2.0) + # uniform(-1, 2) + scaled_prob = low + (high - low) * prob + if not ((scaled_prob.asnumpy() >= 0).all() and (scaled_prob.asnumpy() <= 1).all()): + assertRaises(MXNetError, _test_bernoulli_exception, scaled_prob, None) + + @with_seed() @use_np def test_np_random():