From 2c21603f6dad7c46af5899e2f9ca29ea469791ec Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 29 Jul 2020 06:05:57 +0000 Subject: [PATCH 1/8] gamma grad wip --- src/operator/mshadow_op.h | 26 +++++++++++++ src/operator/numpy/random/np_gamma_op.cc | 31 ++++++++++++++- src/operator/numpy/random/np_gamma_op.h | 48 ++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 1 deletion(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 55f26b08fcc1..ec72f291285d 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -1167,6 +1167,32 @@ struct smooth_l1_gradient : public mxnet_op::tunable { } }; // struct smooth_l1_derivative +/* Implicti reparameterization gradient for standard x ~ Gamma(\alpha, 1) + * according to dx/da = -cdf(x;alpha) / pdf(x;alpha) + */ +struct gamma_implicit_grad : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + auto numer = 1; + auto denom = a; + auto series1 = numer / denom; + auto series2 = numer / (denom * denom); + for (int i = 1; i <= 5; i++) { + numer *= -x / static_cast(i); + denom += 1; + series1 += numer / denom; + series2 += numer / (denom * denom); + } + auto pow_x_alpha = math::pow(x, a); + auto gamma_pdf = math::pow(x, a - 1) * math::exp(-x); + auto gamma_cdf = pow_x_alpha * series1; + auto gamma_cdf_alpha = (math::log(x) - + DType(special_functions::cephes::psi(a)) * gamma_cdf - + pow_x_alpha * series2); + return IsNan(result) ? static_cast(0.f) : result; + } +}; // gamma_implicit_grad + /*! \brief product reducer */ struct product { /*! \brief do reduction into dst */ diff --git a/src/operator/numpy/random/np_gamma_op.cc b/src/operator/numpy/random/np_gamma_op.cc index f87e997d549e..748d36a86526 100644 --- a/src/operator/numpy/random/np_gamma_op.cc +++ b/src/operator/numpy/random/np_gamma_op.cc @@ -75,10 +75,39 @@ NNVM_REGISTER_OP(_npi_gamma) ResourceRequest::kTempSpace}; }) .set_attr("FCompute", NumpyGammaForward) -.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_broadcast_gamma"}) .add_argument("input1", "NDArray-or-Symbol", "Source input") .add_argument("input2", "NDArray-or-Symbol", "Source input") .add_arguments(NumpyGammaParam::__FIELDS__()); +NNVM_REGISTER_OP(_backward_broadcast_gamma) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser) +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyNormalParam& param = nnvm::get(attrs.parsed); + int num_inputs = 5; + if (param.shape.has_value()) num_inputs -= 1; + if (param.scale.has_value()) num_inputs -= 1; + return num_inputs; + } +) +.set_num_outputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyNormalParam& param = nnvm::get(attrs.parsed); + int num_outputs = 2; + if (param.shape.has_value()) num_outputs -= 1; + if (param.scale.has_value()) num_outputs -= 1; + return num_outputs; + } +) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NormalReparamBackward) +.add_arguments(NumpyNormalParam::__FIELDS__()); + + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index 58ca4c7c52c0..10da84458f7d 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -236,6 +236,27 @@ struct gamma_two_scalar_kernel { out[i] = sample; } } + +// Backward utils +template +MSHADOW_XINLINE void StandardGammaPdf(IType a, IType x) { + return pow(x, alpha - 1) * exp(-x) +} + +template +MSHADOW_XINLINE void StandardGammaCdf(IType a, IType x) { + // Approximate the Gamma cdf via taylor series + IType numer = 1; + IType denom = a; + + for (int i = 1; i <= 5; i++) { + numer *= -x / i; + denom += 1; + series1 + } +} + + }; } // namespace mxnet_op @@ -394,6 +415,33 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } } +// Allow gamma sampling to be differentiable, +// using implicit reparameterization gradient: +// -(d/d\alpha cdf(x;alpha)) / pdf(x;alpha) +template +void NumpyGammaGrad(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + // skip kernel launch for zero-size tensors + if (inputs[0].shape_.Size() == 0U) { + return; + } + + // [scalar, scalar] case + if (outputs.size() == 0U) { + return; + } + const NumpyGammaParam ¶m = nnvm::get(attrs.parsed) + + // [tensor tensor] case + if (inputs.size() == 5U) { + + } + +} + } // namespace op } // namespace mxnet From 4cc80177275b8d9a93fbe92250b62eebad1a2ea1 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Thu, 30 Jul 2020 10:56:41 +0000 Subject: [PATCH 2/8] gamma grad wip --- src/operator/mshadow_op.h | 21 ++++--- src/operator/numpy/random/np_gamma_op.cc | 14 ++--- src/operator/numpy/random/np_gamma_op.cu | 3 + src/operator/numpy/random/np_gamma_op.h | 79 +++++++++++++++++------- src/operator/operator_tune.cc | 1 + 5 files changed, 78 insertions(+), 40 deletions(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index ec72f291285d..52428cd36804 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -1172,27 +1172,30 @@ struct smooth_l1_gradient : public mxnet_op::tunable { */ struct gamma_implicit_grad : public mxnet_op::tunable { template - MSHADOW_XINLINE static DType Map(DType a, DType b) { - auto numer = 1; - auto denom = a; - auto series1 = numer / denom; - auto series2 = numer / (denom * denom); + MSHADOW_XINLINE static DType Map(DType a, DType x) { + DType numer = 1; + DType denom = a; + DType series1 = numer / denom; + DType series2 = numer / (denom * denom); for (int i = 1; i <= 5; i++) { numer *= -x / static_cast(i); denom += 1; series1 += numer / denom; series2 += numer / (denom * denom); } - auto pow_x_alpha = math::pow(x, a); - auto gamma_pdf = math::pow(x, a - 1) * math::exp(-x); - auto gamma_cdf = pow_x_alpha * series1; - auto gamma_cdf_alpha = (math::log(x) - + DType pow_x_alpha = math::pow(x, a); + DType gamma_pdf = math::pow(x, a - 1) * math::exp(-x); + DType gamma_cdf = pow_x_alpha * series1; + DType gamma_cdf_alpha = (math::log(x) - DType(special_functions::cephes::psi(a)) * gamma_cdf - pow_x_alpha * series2); + DType result = -gamma_cdf_alpha / gamma_pdf; return IsNan(result) ? static_cast(0.f) : result; } }; // gamma_implicit_grad +MXNET_BINARY_MATH_OP_NC_WITH_BOOL(neg_div, -a / b); + /*! \brief product reducer */ struct product { /*! \brief do reduction into dst */ diff --git a/src/operator/numpy/random/np_gamma_op.cc b/src/operator/numpy/random/np_gamma_op.cc index 748d36a86526..d4673bb807d4 100644 --- a/src/operator/numpy/random/np_gamma_op.cc +++ b/src/operator/numpy/random/np_gamma_op.cc @@ -75,18 +75,18 @@ NNVM_REGISTER_OP(_npi_gamma) ResourceRequest::kTempSpace}; }) .set_attr("FCompute", NumpyGammaForward) -.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_broadcast_gamma"}) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_gamma"}) .add_argument("input1", "NDArray-or-Symbol", "Source input") .add_argument("input2", "NDArray-or-Symbol", "Source input") .add_arguments(NumpyGammaParam::__FIELDS__()); -NNVM_REGISTER_OP(_backward_broadcast_gamma) +NNVM_REGISTER_OP(_backward_gamma) .set_attr("TIsBackward", true) .set_attr_parser(ParamParser) .set_num_inputs( [](const nnvm::NodeAttrs& attrs) { - const NumpyNormalParam& param = nnvm::get(attrs.parsed); - int num_inputs = 5; + const NumpyGammaParam& param = nnvm::get(attrs.parsed); + int num_inputs = 4; if (param.shape.has_value()) num_inputs -= 1; if (param.scale.has_value()) num_inputs -= 1; return num_inputs; @@ -94,7 +94,7 @@ NNVM_REGISTER_OP(_backward_broadcast_gamma) ) .set_num_outputs( [](const nnvm::NodeAttrs& attrs) { - const NumpyNormalParam& param = nnvm::get(attrs.parsed); + const NumpyGammaParam& param = nnvm::get(attrs.parsed); int num_outputs = 2; if (param.shape.has_value()) num_outputs -= 1; if (param.scale.has_value()) num_outputs -= 1; @@ -105,8 +105,8 @@ NNVM_REGISTER_OP(_backward_broadcast_gamma) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("FCompute", NormalReparamBackward) -.add_arguments(NumpyNormalParam::__FIELDS__()); +.set_attr("FCompute", NumpyGammaGrad) +.add_arguments(NumpyGammaParam::__FIELDS__()); } // namespace op diff --git a/src/operator/numpy/random/np_gamma_op.cu b/src/operator/numpy/random/np_gamma_op.cu index 5be15c7b9d13..e043b2b9a33b 100644 --- a/src/operator/numpy/random/np_gamma_op.cu +++ b/src/operator/numpy/random/np_gamma_op.cu @@ -32,5 +32,8 @@ namespace op { NNVM_REGISTER_OP(_npi_gamma) .set_attr("FCompute", NumpyGammaForward); +NNVM_REGISTER_OP(_backward_gamma) +.set_attr("FCompute", NumpyGammaGrad); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index 10da84458f7d..ba99bf11cc1c 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -136,6 +136,13 @@ struct CheckSuccessKernel { } }; +template +struct StandarizeKernel { + MSHADOW_XINLINE static void Map(int i, DType* samples, float scale) { + samples[i] /= scale; + } +}; + template struct gamma_kernel { MSHADOW_XINLINE static void Map(index_t i, const Shape &lstride, @@ -236,27 +243,6 @@ struct gamma_two_scalar_kernel { out[i] = sample; } } - -// Backward utils -template -MSHADOW_XINLINE void StandardGammaPdf(IType a, IType x) { - return pow(x, alpha - 1) * exp(-x) -} - -template -MSHADOW_XINLINE void StandardGammaCdf(IType a, IType x) { - // Approximate the Gamma cdf via taylor series - IType numer = 1; - IType denom = a; - - for (int i = 1; i <= 5; i++) { - numer *= -x / i; - denom += 1; - series1 - } -} - - }; } // namespace mxnet_op @@ -415,6 +401,35 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } } +template +inline void GammaReparamBackwardImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& new_ishape, + const mxnet::TShape& new_oshape, + const float scale) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace broadcast; + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + const TBlob igrad = outputs[0].reshape(new_ishape); + // inputs: [grad_from_samples, alpha_tensor, samples] + const TBlob ograd = inputs[0].reshape(new_oshape); + const TBlob alpha = inputs[1].reshape(new_ishape); + const TBlob samples = inputs[2].reshape(new_oshape); + size_t workspace_size = + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); + // Convert samples to standard gamma + Kernel, xpu>::Launch( + s, samples.Size(), samples.dptr(), scale); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + Reduce( + s, igrad, req[0], workspace, ograd, alpha, samples); +} + // Allow gamma sampling to be differentiable, // using implicit reparameterization gradient: // -(d/d\alpha cdf(x;alpha)) / pdf(x;alpha) @@ -433,11 +448,27 @@ void NumpyGammaGrad(const nnvm::NodeAttrs& attrs, if (outputs.size() == 0U) { return; } - const NumpyGammaParam ¶m = nnvm::get(attrs.parsed) - - // [tensor tensor] case + const NumpyGammaParam ¶m = nnvm::get(attrs.parsed); + // [tensor tensor] case, not supported. if (inputs.size() == 5U) { + LOG(FATAL) << "ValueError: two tensor case not supported"; + } + // [tensor, scalar] case, only scalar scale is supported. + if (inputs.size() == 4U) { + if (param.shape.has_value()) { + LOG(FATAL) << "ValueError: tensor scale case not supported"; + } + mxnet::TShape new_ishape, new_oshape; + int ndim = FillShape(outputs[0].shape_, outputs[0].shape_, inputs[0].shape_, + &new_ishape, &new_ishape, &new_oshape); + auto scale = param.scale.value(); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + GammaReparamBackwardImpl( + ctx, inputs, req, outputs, new_ishape, new_oshape, scale); + }); + }); } } diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index b5e253a1872e..fb3d6782bb3e 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -417,6 +417,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_and); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma_implicit_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm); // NOLINT() From 6fdf0b7a559150cb22a6b491a168dee0fd0c791a Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Mon, 3 Aug 2020 16:42:27 +0000 Subject: [PATCH 3/8] test tbd --- src/operator/mshadow_op.h | 79 +++++++++++++++++------- src/operator/numpy/random/np_gamma_op.cc | 4 +- src/operator/numpy/random/np_gamma_op.cu | 2 +- src/operator/numpy/random/np_gamma_op.h | 8 +-- src/operator/operator_tune.cc | 2 +- tests/python/unittest/test_numpy_op.py | 51 +++++++++++++++ 6 files changed, 116 insertions(+), 30 deletions(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 52428cd36804..5d20a9ffe9bf 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -1170,32 +1170,69 @@ struct smooth_l1_gradient : public mxnet_op::tunable { /* Implicti reparameterization gradient for standard x ~ Gamma(\alpha, 1) * according to dx/da = -cdf(x;alpha) / pdf(x;alpha) */ -struct gamma_implicit_grad : public mxnet_op::tunable { - template +struct gamma_implicit_grad { + template MSHADOW_XINLINE static DType Map(DType a, DType x) { - DType numer = 1; - DType denom = a; - DType series1 = numer / denom; - DType series2 = numer / (denom * denom); - for (int i = 1; i <= 5; i++) { - numer *= -x / static_cast(i); - denom += 1; - series1 += numer / denom; - series2 += numer / (denom * denom); + if (x < 0.8f) { + DType numer = 1; + DType denom = a; + DType series1 = numer / denom; + DType series2 = numer / (denom * denom); + for (int i = 1; i <= 5; i++) { + numer *= -x / static_cast(i); + denom += 1; + series1 += numer / denom; + series2 += numer / (denom * denom); + } + DType pow_x_alpha = math::pow(x, a); + DType gamma_pdf = math::pow(x, a - 1) * math::exp(-x); + DType gamma_cdf = pow_x_alpha * series1; + DType gamma_cdf_alpha = + (math::log(x) - DType(special_functions::cephes::psi(a))) * + gamma_cdf - + pow_x_alpha * series2; + DType result = -gamma_cdf_alpha / gamma_pdf; + return IsNan(result) ? static_cast( 0.f ) : static_cast(result); + } + if (a > 8.0f) { + if (0.9f * a <= x && x <= 1.1f * a) { + DType numer_1 = 1 + 24 * a * (1 + 12 * a); + DType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) - + 65 * x * x / a + a * (107 + 3600 * x); + DType denom = 1244160 * (a * a) * (a * a); + return static_cast(numer_1 * numer_2 / denom); + } + DType denom = math::sqrt(8 * a); + DType term2 = denom / (a - x); + DType term3 = + math::pow(x - a - a * math::log(x / a), static_cast(-1.5)); + DType term23 = (x < a) ? term2 - term3 : term2 + term3; + DType term1 = math::log(x / a) * term23 - + math::sqrt(2 / a) * (a + x) / ((a - x) * (a - x)); + DType stirling = 1 + 1 / (12 * a) * (1 + 1 / (24 * a)); + DType numer = x * term1; + return static_cast(-stirling * numer / denom); + } + DType u = math::log(x / a); + DType v = math::log(a); + DType coef_uv[3][8] = { + {0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115, + 0.10406089, 0.0014179084}, + {0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465, + 0.020070113, -0.0035938915, -0.00058392623}, + {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642, + -0.0021309326, 0.00085092367, -1.5247877e-07}, + }; + DType coef_v[8]; + for (int i = 0; i < 8; i++) { + coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]); } - DType pow_x_alpha = math::pow(x, a); - DType gamma_pdf = math::pow(x, a - 1) * math::exp(-x); - DType gamma_cdf = pow_x_alpha * series1; - DType gamma_cdf_alpha = (math::log(x) - - DType(special_functions::cephes::psi(a)) * gamma_cdf - - pow_x_alpha * series2); - DType result = -gamma_cdf_alpha / gamma_pdf; - return IsNan(result) ? static_cast(0.f) : result; + DType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3])); + DType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7])); + return static_cast(math::exp(p / q)); } }; // gamma_implicit_grad -MXNET_BINARY_MATH_OP_NC_WITH_BOOL(neg_div, -a / b); - /*! \brief product reducer */ struct product { /*! \brief do reduction into dst */ diff --git a/src/operator/numpy/random/np_gamma_op.cc b/src/operator/numpy/random/np_gamma_op.cc index d4673bb807d4..847243da9491 100644 --- a/src/operator/numpy/random/np_gamma_op.cc +++ b/src/operator/numpy/random/np_gamma_op.cc @@ -75,12 +75,12 @@ NNVM_REGISTER_OP(_npi_gamma) ResourceRequest::kTempSpace}; }) .set_attr("FCompute", NumpyGammaForward) -.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_gamma"}) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_gamma_sample"}) .add_argument("input1", "NDArray-or-Symbol", "Source input") .add_argument("input2", "NDArray-or-Symbol", "Source input") .add_arguments(NumpyGammaParam::__FIELDS__()); -NNVM_REGISTER_OP(_backward_gamma) +NNVM_REGISTER_OP(_backward_gamma_sample) .set_attr("TIsBackward", true) .set_attr_parser(ParamParser) .set_num_inputs( diff --git a/src/operator/numpy/random/np_gamma_op.cu b/src/operator/numpy/random/np_gamma_op.cu index e043b2b9a33b..2d8b5b204dd2 100644 --- a/src/operator/numpy/random/np_gamma_op.cu +++ b/src/operator/numpy/random/np_gamma_op.cu @@ -32,7 +32,7 @@ namespace op { NNVM_REGISTER_OP(_npi_gamma) .set_attr("FCompute", NumpyGammaForward); -NNVM_REGISTER_OP(_backward_gamma) +NNVM_REGISTER_OP(_backward_gamma_sample) .set_attr("FCompute", NumpyGammaGrad); } // namespace op diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index ba99bf11cc1c..44ff3bc9bd6e 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -422,12 +422,12 @@ inline void GammaReparamBackwardImpl(const OpContext& ctx, size_t workspace_size = ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); // Convert samples to standard gamma - Kernel, xpu>::Launch( - s, samples.Size(), samples.dptr(), scale); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); Reduce( s, igrad, req[0], workspace, ograd, alpha, samples); + Kernel, xpu>::Launch( + s, igrad.Size(), igrad.dptr(), scale); } // Allow gamma sampling to be differentiable, @@ -443,7 +443,6 @@ void NumpyGammaGrad(const nnvm::NodeAttrs& attrs, if (inputs[0].shape_.Size() == 0U) { return; } - // [scalar, scalar] case if (outputs.size() == 0U) { return; @@ -455,7 +454,7 @@ void NumpyGammaGrad(const nnvm::NodeAttrs& attrs, } // [tensor, scalar] case, only scalar scale is supported. - if (inputs.size() == 4U) { + if (inputs.size() == 3U) { if (param.shape.has_value()) { LOG(FATAL) << "ValueError: tensor scale case not supported"; } @@ -470,7 +469,6 @@ void NumpyGammaGrad(const nnvm::NodeAttrs& attrs, }); }); } - } } // namespace op diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index fb3d6782bb3e..9d409f4cb2d9 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -417,7 +417,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_and); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma_implicit_grad); // NOLINT() +// IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma_implicit_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm); // NOLINT() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 88ad77fc978d..c776a057c4f7 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4777,6 +4777,57 @@ def _test_gamma_exception(shape, scale): assertRaises(ValueError, _test_gamma_exception, shape, scale) +@with_seed() +@use_np +def test_gamma_grad(): + class TestRandomGamma(HybridBlock): + def __init__(self, size, beta): + super(TestRandomGamma, self).__init__() + self._size = size + self._beta = beta + + def hybrid_forward(self, F, a): + return F.np.random.gamma(a, 1.0, self._size) * self._beta + + shapes = [ + # shape(alpha), shape(samples) + ((1,), (1,)), + ((4,), (4,)), + ((2, 2), (2, 2)), + ] + alpha = [2.0, 5.0, 10.0] + beta = [0.5, 1.0, 1.5] + for (shape, a, b) in itertools.product(shapes, alpha, beta): + for hybridize in [True, False]: + param = np.ones(shape[0]) * a + param.attach_grad() + sample_shape = shape[1] + net = TestRandomGamma(sample_shape, b) + if hybridize: + net.hybridize() + with mx.autograd.record(): + samples = net(param) + samples.backward() + # Check shape + assert param.grad.shape == param.shape + # Check correctness + cdf = ss.gamma.cdf + log_pdf = ss.gamma.logpdf + eps = (0.01 * param / (1.0 + param ** 0.5)).asnumpy() + x = samples.asnumpy().astype('float64') + # d(cdf(x;alpha,beta))/d(alpha) + cdf_alpha = (cdf(x, param.asnumpy() + eps) - + cdf(x, param.asnumpy() - eps)) / (2 * eps) + # d(cdf(x;alpha,beta))/d(x) + log_cdf_x = log_pdf(x, _np.ones_like(x) * a) + expected_grad = -cdf_alpha / _np.exp(log_cdf_x) + # print("*************") + # print(a) + # print(b) + # print(param.grad) + # print(expected_grad / b) + + @with_seed() @use_np @pytest.mark.skip(reason='https://github.com/apache/incubator-mxnet/issues/18600') From e86bf4756ef097edd79ee866de65da237004cb61 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 4 Aug 2020 09:52:49 +0000 Subject: [PATCH 4/8] fix grad --- src/operator/numpy/random/np_gamma_op.h | 4 +++- tests/python/unittest/test_numpy_op.py | 14 +++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index 44ff3bc9bd6e..c4f43ea62b23 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -422,12 +422,14 @@ inline void GammaReparamBackwardImpl(const OpContext& ctx, size_t workspace_size = ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); // Convert samples to standard gamma + Kernel, xpu>::Launch( + s, samples.Size(), samples.dptr(), scale); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); Reduce( s, igrad, req[0], workspace, ograd, alpha, samples); Kernel, xpu>::Launch( - s, igrad.Size(), igrad.dptr(), scale); + s, igrad.Size(), igrad.dptr(), 1 / scale); } // Allow gamma sampling to be differentiable, diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index c776a057c4f7..6b58bad5eeb5 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4787,7 +4787,7 @@ def __init__(self, size, beta): self._beta = beta def hybrid_forward(self, F, a): - return F.np.random.gamma(a, 1.0, self._size) * self._beta + return F.np.random.gamma(a, size=self._size) * self._beta shapes = [ # shape(alpha), shape(samples) @@ -4816,16 +4816,12 @@ def hybrid_forward(self, F, a): eps = (0.01 * param / (1.0 + param ** 0.5)).asnumpy() x = samples.asnumpy().astype('float64') # d(cdf(x;alpha,beta))/d(alpha) - cdf_alpha = (cdf(x, param.asnumpy() + eps) - - cdf(x, param.asnumpy() - eps)) / (2 * eps) + cdf_alpha = (cdf(x, param.asnumpy() + eps, scale=b) - + cdf(x, param.asnumpy() - eps, scale=b)) / (2 * eps) # d(cdf(x;alpha,beta))/d(x) - log_cdf_x = log_pdf(x, _np.ones_like(x) * a) + log_cdf_x = log_pdf(x, param.asnumpy(), scale=b) expected_grad = -cdf_alpha / _np.exp(log_cdf_x) - # print("*************") - # print(a) - # print(b) - # print(param.grad) - # print(expected_grad / b) + assert_almost_equal(expected_grad, param.grad.asnumpy(), rtol=1e-2, atol=1e-3) @with_seed() From d41bf51a9fcb6e66e37fa4ba5346bed0a83b259c Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 5 Aug 2020 11:48:53 +0000 Subject: [PATCH 5/8] change scale to the frontend --- python/mxnet/ndarray/numpy/random.py | 2 +- python/mxnet/symbol/numpy/random.py | 16 ++++++++-------- tests/python/unittest/test_numpy_op.py | 12 ++++++------ 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 41e573c76111..7f57a1cee925 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -732,7 +732,7 @@ def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): ctx = str(ctx) if dtype is not None and not isinstance(dtype, str): dtype = np.dtype(dtype).name - return _api_internal.gamma(shape, scale, size, ctx, dtype, out) + return _api_internal.gamma(shape, 1.0, size, ctx, dtype, out) * scale def beta(a, b, size=None, dtype=None, ctx=None): diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 75780df173e9..e74fb1e55874 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -512,17 +512,17 @@ def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): if size == (): size = None if input_type == (True, True): - return _npi.gamma(shape, scale, shape=None, scale=None, size=size, - ctx=ctx, dtype=dtype, out=out) + return _npi.gamma(shape, shape=None, scale=1.0, size=size, + ctx=ctx, dtype=dtype, out=out) * scale elif input_type == (False, True): - return _npi.gamma(scale, shape=shape, scale=None, size=size, - ctx=ctx, dtype=dtype, out=out) + return _npi.gamma(shape=shape, scale=1.0, size=size, + ctx=ctx, dtype=dtype, out=out) * scale elif input_type == (True, False): - return _npi.gamma(shape, shape=None, scale=scale, size=size, - ctx=ctx, dtype=dtype, out=out) + return _npi.gamma(shape, shape=None, scale=1.0, size=size, + ctx=ctx, dtype=dtype, out=out) * scale else: - return _npi.gamma(shape=shape, scale=scale, size=size, - ctx=ctx, dtype=dtype, out=out) + return _npi.gamma(shape=shape, scale=1.0, size=size, + ctx=ctx, dtype=dtype, out=out) * scale raise ValueError("Distribution parameters must be either _Symbol or numbers") diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 6b58bad5eeb5..0dfd666929dc 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4787,7 +4787,7 @@ def __init__(self, size, beta): self._beta = beta def hybrid_forward(self, F, a): - return F.np.random.gamma(a, size=self._size) * self._beta + return F.np.random.gamma(a, self._beta, size=self._size) shapes = [ # shape(alpha), shape(samples) @@ -4814,14 +4814,14 @@ def hybrid_forward(self, F, a): cdf = ss.gamma.cdf log_pdf = ss.gamma.logpdf eps = (0.01 * param / (1.0 + param ** 0.5)).asnumpy() - x = samples.asnumpy().astype('float64') + x = samples.asnumpy().astype('float64') / b # d(cdf(x;alpha,beta))/d(alpha) - cdf_alpha = (cdf(x, param.asnumpy() + eps, scale=b) - - cdf(x, param.asnumpy() - eps, scale=b)) / (2 * eps) + cdf_alpha = (cdf(x, param.asnumpy() + eps) - + cdf(x, param.asnumpy() - eps)) / (2 * eps) # d(cdf(x;alpha,beta))/d(x) - log_cdf_x = log_pdf(x, param.asnumpy(), scale=b) + log_cdf_x = log_pdf(x, param.asnumpy()) expected_grad = -cdf_alpha / _np.exp(log_cdf_x) - assert_almost_equal(expected_grad, param.grad.asnumpy(), rtol=1e-2, atol=1e-3) + assert_almost_equal(expected_grad * b, param.grad.asnumpy(), rtol=1e-2, atol=1e-3) @with_seed() From 3fe6814d72ef9d5d42c35198d5ee594aa2287d21 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Fri, 7 Aug 2020 09:04:15 +0000 Subject: [PATCH 6/8] fix bugs --- python/mxnet/ndarray/numpy/random.py | 2 +- python/mxnet/symbol/numpy/random.py | 16 ++++++++-------- src/operator/numpy/random/np_gamma_op.h | 15 ++++++++++----- tests/python/unittest/test_numpy_op.py | 16 +++++----------- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 7f57a1cee925..41e573c76111 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -732,7 +732,7 @@ def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): ctx = str(ctx) if dtype is not None and not isinstance(dtype, str): dtype = np.dtype(dtype).name - return _api_internal.gamma(shape, 1.0, size, ctx, dtype, out) * scale + return _api_internal.gamma(shape, scale, size, ctx, dtype, out) def beta(a, b, size=None, dtype=None, ctx=None): diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index e74fb1e55874..75780df173e9 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -512,17 +512,17 @@ def gamma(shape, scale=1.0, size=None, dtype=None, ctx=None, out=None): if size == (): size = None if input_type == (True, True): - return _npi.gamma(shape, shape=None, scale=1.0, size=size, - ctx=ctx, dtype=dtype, out=out) * scale + return _npi.gamma(shape, scale, shape=None, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) elif input_type == (False, True): - return _npi.gamma(shape=shape, scale=1.0, size=size, - ctx=ctx, dtype=dtype, out=out) * scale + return _npi.gamma(scale, shape=shape, scale=None, size=size, + ctx=ctx, dtype=dtype, out=out) elif input_type == (True, False): - return _npi.gamma(shape, shape=None, scale=1.0, size=size, - ctx=ctx, dtype=dtype, out=out) * scale + return _npi.gamma(shape, shape=None, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) else: - return _npi.gamma(shape=shape, scale=1.0, size=size, - ctx=ctx, dtype=dtype, out=out) * scale + return _npi.gamma(shape=shape, scale=scale, size=size, + ctx=ctx, dtype=dtype, out=out) raise ValueError("Distribution parameters must be either _Symbol or numbers") diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index c4f43ea62b23..aa9913e75f94 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -418,18 +418,23 @@ inline void GammaReparamBackwardImpl(const OpContext& ctx, // inputs: [grad_from_samples, alpha_tensor, samples] const TBlob ograd = inputs[0].reshape(new_oshape); const TBlob alpha = inputs[1].reshape(new_ishape); - const TBlob samples = inputs[2].reshape(new_oshape); + TBlob samples = inputs[2].reshape(new_oshape); size_t workspace_size = ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); // Convert samples to standard gamma - Kernel, xpu>::Launch( - s, samples.Size(), samples.dptr(), scale); + // Kernel, xpu>::Launch( + // s, samples.Size(), samples.dptr(), scale); + Kernel, xpu>::Launch( + s, samples.Size(), samples.dptr(), samples.dptr(), DType(scale)); Tensor workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); Reduce( s, igrad, req[0], workspace, ograd, alpha, samples); - Kernel, xpu>::Launch( - s, igrad.Size(), igrad.dptr(), 1 / scale); + Kernel, xpu>::Launch( + s, igrad.Size(), igrad.dptr(), igrad.dptr(), DType(scale)); + // Convert samples back, otherwise the output would be corrupted. + Kernel, xpu>::Launch( + s, samples.Size(), samples.dptr(), samples.dptr(), DType(scale)); } // Allow gamma sampling to be differentiable, diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 0dfd666929dc..3eca41092fe0 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4789,20 +4789,14 @@ def __init__(self, size, beta): def hybrid_forward(self, F, a): return F.np.random.gamma(a, self._beta, size=self._size) - shapes = [ - # shape(alpha), shape(samples) - ((1,), (1,)), - ((4,), (4,)), - ((2, 2), (2, 2)), - ] + shapes = [(1,), (2, 2), (4, 2, 2)] alpha = [2.0, 5.0, 10.0] beta = [0.5, 1.0, 1.5] for (shape, a, b) in itertools.product(shapes, alpha, beta): for hybridize in [True, False]: - param = np.ones(shape[0]) * a + param = np.ones(shape) * a param.attach_grad() - sample_shape = shape[1] - net = TestRandomGamma(sample_shape, b) + net = TestRandomGamma(shape, b) if hybridize: net.hybridize() with mx.autograd.record(): @@ -4820,8 +4814,8 @@ def hybrid_forward(self, F, a): cdf(x, param.asnumpy() - eps)) / (2 * eps) # d(cdf(x;alpha,beta))/d(x) log_cdf_x = log_pdf(x, param.asnumpy()) - expected_grad = -cdf_alpha / _np.exp(log_cdf_x) - assert_almost_equal(expected_grad * b, param.grad.asnumpy(), rtol=1e-2, atol=1e-3) + expected_grad = -b * cdf_alpha / _np.exp(log_cdf_x) + assert_almost_equal(expected_grad, param.grad.asnumpy(), rtol=1e-2, atol=1e-3) @with_seed() From e8ff581eb0e43999d9f6012cc06b93c3388fd02e Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Fri, 7 Aug 2020 09:10:09 +0000 Subject: [PATCH 7/8] change distributions.gamma --- python/mxnet/gluon/probability/distributions/gamma.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/probability/distributions/gamma.py b/python/mxnet/gluon/probability/distributions/gamma.py index 348a0a51f0a4..bdb328cf2242 100644 --- a/python/mxnet/gluon/probability/distributions/gamma.py +++ b/python/mxnet/gluon/probability/distributions/gamma.py @@ -77,10 +77,10 @@ def broadcast_to(self, batch_shape): return new_instance def sample(self, size=None): - return self.F.np.random.gamma(self.shape, self.scale, size) + return self.F.np.random.gamma(self.shape, 1, size) * self.scale def sample_n(self, size=None): - return self.F.np.random.gamma(self.shape, self.scale, sample_n_shape_converter(size)) + return self.F.np.random.gamma(self.shape, 1, sample_n_shape_converter(size)) * self.scale @property def mean(self): From faeb5a5e73304caaf4c72c60e5de8c25f2a6c965 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 12 Aug 2020 04:49:09 +0000 Subject: [PATCH 8/8] fix test and operator tune --- src/operator/numpy/random/np_gamma_op.h | 2 - src/operator/operator_tune.cc | 1 - tests/python/unittest/test_numpy_op.py | 59 ++++++++++++------------- 3 files changed, 29 insertions(+), 33 deletions(-) diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index aa9913e75f94..e1d031fb9a0b 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -422,8 +422,6 @@ inline void GammaReparamBackwardImpl(const OpContext& ctx, size_t workspace_size = ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); // Convert samples to standard gamma - // Kernel, xpu>::Launch( - // s, samples.Size(), samples.dptr(), scale); Kernel, xpu>::Launch( s, samples.Size(), samples.dptr(), samples.dptr(), DType(scale)); Tensor workspace = diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 9d409f4cb2d9..b5e253a1872e 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -417,7 +417,6 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_and); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or); // NOLINT() -// IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gamma_implicit_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm); // NOLINT() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 3eca41092fe0..ce9bd13757bc 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4779,43 +4779,42 @@ def _test_gamma_exception(shape, scale): @with_seed() @use_np -def test_gamma_grad(): - class TestRandomGamma(HybridBlock): +@pytest.mark.parametrize("shape", [(1,), (2, 2), (4, 2, 2)]) +@pytest.mark.parametrize("a", [2.0, 5.0, 10.0]) +@pytest.mark.parametrize("b", [0.5, 1.0, 1.5]) +def test_gamma_grad(shape, a, b): + class TestGammaGrad(HybridBlock): def __init__(self, size, beta): - super(TestRandomGamma, self).__init__() + super(TestGammaGrad, self).__init__() self._size = size self._beta = beta def hybrid_forward(self, F, a): return F.np.random.gamma(a, self._beta, size=self._size) - shapes = [(1,), (2, 2), (4, 2, 2)] - alpha = [2.0, 5.0, 10.0] - beta = [0.5, 1.0, 1.5] - for (shape, a, b) in itertools.product(shapes, alpha, beta): - for hybridize in [True, False]: - param = np.ones(shape) * a - param.attach_grad() - net = TestRandomGamma(shape, b) - if hybridize: - net.hybridize() - with mx.autograd.record(): - samples = net(param) - samples.backward() - # Check shape - assert param.grad.shape == param.shape - # Check correctness - cdf = ss.gamma.cdf - log_pdf = ss.gamma.logpdf - eps = (0.01 * param / (1.0 + param ** 0.5)).asnumpy() - x = samples.asnumpy().astype('float64') / b - # d(cdf(x;alpha,beta))/d(alpha) - cdf_alpha = (cdf(x, param.asnumpy() + eps) - - cdf(x, param.asnumpy() - eps)) / (2 * eps) - # d(cdf(x;alpha,beta))/d(x) - log_cdf_x = log_pdf(x, param.asnumpy()) - expected_grad = -b * cdf_alpha / _np.exp(log_cdf_x) - assert_almost_equal(expected_grad, param.grad.asnumpy(), rtol=1e-2, atol=1e-3) + for hybridize in [True, False]: + param = np.ones(shape) * a + param.attach_grad() + net = TestGammaGrad(shape, b) + if hybridize: + net.hybridize() + with mx.autograd.record(): + samples = net(param) + samples.backward() + # Check shape + assert param.grad.shape == param.shape + # Check correctness + cdf = ss.gamma.cdf + log_pdf = ss.gamma.logpdf + eps = (0.01 * param / (1.0 + param ** 0.5)).asnumpy() + x = samples.asnumpy().astype('float64') / b + # d(cdf(x;alpha,beta))/d(alpha) + cdf_alpha = (cdf(x, param.asnumpy() + eps) - + cdf(x, param.asnumpy() - eps)) / (2 * eps) + # d(cdf(x;alpha,beta))/d(x) + log_cdf_x = log_pdf(x, param.asnumpy()) + expected_grad = -b * cdf_alpha / _np.exp(log_cdf_x) + assert_almost_equal(expected_grad, param.grad.asnumpy(), rtol=1e-2, atol=1e-3) @with_seed()