Skip to content

Commit

Permalink
Remove gammainc(c) safeguards in logcdf methods of Gamma and `I…
Browse files Browse the repository at this point in the history
…nverseGamma`

Closes pymc-devs#4467
  • Loading branch information
ricardoV94 committed Jun 4, 2021
1 parent 6fa8506 commit dc4f211
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 21 deletions.
13 changes: 2 additions & 11 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2392,13 +2392,8 @@ def logcdf(value, alpha, inv_beta):
"""
beta = at.inv(inv_beta)

# Avoid C-assertion when the gammainc function is called with invalid values (#4340)
safe_alpha = at.switch(at.lt(alpha, 0), 0, alpha)
safe_beta = at.switch(at.lt(beta, 0), 0, beta)
safe_value = at.switch(at.lt(value, 0), 0, value)

return bound(
at.log(at.gammainc(safe_alpha, safe_beta * safe_value)),
at.log(at.gammainc(alpha, beta * value)),
0 <= value,
0 < alpha,
0 < beta,
Expand Down Expand Up @@ -2540,13 +2535,9 @@ def logcdf(value, alpha, beta):
-------
TensorVariable
"""
# Avoid C-assertion when the gammaincc function is called with invalid values (#4340)
safe_alpha = at.switch(at.lt(alpha, 0), 0, alpha)
safe_beta = at.switch(at.lt(beta, 0), 0, beta)
safe_value = at.switch(at.lt(value, 0), 0, value)

return bound(
at.log(at.gammaincc(safe_alpha, safe_beta / safe_value)),
at.log(at.gammaincc(alpha, beta / value)),
0 <= value,
0 < alpha,
0 < beta,
Expand Down
10 changes: 0 additions & 10 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,15 +1437,11 @@ def test_fun(value, mu, sigma):
reason="Fails on float32 due to numerical issues",
)
def test_gamma_logcdf(self):
# pymc-devs/aesara#224: skip_paramdomain_outside_edge_test has to be set
# True to avoid triggering a C-level assertion in the Aesara GammaQ function
# in gamma.c file. Can be set back to False (default) once that issue is solved
self.check_logcdf(
Gamma,
Rplus,
{"alpha": Rplusbig, "beta": Rplusbig},
lambda value, alpha, beta: sp.gamma.logcdf(value, alpha, scale=1.0 / beta),
skip_paramdomain_outside_edge_test=True,
)

def test_inverse_gamma_logp(self):
Expand All @@ -1455,23 +1451,17 @@ def test_inverse_gamma_logp(self):
{"alpha": Rplus, "beta": Rplus},
lambda value, alpha, beta: sp.invgamma.logpdf(value, alpha, scale=beta),
)
# pymc-devs/aesara#224: skip_paramdomain_outside_edge_test has to be set
# True to avoid triggering a C-level assertion in the Aesara GammaQ function

@pytest.mark.xfail(
condition=(aesara.config.floatX == "float32"),
reason="Fails on float32 due to numerical issues",
)
def test_inverse_gamma_logcdf(self):
# pymc-devs/aesara#224: skip_paramdomain_outside_edge_test has to be set
# True to avoid triggering a C-level assertion in the Aesara GammaQ function
# in gamma.c file. Can be set back to False (default) once that issue is solved
self.check_logcdf(
InverseGamma,
Rplus,
{"alpha": Rplus, "beta": Rplus},
lambda value, alpha, beta: sp.invgamma.logcdf(value, alpha, scale=beta),
skip_paramdomain_outside_edge_test=True,
)

@pytest.mark.xfail(
Expand Down

0 comments on commit dc4f211

Please sign in to comment.