Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy]implement exponential backward (#17401)
Browse files Browse the repository at this point in the history
* add output

* add output

* add ctx

* gpu ok

* format
  • Loading branch information
Yiyan66 authored Feb 10, 2020
1 parent b65db3c commit 9aa5088
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 28 deletions.
9 changes: 6 additions & 3 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None):
return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out)


def exponential(scale, size):
def exponential(scale=1.0, size=None, ctx=None, out=None):
r"""Draw samples from an exponential distribution.
Parameters
----------
Expand All @@ -453,13 +453,16 @@ def exponential(scale, size):
"""
from ...numpy import ndarray as np_ndarray
tensor_type_name = np_ndarray
if ctx is None:
ctx = current_context()
if size == ():
size = None
is_tensor = isinstance(scale, tensor_type_name)
if is_tensor:
return _npi.exponential(scale, scale=None, size=size)
return _npi.exponential(scale, scale=None, size=size,
ctx=ctx, out=out)
else:
return _npi.exponential(scale=scale, size=size)
return _npi.exponential(scale=scale, size=size, ctx=ctx, out=out)


def weibull(a, size=None):
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def rand(*size, **kwargs):
return _mx_nd_np.random.uniform(0, 1, size=output_shape, **kwargs)


def exponential(scale=1.0, size=None):
def exponential(scale=1.0, size=None, ctx=None, out=None):
r"""Draw samples from an exponential distribution.
Parameters
Expand All @@ -481,7 +481,7 @@ def exponential(scale=1.0, size=None):
out : ndarray or scalar
Drawn samples from the parameterized exponential distribution.
"""
return _mx_nd_np.random.exponential(scale, size)
return _mx_nd_np.random.exponential(scale, size=size, ctx=ctx, out=out)


def weibull(a, size=None):
Expand Down
9 changes: 6 additions & 3 deletions python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def beta(a, b, size=None, dtype=None, ctx=None):
return out.astype(dtype)


def exponential(scale=1.0, size=None):
def exponential(scale=1.0, size=None, ctx=None, out=None):
r"""Draw samples from an exponential distribution.
Parameters
Expand All @@ -460,13 +460,16 @@ def exponential(scale=1.0, size=None):
"""
from ..numpy import _Symbol as np_symbol
tensor_type_name = np_symbol
if ctx is None:
ctx = current_context()
if size == ():
size = None
is_tensor = isinstance(scale, tensor_type_name)
if is_tensor:
return _npi.exponential(scale, scale=None, size=size)
return _npi.exponential(scale, scale=None, size=size,
ctx=ctx, out=out)
else:
return _npi.exponential(scale=scale, size=size)
return _npi.exponential(scale=scale, size=size, ctx=ctx, out=out)


def weibull(a, size=None):
Expand Down
38 changes: 35 additions & 3 deletions src/operator/numpy/random/np_exponential_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace op {
DMLC_REGISTER_PARAMETER(NumpyExponentialParam);

NNVM_REGISTER_OP(_npi_exponential)
.describe("Numpy behavior exponential")
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyExponentialParam& param = nnvm::get<NumpyExponentialParam>(attrs.parsed);
Expand All @@ -41,7 +42,11 @@ NNVM_REGISTER_OP(_npi_exponential)
}
return num_inputs;
})
.set_num_outputs(1)
.set_num_outputs(2)
.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
[](const NodeAttrs& attrs) {
return 1;
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const NumpyExponentialParam& param = nnvm::get<NumpyExponentialParam>(attrs.parsed);
Expand All @@ -52,10 +57,11 @@ NNVM_REGISTER_OP(_npi_exponential)
return (num_inputs == 0) ? std::vector<std::string>() : std::vector<std::string>{"input1"};
})
.set_attr_parser(ParamParser<NumpyExponentialParam>)
.set_attr<mxnet::FInferShape>("FInferShape", UnaryDistOpShape<NumpyExponentialParam>)
.set_attr<mxnet::FInferShape>("FInferShape", TwoparamsDistOpShape<NumpyExponentialParam>)
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs &attrs, std::vector<int> *in_attrs, std::vector<int> *out_attrs) {
(*out_attrs)[0] = mshadow::kFloat32;
(*out_attrs)[1] = mshadow::kFloat32;
return true;
})
.set_attr<FResourceRequest>("FResourceRequest",
Expand All @@ -64,9 +70,35 @@ NNVM_REGISTER_OP(_npi_exponential)
ResourceRequest::kRandom, ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyExponentialForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_broadcast_exponential"})
.add_argument("input1", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyExponentialParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_broadcast_exponential)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<NumpyExponentialParam>)
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyExponentialParam& param = nnvm::get<NumpyExponentialParam>(attrs.parsed);
int num_inputs = 5;
if (param.scale.has_value()) num_inputs -= 1;
return num_inputs;
}
)
.set_num_outputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyExponentialParam& param = nnvm::get<NumpyExponentialParam>(attrs.parsed);
int num_outputs = 1;
if (param.scale.has_value()) num_outputs -= 1;
return num_outputs;
}
)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", ExponentialReparamBackward<cpu>)
.add_arguments(NumpyExponentialParam::__FIELDS__());

} // namespace op
} // namespace mxnet
3 changes: 3 additions & 0 deletions src/operator/numpy/random/np_exponential_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,8 @@ namespace op {
NNVM_REGISTER_OP(_npi_exponential)
.set_attr<FCompute>("FCompute<gpu>", NumpyExponentialForward<gpu>);

NNVM_REGISTER_OP(_backward_broadcast_exponential)
.set_attr<FCompute>("FCompute<gpu>", ExponentialReparamBackward<gpu>);

} // namespace op
} // namespace mxnet
70 changes: 64 additions & 6 deletions src/operator/numpy/random/np_exponential_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace op {
struct NumpyExponentialParam : public dmlc::Parameter<NumpyExponentialParam> {
dmlc::optional<float> scale;
dmlc::optional<mxnet::Tuple<int>> size;
std::string ctx;
DMLC_DECLARE_PARAMETER(NumpyExponentialParam) {
DMLC_DECLARE_FIELD(scale)
.set_default(dmlc::optional<float>(1.0));
Expand All @@ -52,6 +53,9 @@ struct NumpyExponentialParam : public dmlc::Parameter<NumpyExponentialParam> {
.describe("Output shape. If the given shape is, "
"e.g., (m, n, k), then m * n * k samples are drawn. "
"Default is None, in which case a single value is returned.");
DMLC_DECLARE_FIELD(ctx).set_default("cpu").describe(
"Context of output, in format [cpu|gpu|cpu_pinned](n)."
" Only used for imperative calls.");
}
};

Expand Down Expand Up @@ -83,7 +87,8 @@ struct exponential_kernel {
IType *scales, float* threshold, OType *out) {
Shape<ndim> coord = unravel(i, oshape);
auto idx = static_cast<index_t>(dot(coord, stride));
out[i] = -scales[idx] * log(threshold[i]);
threshold[i] = -log(threshold[i]);
out[i] = scales[idx] * threshold[i];
}
};

Expand All @@ -99,16 +104,15 @@ void NumpyExponentialForward(const nnvm::NodeAttrs &attrs,
using namespace mxnet_op;
const NumpyExponentialParam &param = nnvm::get<NumpyExponentialParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
index_t output_len = outputs[0].Size();
Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
Tensor<xpu, 1, float> workspace =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(output_len + 1), s);
Tensor<xpu, 1, float> uniform_tensor = workspace.Slice(0, output_len);
Tensor<xpu, 1, float> indicator_device = workspace.Slice(output_len, output_len + 1);
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(1), s);
Tensor<xpu, 1, float> uniform_tensor = outputs[1].FlatTo1D<xpu, float>(s);
Tensor<xpu, 1, float> indicator_device = workspace;
float indicator_host = 1.0;
float *indicator_device_ptr = indicator_device.dptr_;
Kernel<set_zero, xpu>::Launch(s, 1, indicator_device_ptr);
prnd->SampleUniform(&workspace, 0.0, 1.0);
prnd->SampleUniform(&uniform_tensor, 0.0, 1.0);
if (param.scale.has_value()) {
CHECK_GE(param.scale.value(), 0.0) << "ValueError: expect scale >= 0";
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Expand Down Expand Up @@ -140,6 +144,60 @@ void NumpyExponentialForward(const nnvm::NodeAttrs &attrs,
}
}

template<typename xpu, int ndim, typename DType>
inline void ScalarExponentialReparamBackwardImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mxnet::TShape& new_ishape,
const mxnet::TShape& new_oshape) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob igrad = outputs[0].reshape(new_ishape);
// inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
// samples, noise]
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob itensor = inputs[2].reshape(new_ishape);
const TBlob samples = inputs[3].reshape(new_oshape);
const TBlob noise = inputs[4].reshape(new_oshape);
size_t workspace_size =
ReduceWorkspaceSize<ndim, DType>(s, igrad.shape_, req[0], ograd.shape_);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
s, igrad, req[0], workspace, ograd, noise, noise);
}

template<typename xpu>
void ExponentialReparamBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
// skip kernel launch for zero-size tensors
if (inputs[0].shape_.Size() == 0U) {
return;
}
// [scalar] case
if (outputs.size() == 0U) {
return;
}
// [tensor] case
if (inputs.size() == 5U) {
mxnet::TShape new_ishape, new_oshape;
int ndim = FillShape(outputs[0].shape_, outputs[0].shape_, inputs[0].shape_,
&new_ishape, &new_ishape, &new_oshape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
ScalarExponentialReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, req, outputs, new_ishape, new_oshape);
});
});
}
}

} // namespace op
} // namespace mxnet

Expand Down
28 changes: 17 additions & 11 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3635,19 +3635,25 @@ def __init__(self, shape):
def hybrid_forward(self, F, scale):
return F.np.random.exponential(scale, self._shape)

shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None]
output_shapes = [
(3, 2),
(4, 3, 2, 2),
(3, 4, 5)
]
for hybridize in [False, True]:
for shape in shapes:
test_exponential = TestRandomExp(shape)
for out_shape in output_shapes:
test_exponential_grad = TestRandomExp(out_shape)
if hybridize:
test_exponential.hybridize()
np_out = _np.random.exponential(size = shape)
mx_out = test_exponential(np.array([1]))

for shape in shapes:
mx_out = np.random.exponential(np.array([1]), shape)
np_out = _np.random.exponential(np.array([1]).asnumpy(), shape)
assert_almost_equal(mx_out.asnumpy().shape, np_out.shape)
test_exponential_grad.hybridize()
scale = np.ones(out_shape)
scale.attach_grad()
with mx.autograd.record():
mx_out = test_exponential_grad(scale)
np_out = _np.random.exponential(scale = scale.asnumpy(), size = out_shape)
assert_almost_equal(np_out.shape, mx_out.shape)
mx_out.backward()
assert scale.grad.shape == out_shape
assert_almost_equal(scale.grad.asnumpy().sum(), mx_out.asnumpy().sum(), rtol=1e-3, atol=1e-5)

def _test_exponential_exception(scale):
output = np.random.exponential(scale=scale).asnumpy()
Expand Down

0 comments on commit 9aa5088

Please sign in to comment.