Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Add sampling method for bernoulli #16638

Merged
merged 9 commits into from
Nov 12, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def write_all_str(module_file, module_all_list):
_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']

_NP_EXT_OP_PREFIX = '_npx_'
_NP_EXT_OP_SUBMODULE_LIST = ['_image_']
_NP_EXT_OP_SUBMODULE_LIST = ['_image_', '_random_']

_NP_INTERNAL_OP_PREFIX = '_npi_'

Expand Down
1 change: 1 addition & 0 deletions python/mxnet/ndarray/numpy_extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from . import _op
from . import image
from . import random
from . import _register
from ._op import * # pylint: disable=wildcard-import

Expand Down
57 changes: 57 additions & 0 deletions python/mxnet/ndarray/numpy_extension/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Namespace for operators used in Gluon dispatched by F=ndarray."""
from __future__ import absolute_import
from ...context import current_context
from .. import _internal as _npi


__all__ = ['bernoulli']


def bernoulli(probs, logits, size, dtype, ctx, out):
"""
Sampling from bernoulli distribution.
"""
from ...numpy import ndarray as np_ndarray
tensor_type_name = np_ndarray
if (probs is None) == (logits is None):
raise ValueError(
"Either `probs` or `logits` must be specified, but not both.")
if dtype is None:
dtype = 'float32'
if ctx is None:
ctx = current_context()
if size == ():
size = None
if probs is not None:
is_tensor = isinstance(probs, tensor_type_name)
if is_tensor:
return _npi.bernoulli(probs, probs=None, logits=None, is_logit=False,
size=size, ctx=ctx, dtype=dtype, out=out)
else:
return _npi.bernoulli(probs=probs, logits=None, is_logit=False,
size=size, ctx=ctx, dtype=dtype, out=out)
else:
is_tensor = isinstance(logits, tensor_type_name)
if is_tensor:
return _npi.bernoulli(logits, probs=None, logits=None, is_logit=True,
size=size, ctx=ctx, dtype=dtype, out=out)
else:
return _npi.bernoulli(probs=None, logits=logits, is_logit=True,
size=size, ctx=ctx, dtype=dtype, out=out)
2 changes: 1 addition & 1 deletion python/mxnet/numpy_extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
from __future__ import absolute_import
from . import _op
from . import image
from . import random # pylint: disable=wildcard-import
from . import _register
from ._op import * # pylint: disable=wildcard-import
from ..context import * # pylint: disable=wildcard-import
from ..util import is_np_shape, is_np_array, set_np, reset_np, get_cuda_compute_capability
from ..ndarray import waitall
from .utils import * # pylint: disable=wildcard-import
from . import random # pylint: disable=wildcard-import

__all__ = []
10 changes: 9 additions & 1 deletion python/mxnet/numpy_extension/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@

from __future__ import absolute_import
from .. import random as _mx_rand
from ..ndarray import numpy_extension as _mx_nd_npx


__all__ = ['seed']
__all__ = ['seed', 'bernoulli']


def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name
Expand Down Expand Up @@ -72,3 +73,10 @@ def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name
array(0.9894903, ctx=gpu(0))
"""
_mx_rand.seed(seed_state=seed, ctx=ctx)


def bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None):
"""
Sampling from bernoulli distribution.
"""
return _mx_nd_npx.random.bernoulli(probs, logits, size, dtype, ctx, out)
1 change: 1 addition & 0 deletions python/mxnet/symbol/numpy_extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from . import _op
from . import image
from . import random
from . import _register
from ._op import * # pylint: disable=wildcard-import

Expand Down
57 changes: 57 additions & 0 deletions python/mxnet/symbol/numpy_extension/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Namespace for operators used in Gluon dispatched by F=symbol."""

from __future__ import absolute_import
from ...context import current_context
from .. import _internal as _npi
xidulu marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ['bernoulli']


def bernoulli(probs=None, logits=None, size=None, dtype=None, ctx=None, out=None):
"""
Sampling from beroulli distributions.
"""
from ..numpy import _Symbol as np_symbol
tensor_type_name = np_symbol
if (probs is None) == (logits is None):
raise ValueError(
"Either `probs` or `logits` must be specified, but not both.")
if dtype is None:
dtype = 'float32'
if ctx is None:
ctx = current_context()
if size == ():
size = None
if probs is not None:
is_tensor = isinstance(probs, tensor_type_name)
if is_tensor:
return _npi.bernoulli(probs, probs=None, logits=None, is_logit=False,
size=size, ctx=ctx, dtype=dtype, out=out)
else:
return _npi.bernoulli(probs=probs, logits=None, is_logit=False,
size=size, ctx=ctx, dtype=dtype, out=out)
else:
is_tensor = isinstance(logits, tensor_type_name)
if is_tensor:
return _npi.bernoulli(logits, probs=None, logits=None, is_logit=True,
size=size, ctx=ctx, dtype=dtype, out=out)
else:
return _npi.bernoulli(probs=None, logits=logits, is_logit=True,
size=size, ctx=ctx, dtype=dtype, out=out)
31 changes: 28 additions & 3 deletions src/operator/numpy/random/dist_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#ifndef MXNET_OPERATOR_NUMPY_RANDOM_DIST_COMMON_H_
#define MXNET_OPERATOR_NUMPY_RANDOM_DIST_COMMON_H_

#include <mshadow/base.h>
#include <mxnet/operator_util.h>
#include <algorithm>
#include <string>
Expand Down Expand Up @@ -172,10 +171,36 @@ inline bool TwoparamsDistOpShape(const nnvm::NodeAttrs &attrs,
} else if (in_attrs->size() == 0) {
// Two scalar case.
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1))
return true;
}
}
return out_attrs->at(0).ndim() != 0U;
return shape_is_known(out_attrs->at(0));
}

template <typename DistParam>
inline bool UnaryDistOpShape(const nnvm::NodeAttrs &attrs,
std::vector<TShape> *in_attrs,
xidulu marked this conversation as resolved.
Show resolved Hide resolved
std::vector<TShape> *out_attrs) {
const DistParam &param = nnvm::get<DistParam>(attrs.parsed);
if (param.size.has_value()) {
// Size declared.
std::vector<dim_t> oshape_vec;
const mxnet::Tuple<int> &size = param.size.value();
for (int i = 0; i < size.ndim(); ++i) {
oshape_vec.emplace_back(size[i]);
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec));
for (size_t input_idx = 0; input_idx < in_attrs->size(); input_idx++) {
CheckBroadcastable((*in_attrs)[input_idx], (*out_attrs)[0]);
}
} else {
if (in_attrs->size() == 1U) {
// One param from ndarray.
SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0))
} else {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1))
}
}
return shape_is_known(out_attrs->at(0));
}

} // namespace op
Expand Down
71 changes: 71 additions & 0 deletions src/operator/numpy/random/np_bernoulli_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_bernoulli_op.cc
* \brief Operator for numpy sampling from bernoulli distributions
*/

#include "./np_bernoulli_op.h"
#include "./dist_common.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(NumpyBernoulliParam);

NNVM_REGISTER_OP(_npi_bernoulli)
.describe("Sample frmo bernoulli distribution")
xidulu marked this conversation as resolved.
Show resolved Hide resolved
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyBernoulliParam& param = nnvm::get<NumpyBernoulliParam>(attrs.parsed);
int num_inputs = 1;
if (param.logit.has_value() || param.prob.has_value()) {
num_inputs -= 1;
}
return num_inputs;
}
)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const NumpyBernoulliParam& param = nnvm::get<NumpyBernoulliParam>(attrs.parsed);
int num_inputs = 1;
if (param.logit.has_value() || param.prob.has_value()) {
num_inputs -= 1;
}
if (num_inputs == 0) return std::vector<std::string>();
return std::vector<std::string>{"input1"};
xidulu marked this conversation as resolved.
Show resolved Hide resolved
})
.set_attr_parser(ParamParser<NumpyBernoulliParam>)
.set_attr<mxnet::FInferShape>("FInferShape", UnaryDistOpShape<NumpyBernoulliParam>)
.set_attr<nnvm::FInferType>("FInferType", NumpyBernoulliOpType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<ResourceRequest>{
ResourceRequest::kRandom, ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyBernoulliForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("input1", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyBernoulliParam::__FIELDS__());

} // namespace op
} // namespace mxnet
35 changes: 35 additions & 0 deletions src/operator/numpy/random/np_bernoulli_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_bernoulli_op.cu
* \brief Operator for numpy sampling from bernoulli distributions
*/

#include "./np_bernoulli_op.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_npi_bernoulli)
.set_attr<FCompute>("FCompute<gpu>", NumpyBernoulliForward<gpu>);

} // namespace op
} // namespace mxnet
Loading