From 26777a29ff42a350e92ac17170df4a9ed58038ac Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Sat, 26 Oct 2019 12:55:12 +0000 Subject: [PATCH 1/9] add frontend interface for bernoulli --- python/mxnet/ndarray/numpy/random.py | 4 ++++ python/mxnet/numpy/random.py | 5 +++++ python/mxnet/symbol/numpy/random.py | 4 ++++ 3 files changed, 13 insertions(+) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 583f56e046f3..85b262000fe1 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -344,3 +344,7 @@ def rand(*size, **kwargs): for s in size: output_shape += (s,) return uniform(0, 1, size=output_shape, **kwargs) + + +def _bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None): + pass diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index d0ae237a5b92..d5a6a5258f89 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -257,3 +257,8 @@ def rand(*size, **kwargs): for s in size: output_shape += (s,) return _mx_nd_np.random.uniform(0, 1, size=output_shape, **kwargs) + + +def _bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None): + pass + diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index d891ea0c21a0..ce42f47c645c 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -288,3 +288,7 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False, out=out) else: return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out) + + +def _bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None): + pass From d6805a64dcd340ffb2702c36b5c128f4220edacd Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Mon, 4 Nov 2019 08:35:34 +0000 Subject: [PATCH 2/9] bernoulli backend done --- python/mxnet/base.py | 2 +- python/mxnet/ndarray/numpy/random.py | 4 - .../mxnet/ndarray/numpy_extension/__init__.py | 1 + .../mxnet/ndarray/numpy_extension/random.py | 57 +++++ python/mxnet/numpy/random.py | 5 - python/mxnet/numpy_extension/__init__.py | 2 +- python/mxnet/numpy_extension/random.py | 10 +- python/mxnet/symbol/numpy/random.py | 4 - .../mxnet/symbol/numpy_extension/__init__.py | 1 + python/mxnet/symbol/numpy_extension/random.py | 57 +++++ src/operator/numpy/random/dist_common.h | 31 ++- src/operator/numpy/random/np_bernoulli_op.cc | 71 ++++++ src/operator/numpy/random/np_bernoulli_op.cu | 35 +++ src/operator/numpy/random/np_bernoulli_op.h | 205 ++++++++++++++++++ src/operator/numpy/random/np_uniform_op.h | 1 - 15 files changed, 466 insertions(+), 20 deletions(-) create mode 100644 python/mxnet/ndarray/numpy_extension/random.py create mode 100644 python/mxnet/symbol/numpy_extension/random.py create mode 100644 src/operator/numpy/random/np_bernoulli_op.cc create mode 100644 src/operator/numpy/random/np_bernoulli_op.cu create mode 100644 src/operator/numpy/random/np_bernoulli_op.h diff --git a/python/mxnet/base.py b/python/mxnet/base.py index cbd9abe9d754..96f099146490 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -752,7 +752,7 @@ def write_all_str(module_file, module_all_list): _NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_'] _NP_EXT_OP_PREFIX = '_npx_' -_NP_EXT_OP_SUBMODULE_LIST = ['_image_'] +_NP_EXT_OP_SUBMODULE_LIST = ['_image_', '_random_'] _NP_INTERNAL_OP_PREFIX = '_npi_' diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 85b262000fe1..583f56e046f3 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -344,7 +344,3 @@ def rand(*size, **kwargs): for s in size: output_shape += (s,) return uniform(0, 1, size=output_shape, **kwargs) - - -def _bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None): - pass diff --git a/python/mxnet/ndarray/numpy_extension/__init__.py b/python/mxnet/ndarray/numpy_extension/__init__.py index 5be34ac9b3d5..afa81edc3820 100644 --- a/python/mxnet/ndarray/numpy_extension/__init__.py +++ b/python/mxnet/ndarray/numpy_extension/__init__.py @@ -19,6 +19,7 @@ from . import _op from . import image +from . import random from . import _register from ._op import * # pylint: disable=wildcard-import diff --git a/python/mxnet/ndarray/numpy_extension/random.py b/python/mxnet/ndarray/numpy_extension/random.py new file mode 100644 index 000000000000..f61809beb2ab --- /dev/null +++ b/python/mxnet/ndarray/numpy_extension/random.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for operators used in Gluon dispatched by F=ndarray.""" +from __future__ import absolute_import +from ...context import current_context +from .. import _internal as _npi + + +__all__ = ['bernoulli'] + + +def bernoulli(probs, logits, size, dtype, ctx, out): + """ + Sampling from bernoulli distribution. + """ + from ...numpy import ndarray as np_ndarray + tensor_type_name = np_ndarray + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both.") + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if size == (): + size = None + if probs is not None: + is_tensor = isinstance(probs, tensor_type_name) + if is_tensor: + return _npi.bernoulli(probs, probs=None, logits=None, is_logit=False, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + return _npi.bernoulli(probs=probs, logits=None, is_logit=False, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + is_tensor = isinstance(logits, tensor_type_name) + if is_tensor: + return _npi.bernoulli(logits, probs=None, logits=None, is_logit=True, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + return _npi.bernoulli(probs=None, logits=logits, is_logit=True, + size=size, ctx=ctx, dtype=dtype, out=out) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index d5a6a5258f89..d0ae237a5b92 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -257,8 +257,3 @@ def rand(*size, **kwargs): for s in size: output_shape += (s,) return _mx_nd_np.random.uniform(0, 1, size=output_shape, **kwargs) - - -def _bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None): - pass - diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py index 879ab4d56a46..c4b90f2f8d35 100644 --- a/python/mxnet/numpy_extension/__init__.py +++ b/python/mxnet/numpy_extension/__init__.py @@ -22,12 +22,12 @@ from __future__ import absolute_import from . import _op from . import image +from . import random # pylint: disable=wildcard-import from . import _register from ._op import * # pylint: disable=wildcard-import from ..context import * # pylint: disable=wildcard-import from ..util import is_np_shape, is_np_array, set_np, reset_np, get_cuda_compute_capability from ..ndarray import waitall from .utils import * # pylint: disable=wildcard-import -from . import random # pylint: disable=wildcard-import __all__ = [] diff --git a/python/mxnet/numpy_extension/random.py b/python/mxnet/numpy_extension/random.py index 5aa58a0cc69d..3d27cb886d2b 100644 --- a/python/mxnet/numpy_extension/random.py +++ b/python/mxnet/numpy_extension/random.py @@ -19,9 +19,10 @@ from __future__ import absolute_import from .. import random as _mx_rand +from ..ndarray import numpy_extension as _mx_nd_npx -__all__ = ['seed'] +__all__ = ['seed', 'bernoulli'] def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name @@ -72,3 +73,10 @@ def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name array(0.9894903, ctx=gpu(0)) """ _mx_rand.seed(seed_state=seed, ctx=ctx) + + +def bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None): + """ + Sampling from bernoulli distribution. + """ + return _mx_nd_npx.random.bernoulli(probs, logits, size, dtype, ctx, out) diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index ce42f47c645c..d891ea0c21a0 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -288,7 +288,3 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False, out=out) else: return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out) - - -def _bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None): - pass diff --git a/python/mxnet/symbol/numpy_extension/__init__.py b/python/mxnet/symbol/numpy_extension/__init__.py index 5be34ac9b3d5..afa81edc3820 100644 --- a/python/mxnet/symbol/numpy_extension/__init__.py +++ b/python/mxnet/symbol/numpy_extension/__init__.py @@ -19,6 +19,7 @@ from . import _op from . import image +from . import random from . import _register from ._op import * # pylint: disable=wildcard-import diff --git a/python/mxnet/symbol/numpy_extension/random.py b/python/mxnet/symbol/numpy_extension/random.py new file mode 100644 index 000000000000..e0cba123c227 --- /dev/null +++ b/python/mxnet/symbol/numpy_extension/random.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for operators used in Gluon dispatched by F=symbol.""" + +from __future__ import absolute_import +from ...context import current_context +from .. import _internal as _npi + +__all__ = ['bernoulli'] + + +def bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None): + """ + Sampling from beroulli distributions. + """ + from ..numpy import _Symbol as np_symbol + tensor_type_name = np_symbol + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both.") + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if size == (): + size = None + if probs is not None: + is_tensor = isinstance(probs, tensor_type_name) + if is_tensor: + return _npi.bernoulli(probs, probs=None, logits=None, is_logit=False, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + return _npi.bernoulli(probs=probs, logits=None, is_logit=False, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + is_tensor = isinstance(logits, tensor_type_name) + if is_tensor: + return _npi.bernoulli(logits, probs=None, logits=None, is_logit=True, + size=size, ctx=ctx, dtype=dtype, out=out) + else: + return _npi.bernoulli(probs=None, logits=logits, is_logit=True, + size=size, ctx=ctx, dtype=dtype, out=out) diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h index a3c48f5b4d5c..cfde9686d272 100644 --- a/src/operator/numpy/random/dist_common.h +++ b/src/operator/numpy/random/dist_common.h @@ -27,7 +27,6 @@ #ifndef MXNET_OPERATOR_NUMPY_RANDOM_DIST_COMMON_H_ #define MXNET_OPERATOR_NUMPY_RANDOM_DIST_COMMON_H_ -#include #include #include #include @@ -172,10 +171,36 @@ inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs, } else if (in_attrs->size() == 0) { // Two scalar case. SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) - return true; } } - return out_attrs->at(0).ndim() != 0U; + return shape_is_known(out_attrs->at(0)); +} + +template +inline bool UnaryDistOpShape(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const DistParam ¶m = nnvm::get(attrs.parsed); + if (param.size.has_value()) { + // Size declared. + std::vector oshape_vec; + const mxnet::Tuple &size = param.size.value(); + for (int i = 0; i < size.ndim(); ++i) { + oshape_vec.emplace_back(size[i]); + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec)); + for (size_t input_idx = 0; input_idx < in_attrs->size(); input_idx++) { + CheckBroadcastable((*in_attrs)[input_idx], (*out_attrs)[0]); + } + } else { + if (in_attrs->size() == 1U) { + // One param from ndarray. + SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)) + } else { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) + } + } + return shape_is_known(out_attrs->at(0)); } } // namespace op diff --git a/src/operator/numpy/random/np_bernoulli_op.cc b/src/operator/numpy/random/np_bernoulli_op.cc new file mode 100644 index 000000000000..0b2988d7f975 --- /dev/null +++ b/src/operator/numpy/random/np_bernoulli_op.cc @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_bernoulli_op.cc + * \brief Operator for numpy sampling from bernoulli distributions + */ + +#include "./np_bernoulli_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyBernoulliParam); + +NNVM_REGISTER_OP(_npi_bernoulli) +.describe("Sample frmo bernoulli distribution") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyBernoulliParam& param = nnvm::get(attrs.parsed); + int num_inputs = 1; + if (param.logit.has_value() || param.prob.has_value()) { + num_inputs -= 1; + } + return num_inputs; + } +) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyBernoulliParam& param = nnvm::get(attrs.parsed); + int num_inputs = 1; + if (param.logit.has_value() || param.prob.has_value()) { + num_inputs -= 1; + } + if (num_inputs == 0) return std::vector(); + return std::vector{"input1"}; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", UnaryDistOpShape) +.set_attr("FInferType", NumpyBernoulliOpType) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBernoulliForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyBernoulliParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_bernoulli_op.cu b/src/operator/numpy/random/np_bernoulli_op.cu new file mode 100644 index 000000000000..a73bf9929db3 --- /dev/null +++ b/src/operator/numpy/random/np_bernoulli_op.cu @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_bernoulli_op.cu + * \brief Operator for numpy sampling from bernoulli distributions + */ + +#include "./np_bernoulli_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_bernoulli) +.set_attr("FCompute", NumpyBernoulliForward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_bernoulli_op.h b/src/operator/numpy/random/np_bernoulli_op.h new file mode 100644 index 000000000000..c5e701d53acb --- /dev/null +++ b/src/operator/numpy/random/np_bernoulli_op.h @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_bernoulli_op.h + * \brief Operator for numpy sampling from bernoulli distribution. + */ +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_BERNOULLI_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_BERNOULLI_OP_H_ + +#include +#include +#include +#include +#include "../../elemwise_op_common.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "./dist_common.h" + +namespace mxnet { +namespace op { + +struct NumpyBernoulliParam : public dmlc::Parameter { + dmlc::optional prob; + dmlc::optional logit; + std::string ctx; + int dtype; + bool is_logit; + dmlc::optional> size; + DMLC_DECLARE_PARAMETER(NumpyBernoulliParam) { + DMLC_DECLARE_FIELD(prob); + DMLC_DECLARE_FIELD(logit); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional>()) + .describe( + "Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + DMLC_DECLARE_FIELD(ctx).set_default("cpu").describe( + "Context of output, in format [cpu|gpu|cpu_pinned](n)." + " Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .set_default(mshadow::kFloat32) + .describe( + "DType of the output in case this can't be inferred. " + "Defaults to float32 if not defined (dtype=None)."); + DMLC_DECLARE_FIELD(is_logit); + } +}; + +inline bool NumpyBernoulliOpType(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const NumpyBernoulliParam ¶m = nnvm::get(attrs.parsed); + int otype = param.dtype; + if (otype != -1) { + (*out_attrs)[0] = otype; + } else { + // Following torch.distributions, + // the default type will be float32. + (*out_attrs)[0] = mshadow::kFloat32; + } + return true; +} + +namespace mxnet_op { + +struct prob_to_logit { + MSHADOW_XINLINE static void Map(index_t i, float* uniforms) { + float prob = uniforms[i]; + uniforms[i] = log(prob) - log(1 - prob); + } +}; + +template +struct bernoulli_kernel { + MSHADOW_XINLINE static void Map(index_t i, + const Shape &stride, + const Shape &oshape, + IType *inputs, float* threshold, OType *out) { + Shape coord = unravel(i, oshape); + auto idx = static_cast(dot(coord, stride)); + out[i] = inputs[idx] > threshold[i] ? OType(1) : OType(0); + } +}; + +template +struct scalar_bernoulli_kernel { + MSHADOW_XINLINE static void Map(index_t i, float inputs, float *threshold, + OType *out) { + out[i] = inputs > threshold[i] ? OType(1) : OType(0); + } +}; + +template +struct check_legal_prob_kernel { + MSHADOW_XINLINE static void Map(index_t i, IType *scalar, float* flag) { + if (scalar[i] < 0.0 || scalar[i] > 1.0) { + flag[0] = -1.0; + } + } +}; + +} // namespace mxnet_op + +template +void NumpyBernoulliForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyBernoulliParam ¶m = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + index_t output_len = outputs[0].Size(); + Random *prnd = ctx.requested[0].get_random(s); + Tensor workspace = + ctx.requested[1].get_space_typed(Shape1(output_len + 1), s); + Tensor uniform_tensor = workspace.Slice(0, output_len); + Tensor indicator_device = workspace.Slice(output_len, output_len + 1); + float indicator_host = 1.0; + float *indicator_device_ptr = indicator_device.dptr_; + prnd->SampleUniform(&uniform_tensor, 0.0, 1.0); + if (param.prob.has_value()) { + // scalar prob input + CHECK_LE(param.prob.value(), 1.0) << "ValueError: expect probs >= 0 && probs <= 1"; + CHECK_GE(param.prob.value(), 0.0) << "ValueError: expect probs >= 0 && probs <= 1"; + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.prob.value(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + } else if (param.logit.has_value()) { + // scalar logit input + // sigmoid(x) > u <=> x > logit(u) + Kernel::Launch(s, outputs[0].Size(), + uniform_tensor.dptr_); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), param.logit.value(), + uniform_tensor.dptr_, outputs[0].dptr()); + }); + } else { + if (param.is_logit) { + // tensor logit input + Kernel::Launch(s, outputs[0].Size(), + uniform_tensor.dptr_); + } else { + // tensor prob input + // sigmoid(x) > u <=> x > logit(u) + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + Kernel, xpu>::Launch( + s, inputs[0].Size(), inputs[0].dptr(), indicator_device_ptr); + }); + _copy(&indicator_host, indicator_device_ptr); + CHECK_GE(indicator_host, 0.0) + << "ValueError: expect probs >= 0 && probs <= 1"; + } + mxnet::TShape new_lshape, new_oshape; + int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, + &new_lshape, &new_lshape, &new_oshape); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + Shape oshape = new_oshape.get(); + Shape stride = calc_stride(new_lshape.get()); + Kernel, xpu>::Launch( + s, outputs[0].Size(), stride, oshape, + inputs[0].dptr(), uniform_tensor.dptr_, + outputs[0].dptr()); + }); + }); + }); + } +} + +} // namespace op +} // namespace mxnet + + + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_BERNOULLI_OP_H_ diff --git a/src/operator/numpy/random/np_uniform_op.h b/src/operator/numpy/random/np_uniform_op.h index 580dc5d05eaa..1df0c39d4e57 100644 --- a/src/operator/numpy/random/np_uniform_op.h +++ b/src/operator/numpy/random/np_uniform_op.h @@ -25,7 +25,6 @@ #ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ #define MXNET_OPERATOR_NUMPY_RANDOM_NP_UNIFORM_OP_H_ -#include #include #include #include From 70ae946a16571b152793e7c851e41ca5e0aeb5c6 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 5 Nov 2019 16:20:24 +0000 Subject: [PATCH 3/9] frontend done, test to be added --- .../mxnet/ndarray/numpy_extension/random.py | 70 +++++++++++++++---- python/mxnet/numpy_extension/random.py | 52 ++++++++++++-- python/mxnet/symbol/numpy_extension/random.py | 70 +++++++++++++++---- 3 files changed, 162 insertions(+), 30 deletions(-) diff --git a/python/mxnet/ndarray/numpy_extension/random.py b/python/mxnet/ndarray/numpy_extension/random.py index f61809beb2ab..bf1fd55f7da5 100644 --- a/python/mxnet/ndarray/numpy_extension/random.py +++ b/python/mxnet/ndarray/numpy_extension/random.py @@ -18,40 +18,84 @@ """Namespace for operators used in Gluon dispatched by F=ndarray.""" from __future__ import absolute_import from ...context import current_context -from .. import _internal as _npi +from ..numpy import _internal as _npi __all__ = ['bernoulli'] -def bernoulli(probs, logits, size, dtype, ctx, out): - """ - Sampling from bernoulli distribution. +def bernoulli(prob, logit, size, dtype, ctx, out): + """Creates a Bernoulli distribution parameterized by :attr:`prob` + or :attr:`logit` (but not both). + + Samples are binary (0 or 1). They take the value `1` with probability `p` + and `0` with probability `1 - p`. + + Parameters + ---------- + prob : float, ndarray + The probability of sampling '1'. + logit : float, ndarray + The log-odds of sampling '1'. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + dtype : dtype, optional + Desired dtype of the result. All dtypes are determined by their + name, i.e., 'int64', 'int', etc, so byteorder is not available + and a specific precision may have different C types depending + on the platform. The default value is 'np.float32'. + ctx : Context, optional + Device context of output. Default is current context. + out : symbol, optional + The output symbol (default is `None`). + + Returns + ------- + out : ndarray + Drawn samples from the parameterized bernoulli distribution. + + Examples + -------- + >>> prob = np.random.uniform(size=(4,4)) + >>> logit = np.log(prob) - np.log(1 - prob) + >>> npx.random.bernoulli(logit=logit) + array([[0., 1., 1., 1.], + [0., 1., 1., 1.], + [0., 1., 0., 0.], + [1., 0., 1., 0.]]) + + >>> npx.random.bernoulli(prob=prob) + array([[0., 1., 0., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 0.], + [1., 0., 1., 0.]]) """ from ...numpy import ndarray as np_ndarray tensor_type_name = np_ndarray - if (probs is None) == (logits is None): + if (prob is None) == (logit is None): raise ValueError( - "Either `probs` or `logits` must be specified, but not both.") + "Either `prob` or `logit` must be specified, but not both.") if dtype is None: dtype = 'float32' if ctx is None: ctx = current_context() if size == (): size = None - if probs is not None: - is_tensor = isinstance(probs, tensor_type_name) + if prob is not None: + is_tensor = isinstance(prob, tensor_type_name) if is_tensor: - return _npi.bernoulli(probs, probs=None, logits=None, is_logit=False, + return _npi.bernoulli(prob, prob=None, logit=None, is_logit=False, size=size, ctx=ctx, dtype=dtype, out=out) else: - return _npi.bernoulli(probs=probs, logits=None, is_logit=False, + return _npi.bernoulli(prob=prob, logit=None, is_logit=False, size=size, ctx=ctx, dtype=dtype, out=out) else: - is_tensor = isinstance(logits, tensor_type_name) + is_tensor = isinstance(logit, tensor_type_name) if is_tensor: - return _npi.bernoulli(logits, probs=None, logits=None, is_logit=True, + return _npi.bernoulli(logit, prob=None, logit=None, is_logit=True, size=size, ctx=ctx, dtype=dtype, out=out) else: - return _npi.bernoulli(probs=None, logits=logits, is_logit=True, + return _npi.bernoulli(prob=None, logit=logit, is_logit=True, size=size, ctx=ctx, dtype=dtype, out=out) diff --git a/python/mxnet/numpy_extension/random.py b/python/mxnet/numpy_extension/random.py index 3d27cb886d2b..7a34eaa087ef 100644 --- a/python/mxnet/numpy_extension/random.py +++ b/python/mxnet/numpy_extension/random.py @@ -75,8 +75,52 @@ def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name _mx_rand.seed(seed_state=seed, ctx=ctx) -def bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None): - """ - Sampling from bernoulli distribution. +def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, out=None): + """Creates a Bernoulli distribution parameterized by :attr:`prob` + or :attr:`logit` (but not both). + + Samples are binary (0 or 1). They take the value `1` with probability `p` + and `0` with probability `1 - p`. + + Parameters + ---------- + prob : float, ndarray + The probability of sampling '1'. + logit : float, ndarray + The log-odds of sampling '1'. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + dtype : dtype, optional + Desired dtype of the result. All dtypes are determined by their + name, i.e., 'int64', 'int', etc, so byteorder is not available + and a specific precision may have different C types depending + on the platform. The default value is 'np.float32'. + ctx : Context, optional + Device context of output. Default is current context. + out : symbol, optional + The output symbol (default is `None`). + + Returns + ------- + out : ndarray + Drawn samples from the parameterized bernoulli distribution. + + Examples + -------- + >>> prob = np.random.uniform(size=(4,4)) + >>> logit = np.log(prob) - np.log(1 - prob) + >>> npx.random.bernoulli(logit=logit) + array([[0., 1., 1., 1.], + [0., 1., 1., 1.], + [0., 1., 0., 0.], + [1., 0., 1., 0.]]) + + >>> npx.random.bernoulli(prob=prob) + array([[0., 1., 0., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 0.], + [1., 0., 1., 0.]]) """ - return _mx_nd_npx.random.bernoulli(probs, logits, size, dtype, ctx, out) + return _mx_nd_npx.random.bernoulli(prob, logit, size, dtype, ctx, out) diff --git a/python/mxnet/symbol/numpy_extension/random.py b/python/mxnet/symbol/numpy_extension/random.py index e0cba123c227..a04b0504d54e 100644 --- a/python/mxnet/symbol/numpy_extension/random.py +++ b/python/mxnet/symbol/numpy_extension/random.py @@ -19,39 +19,83 @@ from __future__ import absolute_import from ...context import current_context -from .. import _internal as _npi +from ..numpy import _internal as _npi __all__ = ['bernoulli'] -def bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None): - """ - Sampling from beroulli distributions. +def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, out=None): + """Creates a Bernoulli distribution parameterized by :attr:`prob` + or :attr:`logit` (but not both). + + Samples are binary (0 or 1). They take the value `1` with probability `p` + and `0` with probability `1 - p`. + + Parameters + ---------- + prob : float, _Symbol + The probability of sampling '1'. + logit : float, _Symbol + The log-odds of sampling '1'. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + dtype : dtype, optional + Desired dtype of the result. All dtypes are determined by their + name, i.e., 'int64', 'int', etc, so byteorder is not available + and a specific precision may have different C types depending + on the platform. The default value is 'np.float32'. + ctx : Context, optional + Device context of output. Default is current context. + out : symbol, optional + The output symbol (default is `None`). + + Returns + ------- + out : _Symbol + Drawn samples from the parameterized bernoulli distribution. + + Examples + -------- + >>> prob = np.random.uniform(size=(4,4)) + >>> logit = np.log(prob) - np.log(1 - prob) + >>> npx.random.bernoulli(logit=logit) + array([[0., 1., 1., 1.], + [0., 1., 1., 1.], + [0., 1., 0., 0.], + [1., 0., 1., 0.]]) + + >>> npx.random.bernoulli(prob=prob) + array([[0., 1., 0., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 0.], + [1., 0., 1., 0.]]) """ from ..numpy import _Symbol as np_symbol tensor_type_name = np_symbol - if (probs is None) == (logits is None): + if (prob is None) == (logit is None): raise ValueError( - "Either `probs` or `logits` must be specified, but not both.") + "Either `prob` or `logit` must be specified, but not both.") if dtype is None: dtype = 'float32' if ctx is None: ctx = current_context() if size == (): size = None - if probs is not None: - is_tensor = isinstance(probs, tensor_type_name) + if prob is not None: + is_tensor = isinstance(prob, tensor_type_name) if is_tensor: - return _npi.bernoulli(probs, probs=None, logits=None, is_logit=False, + return _npi.bernoulli(prob, prob=None, logit=None, is_logit=False, size=size, ctx=ctx, dtype=dtype, out=out) else: - return _npi.bernoulli(probs=probs, logits=None, is_logit=False, + return _npi.bernoulli(prob=prob, logit=None, is_logit=False, size=size, ctx=ctx, dtype=dtype, out=out) else: - is_tensor = isinstance(logits, tensor_type_name) + is_tensor = isinstance(logit, tensor_type_name) if is_tensor: - return _npi.bernoulli(logits, probs=None, logits=None, is_logit=True, + return _npi.bernoulli(logit, prob=None, logit=None, is_logit=True, size=size, ctx=ctx, dtype=dtype, out=out) else: - return _npi.bernoulli(probs=None, logits=logits, is_logit=True, + return _npi.bernoulli(prob=None, logit=logit, is_logit=True, size=size, ctx=ctx, dtype=dtype, out=out) From 7330e82ac67dd70b1e3f0357ec876240e639102d Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 6 Nov 2019 06:21:20 +0000 Subject: [PATCH 4/9] finish tests, fix indicator initialization bug --- src/operator/numpy/random/np_bernoulli_op.h | 3 +++ src/operator/numpy/random/np_normal_op.h | 1 + tests/python/unittest/test_numpy_op.py | 20 ++++++++++++++++++++ 3 files changed, 24 insertions(+) diff --git a/src/operator/numpy/random/np_bernoulli_op.h b/src/operator/numpy/random/np_bernoulli_op.h index c5e701d53acb..6029d6f14978 100644 --- a/src/operator/numpy/random/np_bernoulli_op.h +++ b/src/operator/numpy/random/np_bernoulli_op.h @@ -59,6 +59,8 @@ struct NumpyBernoulliParam : public dmlc::Parameter { "Context of output, in format [cpu|gpu|cpu_pinned](n)." " Only used for imperative calls."); DMLC_DECLARE_FIELD(dtype) + .add_enum("uint8", mshadow::kUint8) + .add_enum("int32", mshadow::kInt32) .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) .add_enum("float16", mshadow::kFloat16) @@ -143,6 +145,7 @@ void NumpyBernoulliForward(const nnvm::NodeAttrs &attrs, Tensor indicator_device = workspace.Slice(output_len, output_len + 1); float indicator_host = 1.0; float *indicator_device_ptr = indicator_device.dptr_; + Kernel::Launch(s, 1, indicator_device_ptr); prnd->SampleUniform(&uniform_tensor, 0.0, 1.0); if (param.prob.has_value()) { // scalar prob input diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h index 678f0ec2fb47..deb603ac5270 100644 --- a/src/operator/numpy/random/np_normal_op.h +++ b/src/operator/numpy/random/np_normal_op.h @@ -158,6 +158,7 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, Tensor indicator_device = workspace.Slice(output_len, output_len + 1); float indicator_host = 1.0; float *indicator_device_ptr = indicator_device.dptr_; + Kernel::Launch(s, 1, indicator_device_ptr); prnd->SampleGaussian(&normal_tensor, 0.0, 1.0); mxnet::TShape new_lshape, new_hshape, new_oshape; diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 5476fbee8be4..f1dce2466f26 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2294,6 +2294,26 @@ def hybrid_forward(self, F, x): assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False) +@with_seed() +@use_np +def test_npx_bernoulli(): + shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] + dtypes = ['float16', 'float32', 'float64', 'int32'] + epsilon = 1e-4 + for shape, dtype in itertools.product(shapes, dtypes): + prob = np.random.uniform(size=shape) + logit = np.log(prob) - np.log(1 - prob) + expected_shape = shape + if not isinstance(shape, tuple): + expected_shape = () if shape is None else (shape,) + out_prob = npx.random.bernoulli(prob=prob, size=shape, dtype=dtype) + assert out_prob.shape == expected_shape + assert int((out_prob == 0).sum() + (out_prob == 1).sum()) == out_prob.size + out_logit = npx.random.bernoulli(logit=logit, size=shape, dtype=dtype) + assert out_logit.shape == expected_shape + assert int((out_logit == 0).sum() + (out_logit == 1).sum()) == out_logit.size + + @with_seed() @use_np def test_np_random(): From 2f5ec0985389d7309e733e6c1b9376102dd220d9 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Thu, 7 Nov 2019 05:43:28 +0000 Subject: [PATCH 5/9] test with native numpy --- tests/python/unittest/test_numpy_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index f1dce2466f26..7e23eec51fad 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2308,10 +2308,10 @@ def test_npx_bernoulli(): expected_shape = () if shape is None else (shape,) out_prob = npx.random.bernoulli(prob=prob, size=shape, dtype=dtype) assert out_prob.shape == expected_shape - assert int((out_prob == 0).sum() + (out_prob == 1).sum()) == out_prob.size + assert int((out_prob.asnumpy() == 0).sum() + (out_prob.asnumpy() == 1).sum()) == out_prob.size out_logit = npx.random.bernoulli(logit=logit, size=shape, dtype=dtype) assert out_logit.shape == expected_shape - assert int((out_logit == 0).sum() + (out_logit == 1).sum()) == out_logit.size + assert int((out_logit.asnumpy() == 0).sum() + (out_logit.asnumpy() == 1).sum()) == out_logit.size @with_seed() From 51dec12e16ba637600f9e6dd80a24ef198a2e32b Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Sat, 9 Nov 2019 06:07:38 +0000 Subject: [PATCH 6/9] fix indent, change test name --- src/operator/numpy/random/dist_common.h | 4 ++-- src/operator/numpy/random/np_bernoulli_op.h | 2 -- tests/python/unittest/test_numpy_op.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h index cfde9686d272..199b55b9afb3 100644 --- a/src/operator/numpy/random/dist_common.h +++ b/src/operator/numpy/random/dist_common.h @@ -178,8 +178,8 @@ inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs, template inline bool UnaryDistOpShape(const nnvm::NodeAttrs &attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector *in_attrs, + std::vector *out_attrs) { const DistParam ¶m = nnvm::get(attrs.parsed); if (param.size.has_value()) { // Size declared. diff --git a/src/operator/numpy/random/np_bernoulli_op.h b/src/operator/numpy/random/np_bernoulli_op.h index 6029d6f14978..b375c0cc00ad 100644 --- a/src/operator/numpy/random/np_bernoulli_op.h +++ b/src/operator/numpy/random/np_bernoulli_op.h @@ -203,6 +203,4 @@ void NumpyBernoulliForward(const nnvm::NodeAttrs &attrs, } // namespace op } // namespace mxnet - - #endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_BERNOULLI_OP_H_ diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 7e23eec51fad..617e4d780fbe 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2296,7 +2296,7 @@ def hybrid_forward(self, F, x): @with_seed() @use_np -def test_npx_bernoulli(): +def test_npx_random_bernoulli(): shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] dtypes = ['float16', 'float32', 'float64', 'int32'] epsilon = 1e-4 From ff4073c292e4973628fdcf47de12907a4e12bc3c Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Sun, 10 Nov 2019 12:29:13 +0000 Subject: [PATCH 7/9] resolve comments --- python/mxnet/ndarray/numpy_extension/random.py | 5 ++++- python/mxnet/numpy_extension/random.py | 2 ++ python/mxnet/symbol/numpy_extension/random.py | 9 ++++++--- src/operator/numpy/random/np_bernoulli_op.h | 15 +++++---------- tests/python/unittest/test_numpy_op.py | 2 +- 5 files changed, 18 insertions(+), 15 deletions(-) diff --git a/python/mxnet/ndarray/numpy_extension/random.py b/python/mxnet/ndarray/numpy_extension/random.py index bf1fd55f7da5..b0472a4ab122 100644 --- a/python/mxnet/ndarray/numpy_extension/random.py +++ b/python/mxnet/ndarray/numpy_extension/random.py @@ -35,8 +35,10 @@ def bernoulli(prob, logit, size, dtype, ctx, out): ---------- prob : float, ndarray The probability of sampling '1'. + Only one of prob or logit should be passed in. logit : float, ndarray The log-odds of sampling '1'. + Only one of prob or logit should be passed in. size : int or tuple of ints, optional Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` samples are drawn. Default is None, in which case a @@ -76,7 +78,8 @@ def bernoulli(prob, logit, size, dtype, ctx, out): tensor_type_name = np_ndarray if (prob is None) == (logit is None): raise ValueError( - "Either `prob` or `logit` must be specified, but not both.") + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) if dtype is None: dtype = 'float32' if ctx is None: diff --git a/python/mxnet/numpy_extension/random.py b/python/mxnet/numpy_extension/random.py index 7a34eaa087ef..300d0b22792f 100644 --- a/python/mxnet/numpy_extension/random.py +++ b/python/mxnet/numpy_extension/random.py @@ -86,8 +86,10 @@ def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, out=None): ---------- prob : float, ndarray The probability of sampling '1'. + Only one of prob or logit should be passed in. logit : float, ndarray The log-odds of sampling '1'. + Only one of prob or logit should be passed in. size : int or tuple of ints, optional Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` samples are drawn. Default is None, in which case a diff --git a/python/mxnet/symbol/numpy_extension/random.py b/python/mxnet/symbol/numpy_extension/random.py index a04b0504d54e..a557a75d56f7 100644 --- a/python/mxnet/symbol/numpy_extension/random.py +++ b/python/mxnet/symbol/numpy_extension/random.py @@ -33,10 +33,12 @@ def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, out=None): Parameters ---------- - prob : float, _Symbol + prob : float, ndarray The probability of sampling '1'. - logit : float, _Symbol + Only one of prob or logit should be passed in. + logit : float, ndarray The log-odds of sampling '1'. + Only one of prob or logit should be passed in. size : int or tuple of ints, optional Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` samples are drawn. Default is None, in which case a @@ -76,7 +78,8 @@ def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, out=None): tensor_type_name = np_symbol if (prob is None) == (logit is None): raise ValueError( - "Either `prob` or `logit` must be specified, but not both.") + "Either `prob` or `logit` must be specified, but not both. " + + "Received prob={}, logit={}".format(prob, logit)) if dtype is None: dtype = 'float32' if ctx is None: diff --git a/src/operator/numpy/random/np_bernoulli_op.h b/src/operator/numpy/random/np_bernoulli_op.h index b375c0cc00ad..7f0103839e25 100644 --- a/src/operator/numpy/random/np_bernoulli_op.h +++ b/src/operator/numpy/random/np_bernoulli_op.h @@ -64,6 +64,7 @@ struct NumpyBernoulliParam : public dmlc::Parameter { .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) .add_enum("float16", mshadow::kFloat16) + .add_enum("bool", mshadow::kBool) .set_default(mshadow::kFloat32) .describe( "DType of the output in case this can't be inferred. " @@ -77,13 +78,7 @@ inline bool NumpyBernoulliOpType(const nnvm::NodeAttrs &attrs, std::vector *out_attrs) { const NumpyBernoulliParam ¶m = nnvm::get(attrs.parsed); int otype = param.dtype; - if (otype != -1) { - (*out_attrs)[0] = otype; - } else { - // Following torch.distributions, - // the default type will be float32. - (*out_attrs)[0] = mshadow::kFloat32; - } + (*out_attrs)[0] = otype; return true; } @@ -151,7 +146,7 @@ void NumpyBernoulliForward(const nnvm::NodeAttrs &attrs, // scalar prob input CHECK_LE(param.prob.value(), 1.0) << "ValueError: expect probs >= 0 && probs <= 1"; CHECK_GE(param.prob.value(), 0.0) << "ValueError: expect probs >= 0 && probs <= 1"; - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { Kernel, xpu>::Launch( s, outputs[0].Size(), param.prob.value(), uniform_tensor.dptr_, outputs[0].dptr()); @@ -161,7 +156,7 @@ void NumpyBernoulliForward(const nnvm::NodeAttrs &attrs, // sigmoid(x) > u <=> x > logit(u) Kernel::Launch(s, outputs[0].Size(), uniform_tensor.dptr_); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { Kernel, xpu>::Launch( s, outputs[0].Size(), param.logit.value(), uniform_tensor.dptr_, outputs[0].dptr()); @@ -186,7 +181,7 @@ void NumpyBernoulliForward(const nnvm::NodeAttrs &attrs, int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, &new_lshape, &new_lshape, &new_oshape); MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { Shape oshape = new_oshape.get(); Shape stride = calc_stride(new_lshape.get()); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 617e4d780fbe..1aa42fdea0b9 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2298,7 +2298,7 @@ def hybrid_forward(self, F, x): @use_np def test_npx_random_bernoulli(): shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] - dtypes = ['float16', 'float32', 'float64', 'int32'] + dtypes = ['float16', 'float32', 'float64', 'int32', 'bool'] epsilon = 1e-4 for shape, dtype in itertools.product(shapes, dtypes): prob = np.random.uniform(size=shape) From 899faba14d8a182f77d1385712c56c80a32753ba Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 12 Nov 2019 03:27:08 +0000 Subject: [PATCH 8/9] add raise test --- src/operator/numpy/random/np_bernoulli_op.cc | 4 +--- src/operator/numpy/random/np_bernoulli_op.h | 7 +++---- tests/python/unittest/test_numpy_op.py | 5 +++++ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/operator/numpy/random/np_bernoulli_op.cc b/src/operator/numpy/random/np_bernoulli_op.cc index 0b2988d7f975..d67ad1b8d7f6 100644 --- a/src/operator/numpy/random/np_bernoulli_op.cc +++ b/src/operator/numpy/random/np_bernoulli_op.cc @@ -32,7 +32,6 @@ namespace op { DMLC_REGISTER_PARAMETER(NumpyBernoulliParam); NNVM_REGISTER_OP(_npi_bernoulli) -.describe("Sample frmo bernoulli distribution") .set_num_inputs( [](const nnvm::NodeAttrs& attrs) { const NumpyBernoulliParam& param = nnvm::get(attrs.parsed); @@ -51,8 +50,7 @@ NNVM_REGISTER_OP(_npi_bernoulli) if (param.logit.has_value() || param.prob.has_value()) { num_inputs -= 1; } - if (num_inputs == 0) return std::vector(); - return std::vector{"input1"}; + return (num_inputs == 0) ? std::vector() : std::vector{"input1"}; }) .set_attr_parser(ParamParser) .set_attr("FInferShape", UnaryDistOpShape) diff --git a/src/operator/numpy/random/np_bernoulli_op.h b/src/operator/numpy/random/np_bernoulli_op.h index 7f0103839e25..aa8e344e2842 100644 --- a/src/operator/numpy/random/np_bernoulli_op.h +++ b/src/operator/numpy/random/np_bernoulli_op.h @@ -181,14 +181,13 @@ void NumpyBernoulliForward(const nnvm::NodeAttrs &attrs, int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, &new_lshape, &new_lshape, &new_oshape); MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { - MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { Shape oshape = new_oshape.get(); Shape stride = calc_stride(new_lshape.get()); Kernel, xpu>::Launch( - s, outputs[0].Size(), stride, oshape, - inputs[0].dptr(), uniform_tensor.dptr_, - outputs[0].dptr()); + s, outputs[0].Size(), stride, oshape, inputs[0].dptr(), + uniform_tensor.dptr_, outputs[0].dptr()); }); }); }); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1aa42fdea0b9..b995b2a749be 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2312,6 +2312,11 @@ def test_npx_random_bernoulli(): out_logit = npx.random.bernoulli(logit=logit, size=shape, dtype=dtype) assert out_logit.shape == expected_shape assert int((out_logit.asnumpy() == 0).sum() + (out_logit.asnumpy() == 1).sum()) == out_logit.size + # Test Exception. + assertRaises(ValueError, npx.random.bernoulli, prob=prob, logit=logit) + if prob.size > 0: + assertRaises(MXNetError, npx.random.bernoulli, prob=prob + 2.0) + assertRaises(MXNetError, npx.random.bernoulli, prob=prob - 2.0) @with_seed() From 0bf24cc140aa6590397a6efd6446ba75ec46d9c5 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 12 Nov 2019 06:14:02 +0000 Subject: [PATCH 9/9] modify raise test --- tests/python/unittest/test_numpy_op.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b995b2a749be..45ee8d2bbd84 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2297,9 +2297,11 @@ def hybrid_forward(self, F, x): @with_seed() @use_np def test_npx_random_bernoulli(): + def _test_bernoulli_exception(prob, logit): + output = npx.random.bernoulli(prob=prob, logit=logit).asnumpy() + shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None] dtypes = ['float16', 'float32', 'float64', 'int32', 'bool'] - epsilon = 1e-4 for shape, dtype in itertools.product(shapes, dtypes): prob = np.random.uniform(size=shape) logit = np.log(prob) - np.log(1 - prob) @@ -2313,10 +2315,18 @@ def test_npx_random_bernoulli(): assert out_logit.shape == expected_shape assert int((out_logit.asnumpy() == 0).sum() + (out_logit.asnumpy() == 1).sum()) == out_logit.size # Test Exception. - assertRaises(ValueError, npx.random.bernoulli, prob=prob, logit=logit) + assertRaises(ValueError, _test_bernoulli_exception, prob, logit) if prob.size > 0: - assertRaises(MXNetError, npx.random.bernoulli, prob=prob + 2.0) - assertRaises(MXNetError, npx.random.bernoulli, prob=prob - 2.0) + # larger than 1 + assertRaises(MXNetError, _test_bernoulli_exception, prob + 2.0, None) + # smaller than 0 + assertRaises(MXNetError, _test_bernoulli_exception, prob - 2.0, None) + # mixed case + low, high = (-1.0, 2.0) + # uniform(-1, 2) + scaled_prob = low + (high - low) * prob + if not ((scaled_prob.asnumpy() >= 0).all() and (scaled_prob.asnumpy() <= 1).all()): + assertRaises(MXNetError, _test_bernoulli_exception, scaled_prob, None) @with_seed()