From b6b1de092b2bbc6ab7207a98dcb1c08fe67ca14b Mon Sep 17 00:00:00 2001 From: "D. Roberts" Date: Sat, 15 Feb 2020 03:15:44 -0500 Subject: [PATCH] Implement Weibull backward (#17590) --- python/mxnet/ndarray/numpy/random.py | 8 +- python/mxnet/numpy/random.py | 4 +- python/mxnet/symbol/numpy/random.py | 8 +- src/operator/numpy/random/np_weibull_op.cc | 38 ++++++++- src/operator/numpy/random/np_weibull_op.cu | 3 + src/operator/numpy/random/np_weibull_op.h | 94 ++++++++++++++++++---- tests/python/unittest/test_numpy_op.py | 33 ++++++++ 7 files changed, 160 insertions(+), 28 deletions(-) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 8a791a56c259..8d99bb1ab1c8 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -570,7 +570,7 @@ def exponential(scale=1.0, size=None, ctx=None, out=None): return _npi.exponential(scale=scale, size=size, ctx=ctx, out=out) -def weibull(a, size=None): +def weibull(a, size=None, ctx=None, out=None): r"""Draw samples from a 1-parameter Weibull distribution with given parameter a, via inversion. @@ -614,13 +614,15 @@ def weibull(a, size=None): """ from ...numpy import ndarray as np_ndarray tensor_type_name = np_ndarray + if ctx is None: + ctx = current_context() if size == (): size = None is_tensor = isinstance(a, tensor_type_name) if is_tensor: - return _npi.weibull(a, a=None, size=size) + return _npi.weibull(a, a=None, size=size, ctx=ctx, out=out) else: - return _npi.weibull(a=a, size=size) + return _npi.weibull(a=a, size=size, ctx=ctx, out=out) def pareto(a, size=None): diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 2ff6b9532189..3c05ff17de0f 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -611,7 +611,7 @@ def exponential(scale=1.0, size=None, ctx=None, out=None): return _mx_nd_np.random.exponential(scale, size=size, ctx=ctx, out=out) -def weibull(a, size=None): +def weibull(a, size=None, ctx=None, out=None): r"""Draw samples from a 1-parameter Weibull distribution with given parameter a via inversion. @@ -653,7 +653,7 @@ def weibull(a, size=None): model time to failure, in modeling particle sizes, in information retrieval to model dwell time on pages, in quantitative finance to model risk etc. """ - return _mx_nd_np.random.weibull(a, size) + return _mx_nd_np.random.weibull(a, size=size, ctx=ctx, out=out) def pareto(a, size=None): diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index fae9c037ded2..5885488550f5 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -638,7 +638,7 @@ def exponential(scale=1.0, size=None, ctx=None, out=None): return _npi.exponential(scale=scale, size=size, ctx=ctx, out=out) -def weibull(a, size=None): +def weibull(a, size=None, ctx=None, out=None): r"""Draw samples from a 1-parameter Weibull distribution with given parameter a via inversion. @@ -684,13 +684,15 @@ def weibull(a, size=None): """ from ..numpy import _Symbol as np_symbol tensor_type_name = np_symbol + if ctx is None: + ctx = current_context() if size == (): size = None is_tensor = isinstance(a, tensor_type_name) if is_tensor: - return _npi.weibull(a, a=None, size=size) + return _npi.weibull(a, a=None, size=size, ctx=ctx, out=out) else: - return _npi.weibull(a=a, size=size) + return _npi.weibull(a=a, size=size, ctx=ctx, out=out) def pareto(a, size=None): diff --git a/src/operator/numpy/random/np_weibull_op.cc b/src/operator/numpy/random/np_weibull_op.cc index e204c7c95100..7e8d8bf5fabc 100644 --- a/src/operator/numpy/random/np_weibull_op.cc +++ b/src/operator/numpy/random/np_weibull_op.cc @@ -32,6 +32,7 @@ namespace op { DMLC_REGISTER_PARAMETER(NumpyWeibullParam); NNVM_REGISTER_OP(_npi_weibull) +.describe("Numpy behavior Weibull") .set_num_inputs( [](const nnvm::NodeAttrs& attrs) { const NumpyWeibullParam& param = nnvm::get(attrs.parsed); @@ -41,7 +42,11 @@ NNVM_REGISTER_OP(_npi_weibull) } return num_inputs; }) -.set_num_outputs(1) +.set_num_outputs(2) +.set_attr("FNumVisibleOutputs", + [](const NodeAttrs& attrs){ + return 1; + }) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { const NumpyWeibullParam& param = nnvm::get(attrs.parsed); @@ -52,10 +57,11 @@ NNVM_REGISTER_OP(_npi_weibull) return (num_inputs == 0) ? std::vector() : std::vector{"input1"}; }) .set_attr_parser(ParamParser) -.set_attr("FInferShape", UnaryDistOpShape) +.set_attr("FInferShape", TwoparamsDistOpShape) .set_attr("FInferType", [](const nnvm::NodeAttrs &attrs, std::vector *in_attrs, std::vector *out_attrs) { (*out_attrs)[0] = mshadow::kFloat32; + (*out_attrs)[1] = mshadow::kFloat32; return true; }) .set_attr("FResourceRequest", @@ -64,9 +70,35 @@ NNVM_REGISTER_OP(_npi_weibull) ResourceRequest::kRandom, ResourceRequest::kTempSpace}; }) .set_attr("FCompute", NumpyWeibullForward) -.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_broadcast_weibull"}) .add_argument("input1", "NDArray-or-Symbol", "Source input") .add_arguments(NumpyWeibullParam::__FIELDS__()); +NNVM_REGISTER_OP(_backward_broadcast_weibull) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser) +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs){ + const NumpyWeibullParam& param = nnvm::get(attrs.parsed); + int num_inputs = 5; + if (param.a.has_value()) num_inputs -= 1; + return num_inputs; + } +) + .set_num_outputs( + [](const nnvm::NodeAttrs& attrs){ + const NumpyWeibullParam& param = nnvm::get(attrs.parsed); + int num_outputs = 1; + if (param.a.has_value()) num_outputs -= 1; + return num_outputs; + } + ) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs){ + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("FCompute", WeibullReparamBackward) + .add_arguments(NumpyWeibullParam::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_weibull_op.cu b/src/operator/numpy/random/np_weibull_op.cu index 1baf78f38fa7..57d609d62768 100644 --- a/src/operator/numpy/random/np_weibull_op.cu +++ b/src/operator/numpy/random/np_weibull_op.cu @@ -31,5 +31,8 @@ namespace op { NNVM_REGISTER_OP(_npi_weibull) .set_attr("FCompute", NumpyWeibullForward); +NNVM_REGISTER_OP(_backward_broadcast_weibull) +.set_attr("FCompute", WeibullReparamBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_weibull_op.h b/src/operator/numpy/random/np_weibull_op.h index ece730aad180..afb37288b04e 100644 --- a/src/operator/numpy/random/np_weibull_op.h +++ b/src/operator/numpy/random/np_weibull_op.h @@ -44,6 +44,7 @@ namespace op { struct NumpyWeibullParam : public dmlc::Parameter { dmlc::optional a; dmlc::optional> size; + std::string ctx; DMLC_DECLARE_PARAMETER(NumpyWeibullParam) { DMLC_DECLARE_FIELD(a) .set_default(dmlc::optional()); @@ -52,14 +53,17 @@ struct NumpyWeibullParam : public dmlc::Parameter { .describe("Output shape. If the given shape is, " "e.g., (m, n, k), then m * n * k samples are drawn. " "Default is None, in which case a single value is returned."); + DMLC_DECLARE_FIELD(ctx).set_default("cpu").describe( + "Context of output, in format [cpu|gpu|cpu_pinned](n)." + " Only used for imperative calls."); } }; template struct scalar_weibull_kernel { - MSHADOW_XINLINE static void Map(index_t i, float a, float *threshold, + MSHADOW_XINLINE static void Map(index_t i, float a, float *noise, DType *out) { - out[i] = powf(-log(threshold[i]), DType(1.0/a)); + out[i] = powf(-log(noise[i]), DType(1.0/a)); } }; @@ -67,8 +71,8 @@ namespace mxnet_op { template struct check_legal_a_kernel { - MSHADOW_XINLINE static void Map(index_t i, IType *a, float* flag) { - if (a[i] < 0.0) { + MSHADOW_XINLINE static void Map(index_t i, IType *a, float *flag) { + if (a[i] <= 0.0) { flag[0] = -1.0; } } @@ -80,10 +84,13 @@ struct weibull_kernel { MSHADOW_XINLINE static void Map(index_t i, const Shape &stride, const Shape &oshape, - IType *aparams, float* threshold, OType *out) { + IType *aparams, float *noise, OType *out) { Shape coord = unravel(i, oshape); auto idx = static_cast(dot(coord, stride)); - out[i] = powf(-log(threshold[i]), IType(1.0/aparams[idx])); + noise[i] = -log(noise[i]); + out[i] = powf(noise[i], IType(1.0/aparams[idx])); + // get grad + noise[i] = -log(noise[i]) * out[i] * (1.0/(aparams[idx] * aparams[idx])); } }; @@ -91,26 +98,25 @@ struct weibull_kernel { template void NumpyWeibullForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { using namespace mshadow; using namespace mxnet_op; const NumpyWeibullParam ¶m = nnvm::get(attrs.parsed); Stream *s = ctx.get_stream(); - index_t output_len = outputs[0].Size(); Random *prnd = ctx.requested[0].get_random(s); Tensor workspace = - ctx.requested[1].get_space_typed(Shape1(output_len + 1), s); - Tensor uniform_tensor = workspace.Slice(0, output_len); - Tensor indicator_device = workspace.Slice(output_len, output_len + 1); + ctx.requested[1].get_space_typed(Shape1(1), s); + Tensor uniform_tensor = outputs[1].FlatTo1D(s); + Tensor indicator_device = workspace; float indicator_host = 1.0; float *indicator_device_ptr = indicator_device.dptr_; Kernel::Launch(s, 1, indicator_device_ptr); - prnd->SampleUniform(&workspace, 0.0, 1.0); + prnd->SampleUniform(&uniform_tensor, 0.0, 1.0); if (param.a.has_value()) { - CHECK_GE(param.a.value(), 0.0) << "ValueError: expect a >= 0"; + CHECK_GT(param.a.value(), 0.0) << "ValueError: expect a > 0"; MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { Kernel, xpu>::Launch( s, outputs[0].Size(), param.a.value(), @@ -122,7 +128,7 @@ void NumpyWeibullForward(const nnvm::NodeAttrs &attrs, s, inputs[0].Size(), inputs[0].dptr(), indicator_device_ptr); }); _copy(s, &indicator_host, indicator_device_ptr); - CHECK_GE(indicator_host, 0.0) << "ValueError: expect a >= 0"; + CHECK_GE(indicator_host, 0.0) << "ValueError: expect a > 0"; mxnet::TShape new_lshape, new_oshape; int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_, &new_lshape, &new_lshape, &new_oshape); @@ -140,6 +146,60 @@ void NumpyWeibullForward(const nnvm::NodeAttrs &attrs, } } +template +inline void ScalarWeibullReparamBackwardImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& new_ishape, + const mxnet::TShape& new_oshape) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace broadcast; + Stream *s = ctx.get_stream(); + const TBlob igrad = outputs[0].reshape(new_ishape); + // inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor, + // samples, noise] + const TBlob ograd = inputs[0].reshape(new_oshape); + const TBlob itensor = inputs[2].reshape(new_ishape); + const TBlob samples = inputs[3].reshape(new_oshape); + const TBlob noise = inputs[4].reshape(new_oshape); + size_t workspace_size = + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + Reduce( + s, igrad, req[0], workspace, ograd, noise, noise); + } + +template +void WeibullReparamBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs) { +// skip kernel launch for zero-size tensors +if (inputs[0].shape_.Size() == 0U) { + return; +} +// [scalar] case +if (outputs.size() == 0U) { + return; +} +// [tensor] case +if (inputs.size() == 5U) { + mxnet::TShape new_ishape, new_oshape; + int ndim = FillShape(outputs[0].shape_, outputs[0].shape_, inputs[0].shape_, + &new_ishape, &new_ishape, &new_oshape); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + ScalarWeibullReparamBackwardImpl( + ctx, inputs, reqs, outputs, new_ishape, new_oshape); + }); + }); + } +} + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 6f94ea22265e..759732e27aea 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4046,6 +4046,39 @@ def _test_exception(a): assertRaises(ValueError, _test_exception, 0) +@with_seed() +@use_np +def test_np_weibull_grad(): + class TestRandomW(HybridBlock): + def __init__(self, shape): + super(TestRandomW, self).__init__() + self._shape = shape + + def hybrid_forward(self, F, a): + return F.np.random.weibull(a, self._shape) + + output_shapes = [ + (3, 2), + (4, 3, 2, 2), + (3, 4, 5) + ] + for hybridize in [False, True]: + for out_shape in output_shapes: + test_w_grad = TestRandomW(out_shape) + if hybridize: + test_w_grad.hybridize() + a = np.ones(out_shape) + a.attach_grad() + with mx.autograd.record(): + mx_out = test_w_grad(a) + mx_out.backward() + + # gradient formula calculus (a=1) + formula_grad = - mx_out * np.log(mx_out) + assert a.grad.shape == out_shape + assert_almost_equal(a.grad.asnumpy().sum(), formula_grad.asnumpy().sum(), rtol=1e-3, atol=1e-5) + + @with_seed() @use_np def test_np_randn():