diff --git a/python/mxnet/gluon/probability/distributions/gamma.py b/python/mxnet/gluon/probability/distributions/gamma.py index 348a0a51f0a4..bdb328cf2242 100644 --- a/python/mxnet/gluon/probability/distributions/gamma.py +++ b/python/mxnet/gluon/probability/distributions/gamma.py @@ -77,10 +77,10 @@ def broadcast_to(self, batch_shape): return new_instance def sample(self, size=None): - return self.F.np.random.gamma(self.shape, self.scale, size) + return self.F.np.random.gamma(self.shape, 1, size) * self.scale def sample_n(self, size=None): - return self.F.np.random.gamma(self.shape, self.scale, sample_n_shape_converter(size)) + return self.F.np.random.gamma(self.shape, 1, sample_n_shape_converter(size)) * self.scale @property def mean(self): diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 55f26b08fcc1..5d20a9ffe9bf 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -1167,6 +1167,72 @@ struct smooth_l1_gradient : public mxnet_op::tunable { } }; // struct smooth_l1_derivative +/* Implicti reparameterization gradient for standard x ~ Gamma(\alpha, 1) + * according to dx/da = -cdf(x;alpha) / pdf(x;alpha) + */ +struct gamma_implicit_grad { + template + MSHADOW_XINLINE static DType Map(DType a, DType x) { + if (x < 0.8f) { + DType numer = 1; + DType denom = a; + DType series1 = numer / denom; + DType series2 = numer / (denom * denom); + for (int i = 1; i <= 5; i++) { + numer *= -x / static_cast(i); + denom += 1; + series1 += numer / denom; + series2 += numer / (denom * denom); + } + DType pow_x_alpha = math::pow(x, a); + DType gamma_pdf = math::pow(x, a - 1) * math::exp(-x); + DType gamma_cdf = pow_x_alpha * series1; + DType gamma_cdf_alpha = + (math::log(x) - DType(special_functions::cephes::psi(a))) * + gamma_cdf - + pow_x_alpha * series2; + DType result = -gamma_cdf_alpha / gamma_pdf; + return IsNan(result) ? static_cast( 0.f ) : static_cast(result); + } + if (a > 8.0f) { + if (0.9f * a <= x && x <= 1.1f * a) { + DType numer_1 = 1 + 24 * a * (1 + 12 * a); + DType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) - + 65 * x * x / a + a * (107 + 3600 * x); + DType denom = 1244160 * (a * a) * (a * a); + return static_cast(numer_1 * numer_2 / denom); + } + DType denom = math::sqrt(8 * a); + DType term2 = denom / (a - x); + DType term3 = + math::pow(x - a - a * math::log(x / a), static_cast(-1.5)); + DType term23 = (x < a) ? term2 - term3 : term2 + term3; + DType term1 = math::log(x / a) * term23 - + math::sqrt(2 / a) * (a + x) / ((a - x) * (a - x)); + DType stirling = 1 + 1 / (12 * a) * (1 + 1 / (24 * a)); + DType numer = x * term1; + return static_cast(-stirling * numer / denom); + } + DType u = math::log(x / a); + DType v = math::log(a); + DType coef_uv[3][8] = { + {0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115, + 0.10406089, 0.0014179084}, + {0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465, + 0.020070113, -0.0035938915, -0.00058392623}, + {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642, + -0.0021309326, 0.00085092367, -1.5247877e-07}, + }; + DType coef_v[8]; + for (int i = 0; i < 8; i++) { + coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]); + } + DType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3])); + DType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7])); + return static_cast(math::exp(p / q)); + } +}; // gamma_implicit_grad + /*! \brief product reducer */ struct product { /*! \brief do reduction into dst */ diff --git a/src/operator/numpy/random/np_gamma_op.cc b/src/operator/numpy/random/np_gamma_op.cc index f87e997d549e..847243da9491 100644 --- a/src/operator/numpy/random/np_gamma_op.cc +++ b/src/operator/numpy/random/np_gamma_op.cc @@ -75,10 +75,39 @@ NNVM_REGISTER_OP(_npi_gamma) ResourceRequest::kTempSpace}; }) .set_attr("FCompute", NumpyGammaForward) -.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_gamma_sample"}) .add_argument("input1", "NDArray-or-Symbol", "Source input") .add_argument("input2", "NDArray-or-Symbol", "Source input") .add_arguments(NumpyGammaParam::__FIELDS__()); +NNVM_REGISTER_OP(_backward_gamma_sample) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser) +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyGammaParam& param = nnvm::get(attrs.parsed); + int num_inputs = 4; + if (param.shape.has_value()) num_inputs -= 1; + if (param.scale.has_value()) num_inputs -= 1; + return num_inputs; + } +) +.set_num_outputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyGammaParam& param = nnvm::get(attrs.parsed); + int num_outputs = 2; + if (param.shape.has_value()) num_outputs -= 1; + if (param.scale.has_value()) num_outputs -= 1; + return num_outputs; + } +) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyGammaGrad) +.add_arguments(NumpyGammaParam::__FIELDS__()); + + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_gamma_op.cu b/src/operator/numpy/random/np_gamma_op.cu index 5be15c7b9d13..2d8b5b204dd2 100644 --- a/src/operator/numpy/random/np_gamma_op.cu +++ b/src/operator/numpy/random/np_gamma_op.cu @@ -32,5 +32,8 @@ namespace op { NNVM_REGISTER_OP(_npi_gamma) .set_attr("FCompute", NumpyGammaForward); +NNVM_REGISTER_OP(_backward_gamma_sample) +.set_attr("FCompute", NumpyGammaGrad); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_gamma_op.h b/src/operator/numpy/random/np_gamma_op.h index 58ca4c7c52c0..e1d031fb9a0b 100644 --- a/src/operator/numpy/random/np_gamma_op.h +++ b/src/operator/numpy/random/np_gamma_op.h @@ -136,6 +136,13 @@ struct CheckSuccessKernel { } }; +template +struct StandarizeKernel { + MSHADOW_XINLINE static void Map(int i, DType* samples, float scale) { + samples[i] /= scale; + } +}; + template struct gamma_kernel { MSHADOW_XINLINE static void Map(index_t i, const Shape &lstride, @@ -394,6 +401,81 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } } +template +inline void GammaReparamBackwardImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape& new_ishape, + const mxnet::TShape& new_oshape, + const float scale) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace broadcast; + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + const TBlob igrad = outputs[0].reshape(new_ishape); + // inputs: [grad_from_samples, alpha_tensor, samples] + const TBlob ograd = inputs[0].reshape(new_oshape); + const TBlob alpha = inputs[1].reshape(new_ishape); + TBlob samples = inputs[2].reshape(new_oshape); + size_t workspace_size = + ReduceWorkspaceSize(s, igrad.shape_, req[0], ograd.shape_); + // Convert samples to standard gamma + Kernel, xpu>::Launch( + s, samples.Size(), samples.dptr(), samples.dptr(), DType(scale)); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + Reduce( + s, igrad, req[0], workspace, ograd, alpha, samples); + Kernel, xpu>::Launch( + s, igrad.Size(), igrad.dptr(), igrad.dptr(), DType(scale)); + // Convert samples back, otherwise the output would be corrupted. + Kernel, xpu>::Launch( + s, samples.Size(), samples.dptr(), samples.dptr(), DType(scale)); +} + +// Allow gamma sampling to be differentiable, +// using implicit reparameterization gradient: +// -(d/d\alpha cdf(x;alpha)) / pdf(x;alpha) +template +void NumpyGammaGrad(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + // skip kernel launch for zero-size tensors + if (inputs[0].shape_.Size() == 0U) { + return; + } + // [scalar, scalar] case + if (outputs.size() == 0U) { + return; + } + const NumpyGammaParam ¶m = nnvm::get(attrs.parsed); + // [tensor tensor] case, not supported. + if (inputs.size() == 5U) { + LOG(FATAL) << "ValueError: two tensor case not supported"; + } + + // [tensor, scalar] case, only scalar scale is supported. + if (inputs.size() == 3U) { + if (param.shape.has_value()) { + LOG(FATAL) << "ValueError: tensor scale case not supported"; + } + mxnet::TShape new_ishape, new_oshape; + int ndim = FillShape(outputs[0].shape_, outputs[0].shape_, inputs[0].shape_, + &new_ishape, &new_ishape, &new_oshape); + auto scale = param.scale.value(); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + GammaReparamBackwardImpl( + ctx, inputs, req, outputs, new_ishape, new_oshape, scale); + }); + }); + } +} + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index e4564e92510e..93d3e7085148 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4801,6 +4801,46 @@ def _test_gamma_exception(shape, scale): assertRaises(ValueError, _test_gamma_exception, shape, scale) +@with_seed() +@use_np +@pytest.mark.parametrize("shape", [(1,), (2, 2), (4, 2, 2)]) +@pytest.mark.parametrize("a", [2.0, 5.0, 10.0]) +@pytest.mark.parametrize("b", [0.5, 1.0, 1.5]) +def test_gamma_grad(shape, a, b): + class TestGammaGrad(HybridBlock): + def __init__(self, size, beta): + super(TestGammaGrad, self).__init__() + self._size = size + self._beta = beta + + def hybrid_forward(self, F, a): + return F.np.random.gamma(a, self._beta, size=self._size) + + for hybridize in [True, False]: + param = np.ones(shape) * a + param.attach_grad() + net = TestGammaGrad(shape, b) + if hybridize: + net.hybridize() + with mx.autograd.record(): + samples = net(param) + samples.backward() + # Check shape + assert param.grad.shape == param.shape + # Check correctness + cdf = ss.gamma.cdf + log_pdf = ss.gamma.logpdf + eps = (0.01 * param / (1.0 + param ** 0.5)).asnumpy() + x = samples.asnumpy().astype('float64') / b + # d(cdf(x;alpha,beta))/d(alpha) + cdf_alpha = (cdf(x, param.asnumpy() + eps) - + cdf(x, param.asnumpy() - eps)) / (2 * eps) + # d(cdf(x;alpha,beta))/d(x) + log_cdf_x = log_pdf(x, param.asnumpy()) + expected_grad = -b * cdf_alpha / _np.exp(log_cdf_x) + assert_almost_equal(expected_grad, param.grad.asnumpy(), rtol=1e-2, atol=1e-3) + + @with_seed() @use_np @pytest.mark.skip(reason='https://github.com/apache/incubator-mxnet/issues/18600')