Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move pymc/distributions/logprob.py to pymc/logprob/ #6441

Merged
merged 4 commits into from
Feb 7, 2023

Conversation

Armavica
Copy link
Member

@Armavica Armavica commented Jan 7, 2023

What is this PR about?
This PR shuffles around some of the logprob-related functions.

  • joint_logprob, a thin wrapper over factorized_joint_logprob, was only used in tests, so it was moved to tests.
  • _joint_logprob now lives as joint_logprob in pymc/logprob/joint_logprob.py
  • _get_scaling, _check_no_rvs, logp to pymc/logprob/joint_logprob.py
  • logcdf to logprob/abstract.py
  • ignore_logprob to logprob/utils.py

Checklist

Major / Breaking Changes

  • ...

New features

  • ...

Bugfixes

  • ...

Documentation

  • ...

Maintenance

  • Move some internal logprob-related functions.

@codecov
Copy link

codecov bot commented Jan 7, 2023

Codecov Report

Merging #6441 (c92b204) into main (c5e4497) will increase coverage by 8.69%.
The diff coverage is 93.54%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6441      +/-   ##
==========================================
+ Coverage   86.04%   94.74%   +8.69%     
==========================================
  Files         148      146       -2     
  Lines       27820    27807      -13     
==========================================
+ Hits        23939    26346    +2407     
+ Misses       3881     1461    -2420     
Impacted Files Coverage Δ
pymc/distributions/__init__.py 100.00% <ø> (ø)
pymc/tests/logprob/utils.py 47.15% <48.97%> (+0.81%) ⬆️
pymc/distributions/bound.py 100.00% <100.00%> (ø)
pymc/distributions/discrete.py 99.22% <100.00%> (ø)
pymc/distributions/mixture.py 95.42% <100.00%> (+0.02%) ⬆️
pymc/distributions/multivariate.py 92.28% <100.00%> (ø)
pymc/distributions/timeseries.py 94.54% <100.00%> (+0.01%) ⬆️
pymc/logprob/__init__.py 100.00% <100.00%> (ø)
pymc/logprob/joint_logprob.py 99.23% <100.00%> (+21.62%) ⬆️
pymc/logprob/utils.py 100.00% <100.00%> (+13.79%) ⬆️
... and 46 more

pymc/tests/logprob/utils.py Show resolved Hide resolved
pymc/tests/logprob/test_joint_logprob.py Show resolved Hide resolved
np.testing.assert_almost_equal(logp_vals, exp_obs_logps)


def test_joint_logp_subtensor():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Member Author

@Armavica Armavica Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry but I have trouble identifying where this one and the next one should go, I don't really understand what these tests are about

Copy link
Member

@ricardoV94 ricardoV94 Feb 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This subtensor test is just testing functionality that is implemented in logprob.test_mixture.py, so we can probably remove it (confirming with codecov that we are not missing lines that were covered exclusively by these)

assert not any(isinstance(o, RandomVariable) for o in ops)


def test_logprob_join_constant_shapes():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in another file, corresponding to where the rewrite is applied. Also we can remove the aeppl comment below

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is redundant of

@pytest.mark.parametrize(
"size1, size2, axis, concatenate",
[
((5,), (3,), 0, True),
((5,), (3,), -1, True),
((5, 2), (3, 2), 0, True),
((2, 5), (2, 3), 1, True),
((2, 5), (2, 5), 0, False),
((2, 5), (2, 5), 1, False),
((2, 5), (2, 5), 2, False),
],
)
def test_measurable_join_univariate(size1, size2, axis, concatenate):
base1_rv = at.random.normal(size=size1, name="base1")
base2_rv = at.random.exponential(size=size2, name="base2")
if concatenate:
y_rv = at.concatenate((base1_rv, base2_rv), axis=axis)
else:
y_rv = at.stack((base1_rv, base2_rv), axis=axis)
y_rv.name = "y"
base1_vv = base1_rv.clone()
base2_vv = base2_rv.clone()
y_vv = y_rv.clone()
base_logps = list(factorized_joint_logprob({base1_rv: base1_vv, base2_rv: base2_vv}).values())
if concatenate:
base_logps = at.concatenate(base_logps, axis=axis)
else:
base_logps = at.stack(base_logps, axis=axis)
y_logp = joint_logprob({y_rv: y_vv}, sum=False)
base1_testval = base1_rv.eval()
base2_testval = base2_rv.eval()
if concatenate:
y_testval = np.concatenate((base1_testval, base2_testval), axis=axis)
else:
y_testval = np.stack((base1_testval, base2_testval), axis=axis)
np.testing.assert_allclose(
base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}),
y_logp.eval({y_vv: y_testval}),
)
@pytest.mark.parametrize(
"size1, supp_size1, size2, supp_size2, axis, concatenate",
[
(None, 2, None, 2, 0, True),
(None, 2, None, 2, -1, True),
((5,), 2, (3,), 2, 0, True),
((5,), 2, (3,), 2, -2, True),
((2,), 5, (2,), 3, 1, True),
pytest.param(
(2,),
5,
(2,),
5,
0,
False,
marks=pytest.mark.xfail(reason="cannot measure dimshuffled multivariate RVs"),
),
pytest.param(
(2,),
5,
(2,),
5,
1,
False,
marks=pytest.mark.xfail(reason="cannot measure dimshuffled multivariate RVs"),
),
],
)
def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis, concatenate):
base1_rv = at.random.multivariate_normal(
np.zeros(supp_size1), np.eye(supp_size1), size=size1, name="base1"
)
base2_rv = at.random.dirichlet(np.ones(supp_size2), size=size2, name="base2")
if concatenate:
y_rv = at.concatenate((base1_rv, base2_rv), axis=axis)
else:
y_rv = at.stack((base1_rv, base2_rv), axis=axis)
y_rv.name = "y"
base1_vv = base1_rv.clone()
base2_vv = base2_rv.clone()
y_vv = y_rv.clone()
base_logps = [
at.atleast_1d(logp)
for logp in factorized_joint_logprob({base1_rv: base1_vv, base2_rv: base2_vv}).values()
]
if concatenate:
axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim)
base_logps = at.concatenate(base_logps, axis=axis_norm - 1)
else:
axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim + 1)
base_logps = at.stack(base_logps, axis=axis_norm - 1)
y_logp = joint_logprob({y_rv: y_vv}, sum=False)
base1_testval = base1_rv.eval()
base2_testval = base2_rv.eval()
if concatenate:
y_testval = np.concatenate((base1_testval, base2_testval), axis=axis)
else:
y_testval = np.stack((base1_testval, base2_testval), axis=axis)
np.testing.assert_allclose(
base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}),
y_logp.eval({y_vv: y_testval}),
)

EXCEPT the line assert_no_rvs(...) on the logp expression. We should add that line in those tests

@Armavica Armavica force-pushed the rm-joint_logprob branch 2 times, most recently from e5ead95 to 6181e61 Compare February 3, 2023 19:47
* _joint_logp to pymc/logprob/joint_logprob.py
* _get_scaling, _check_no_rvs, logp to logprob/joint_logprob.py
* logcdf to logprob/abstract.py
* ignore_logprob to logprob/utils.py
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great

Dunno if RTD is a fluke

@Armavica
Copy link
Member Author

Armavica commented Feb 7, 2023

Thank you for your help @ricardoV94 !

Dunno if RTD is a fluke

It looks like it is the mamba env create that timed out, I don't think that this PR could have affected it so I think it is indeed a fluke.

@ricardoV94 ricardoV94 merged commit 28bac77 into pymc-devs:main Feb 7, 2023
jessegrabowski added a commit to jessegrabowski/pymc-extras that referenced this pull request Apr 15, 2023
@Armavica Armavica deleted the rm-joint_logprob branch October 7, 2024 14:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants