Skip to content

List more math function in API docs #7211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 27, 2024
1 change: 1 addition & 0 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def scipy_log_pdf(value, a, b):
return np.log(a) + np.log(b) + (a - 1) * np.log(value) + (b - 1) * np.log(1 - value**a)

def scipy_log_cdf(value, a, b):
warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
return pm.math.log1mexp_numpy(b * np.log1p(-(value**a)), negative_input=True)

check_logp(
Expand Down
12 changes: 9 additions & 3 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
kron_solve_lower,
kronecker,
log1mexp,
log1mexp_numpy,
log1mexp_numpy, # to be deprecated
logdet,
logdiffexp,
logdiffexp_numpy,
logdiffexp_numpy, # to be deprecated
probit,
)
from pymc.pytensorf import floatX
Expand Down Expand Up @@ -148,6 +148,8 @@ def test_log1mexp():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning)

warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
actual_ = log1mexp_numpy(-vals, negative_input=True)
npt.assert_allclose(actual_, expected)
# Check that input was not changed in place
Expand All @@ -158,10 +160,12 @@ def test_log1mexp_numpy_no_warning():
"""Assert RuntimeWarning is not raised for very small numbers"""
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
log1mexp_numpy(-1e-25, negative_input=True)


def test_log1mexp_numpy_integer_input():
warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval())


Expand All @@ -170,10 +174,12 @@ def test_log1mexp_deprecation_warnings():
FutureWarning,
match="pymc.math.log1mexp_numpy will expect a negative input",
):
warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
res_pos = log1mexp_numpy(2)

with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
res_neg = log1mexp_numpy(-2, negative_input=True)

with pytest.warns(
Expand All @@ -196,7 +202,7 @@ def test_logdiffexp():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
b = np.log([0, 1, 2, 3])

warnings.warn("pymc.math.logdiffexp_numpy is being deprecated.", FutureWarning)
assert np.allclose(logdiffexp_numpy(a, b), 0)
assert np.allclose(logdiffexp(a, b).eval(), 0)

Expand Down