Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy] add op random.exponential (#17280)
Browse files Browse the repository at this point in the history
* C++ ok

* before rebase

* sanity

* change sth

* change sth

* change sth
  • Loading branch information
Yiyan66 authored and haojin2 committed Jan 19, 2020
1 parent d95d74f commit 28742cf
Show file tree
Hide file tree
Showing 8 changed files with 384 additions and 3 deletions.
30 changes: 29 additions & 1 deletion python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..ndarray import NDArray


__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "shuffle", 'gamma']
__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "shuffle", 'gamma', 'exponential']


def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
Expand Down Expand Up @@ -319,6 +319,34 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None):
return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out)


def exponential(scale, size):
r"""Draw samples from an exponential distribution.
Parameters
----------
scale : float or array_like of floats
The scale parameter, :math:`\beta = 1/\lambda`. Must be
non-negative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``scale`` is a scalar. Otherwise,
``np.array(scale).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized exponential distribution.
"""
from ...numpy import ndarray as np_ndarray
tensor_type_name = np_ndarray
if size == ():
size = None
is_tensor = isinstance(scale, tensor_type_name)
if is_tensor:
return _npi.exponential(scale, scale=None, size=size)
else:
return _npi.exponential(scale=scale, size=size)


def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None):
"""Draw samples from a Gamma distribution.
Expand Down
25 changes: 24 additions & 1 deletion python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from __future__ import absolute_import
from ..ndarray import numpy as _mx_nd_np


__all__ = ["randint", "uniform", "normal", "choice", "rand", "multinomial", "shuffle", "randn",
"gamma"]
"gamma", "exponential"]


def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
Expand Down Expand Up @@ -324,6 +325,28 @@ def rand(*size, **kwargs):
return _mx_nd_np.random.uniform(0, 1, size=output_shape, **kwargs)


def exponential(scale=1.0, size=None):
r"""Draw samples from an exponential distribution.
Parameters
----------
scale : float or array_like of floats
The scale parameter, :math:`\beta = 1/\lambda`. Must be
non-negative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``scale`` is a scalar. Otherwise,
``np.array(scale).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized exponential distribution.
"""
return _mx_nd_np.random.exponential(scale, size)


def shuffle(x):
"""
Modify a sequence in-place by shuffling its contents.
Expand Down
33 changes: 32 additions & 1 deletion python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from ...context import current_context
from . import _internal as _npi

__all__ = ['randint', 'uniform', 'normal', 'rand', 'shuffle', 'gamma']

__all__ = ['randint', 'uniform', 'normal', 'rand', 'shuffle', 'gamma', 'exponential']


def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
Expand Down Expand Up @@ -347,6 +348,36 @@ def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None):
raise ValueError("Distribution parameters must be either _Symbol or numbers")


def exponential(scale=1.0, size=None):
r"""Draw samples from an exponential distribution.
Parameters
----------
scale : float or array_like of floats
The scale parameter, :math:`\beta = 1/\lambda`. Must be
non-negative.
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. If size is ``None`` (default),
a single value is returned if ``scale`` is a scalar. Otherwise,
``np.array(scale).size`` samples are drawn.
Returns
-------
out : ndarray or scalar
Drawn samples from the parameterized exponential distribution.
"""
from ..numpy import _Symbol as np_symbol
tensor_type_name = np_symbol
if size == ():
size = None
is_tensor = isinstance(scale, tensor_type_name)
if is_tensor:
return _npi.exponential(scale, scale=None, size=size)
else:
return _npi.exponential(scale=scale, size=size)


def shuffle(x):
"""
Modify a sequence in-place by shuffling its contents.
Expand Down
72 changes: 72 additions & 0 deletions src/operator/numpy/random/np_exponential_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_exponential_op.cc
* \brief Operator for numpy sampling from exponential distributions
*/

#include "./np_exponential_op.h"
#include "./dist_common.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(NumpyExponentialParam);

NNVM_REGISTER_OP(_npi_exponential)
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyExponentialParam& param = nnvm::get<NumpyExponentialParam>(attrs.parsed);
int num_inputs = 1;
if (param.scale.has_value()) {
num_inputs -= 1;
}
return num_inputs;
})
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const NumpyExponentialParam& param = nnvm::get<NumpyExponentialParam>(attrs.parsed);
int num_inputs = 1;
if (param.scale.has_value()) {
num_inputs -= 1;
}
return (num_inputs == 0) ? std::vector<std::string>() : std::vector<std::string>{"input1"};
})
.set_attr_parser(ParamParser<NumpyExponentialParam>)
.set_attr<mxnet::FInferShape>("FInferShape", UnaryDistOpShape<NumpyExponentialParam>)
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs &attrs, std::vector<int> *in_attrs, std::vector<int> *out_attrs) {
(*out_attrs)[0] = mshadow::kFloat32;
return true;
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<ResourceRequest>{
ResourceRequest::kRandom, ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyExponentialForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("input1", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyExponentialParam::__FIELDS__());

} // namespace op
} // namespace mxnet
35 changes: 35 additions & 0 deletions src/operator/numpy/random/np_exponential_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_exponential_op.cu
* \brief Operator for numpy sampling from exponential distributions
*/

#include "./np_exponential_op.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_npi_exponential)
.set_attr<FCompute>("FCompute<gpu>", NumpyExponentialForward<gpu>);

} // namespace op
} // namespace mxnet
146 changes: 146 additions & 0 deletions src/operator/numpy/random/np_exponential_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_exponential_op.h
* \brief Operator for numpy sampling from exponential distribution.
*/

#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_EXPONENTIAL_OP_H_
#define MXNET_OPERATOR_NUMPY_RANDOM_NP_EXPONENTIAL_OP_H_

#include <mxnet/operator_util.h>
#include <algorithm>
#include <string>
#include <vector>
#include <cmath>
#include "../../elemwise_op_common.h"
#include "../../mshadow_op.h"
#include "../../mxnet_op.h"
#include "../../operator_common.h"
#include "../../tensor/elemwise_binary_broadcast_op.h"
#include "./dist_common.h"

namespace mxnet {
namespace op {

struct NumpyExponentialParam : public dmlc::Parameter<NumpyExponentialParam> {
dmlc::optional<float> scale;
dmlc::optional<mxnet::Tuple<int>> size;
DMLC_DECLARE_PARAMETER(NumpyExponentialParam) {
DMLC_DECLARE_FIELD(scale)
.set_default(dmlc::optional<float>(1.0));
DMLC_DECLARE_FIELD(size)
.set_default(dmlc::optional<mxnet::Tuple<int>>())
.describe("Output shape. If the given shape is, "
"e.g., (m, n, k), then m * n * k samples are drawn. "
"Default is None, in which case a single value is returned.");
}
};

template <typename DType>
struct scalar_exponential_kernel {
MSHADOW_XINLINE static void Map(index_t i, float scale, float *threshold,
DType *out) {
out[i] = -scale * log(threshold[i]);
}
};

namespace mxnet_op {

template <typename IType>
struct check_legal_scale_kernel {
MSHADOW_XINLINE static void Map(index_t i, IType *scalar, float* flag) {
if (scalar[i] < 0.0) {
flag[0] = -1.0;
}
}
};


template <int ndim, typename IType, typename OType>
struct exponential_kernel {
MSHADOW_XINLINE static void Map(index_t i,
const Shape<ndim> &stride,
const Shape<ndim> &oshape,
IType *scales, float* threshold, OType *out) {
Shape<ndim> coord = unravel(i, oshape);
auto idx = static_cast<index_t>(dot(coord, stride));
out[i] = -scales[idx] * log(threshold[i]);
}
};

} // namespace mxnet_op

template <typename xpu>
void NumpyExponentialForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mxnet_op;
const NumpyExponentialParam &param = nnvm::get<NumpyExponentialParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
index_t output_len = outputs[0].Size();
Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
Tensor<xpu, 1, float> workspace =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(output_len + 1), s);
Tensor<xpu, 1, float> uniform_tensor = workspace.Slice(0, output_len);
Tensor<xpu, 1, float> indicator_device = workspace.Slice(output_len, output_len + 1);
float indicator_host = 1.0;
float *indicator_device_ptr = indicator_device.dptr_;
Kernel<set_zero, xpu>::Launch(s, 1, indicator_device_ptr);
prnd->SampleUniform(&workspace, 0.0, 1.0);
if (param.scale.has_value()) {
CHECK_GE(param.scale.value(), 0.0) << "ValueError: expect scale >= 0";
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Kernel<scalar_exponential_kernel<DType>, xpu>::Launch(
s, outputs[0].Size(), param.scale.value(),
uniform_tensor.dptr_, outputs[0].dptr<DType>());
});
} else {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
Kernel<check_legal_scale_kernel<IType>, xpu>::Launch(
s, inputs[0].Size(), inputs[0].dptr<IType>(), indicator_device_ptr);
});
_copy<xpu>(s, &indicator_host, indicator_device_ptr);
CHECK_GE(indicator_host, 0.0) << "ValueError: expect scale >= 0";
mxnet::TShape new_lshape, new_oshape;
int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_,
&new_lshape, &new_lshape, &new_oshape);
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
Shape<NDim> oshape = new_oshape.get<NDim>();
Shape<NDim> stride = calc_stride(new_lshape.get<NDim>());
Kernel<exponential_kernel<NDim, IType, OType>, xpu>::Launch(
s, outputs[0].Size(), stride, oshape, inputs[0].dptr<IType>(),
uniform_tensor.dptr_, outputs[0].dptr<OType>());
});
});
});
}
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_EXPONENTIAL_OP_H_
Loading

0 comments on commit 28742cf

Please sign in to comment.