-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move pymc/distributions/logprob.py to pymc/logprob/ #6441
Conversation
b0e154a
to
9763ae9
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6441 +/- ##
==========================================
+ Coverage 86.04% 94.74% +8.69%
==========================================
Files 148 146 -2
Lines 27820 27807 -13
==========================================
+ Hits 23939 26346 +2407
+ Misses 3881 1461 -2420
|
9763ae9
to
05f1698
Compare
np.testing.assert_almost_equal(logp_vals, exp_obs_logps) | ||
|
||
|
||
def test_joint_logp_subtensor(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am sorry but I have trouble identifying where this one and the next one should go, I don't really understand what these tests are about
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This subtensor test is just testing functionality that is implemented in logprob.test_mixture.py
, so we can probably remove it (confirming with codecov that we are not missing lines that were covered exclusively by these)
assert not any(isinstance(o, RandomVariable) for o in ops) | ||
|
||
|
||
def test_logprob_join_constant_shapes(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be in another file, corresponding to where the rewrite is applied. Also we can remove the aeppl comment below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is redundant of
pymc/pymc/tests/logprob/test_tensor.py
Lines 218 to 326 in c5e4497
@pytest.mark.parametrize( | |
"size1, size2, axis, concatenate", | |
[ | |
((5,), (3,), 0, True), | |
((5,), (3,), -1, True), | |
((5, 2), (3, 2), 0, True), | |
((2, 5), (2, 3), 1, True), | |
((2, 5), (2, 5), 0, False), | |
((2, 5), (2, 5), 1, False), | |
((2, 5), (2, 5), 2, False), | |
], | |
) | |
def test_measurable_join_univariate(size1, size2, axis, concatenate): | |
base1_rv = at.random.normal(size=size1, name="base1") | |
base2_rv = at.random.exponential(size=size2, name="base2") | |
if concatenate: | |
y_rv = at.concatenate((base1_rv, base2_rv), axis=axis) | |
else: | |
y_rv = at.stack((base1_rv, base2_rv), axis=axis) | |
y_rv.name = "y" | |
base1_vv = base1_rv.clone() | |
base2_vv = base2_rv.clone() | |
y_vv = y_rv.clone() | |
base_logps = list(factorized_joint_logprob({base1_rv: base1_vv, base2_rv: base2_vv}).values()) | |
if concatenate: | |
base_logps = at.concatenate(base_logps, axis=axis) | |
else: | |
base_logps = at.stack(base_logps, axis=axis) | |
y_logp = joint_logprob({y_rv: y_vv}, sum=False) | |
base1_testval = base1_rv.eval() | |
base2_testval = base2_rv.eval() | |
if concatenate: | |
y_testval = np.concatenate((base1_testval, base2_testval), axis=axis) | |
else: | |
y_testval = np.stack((base1_testval, base2_testval), axis=axis) | |
np.testing.assert_allclose( | |
base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}), | |
y_logp.eval({y_vv: y_testval}), | |
) | |
@pytest.mark.parametrize( | |
"size1, supp_size1, size2, supp_size2, axis, concatenate", | |
[ | |
(None, 2, None, 2, 0, True), | |
(None, 2, None, 2, -1, True), | |
((5,), 2, (3,), 2, 0, True), | |
((5,), 2, (3,), 2, -2, True), | |
((2,), 5, (2,), 3, 1, True), | |
pytest.param( | |
(2,), | |
5, | |
(2,), | |
5, | |
0, | |
False, | |
marks=pytest.mark.xfail(reason="cannot measure dimshuffled multivariate RVs"), | |
), | |
pytest.param( | |
(2,), | |
5, | |
(2,), | |
5, | |
1, | |
False, | |
marks=pytest.mark.xfail(reason="cannot measure dimshuffled multivariate RVs"), | |
), | |
], | |
) | |
def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis, concatenate): | |
base1_rv = at.random.multivariate_normal( | |
np.zeros(supp_size1), np.eye(supp_size1), size=size1, name="base1" | |
) | |
base2_rv = at.random.dirichlet(np.ones(supp_size2), size=size2, name="base2") | |
if concatenate: | |
y_rv = at.concatenate((base1_rv, base2_rv), axis=axis) | |
else: | |
y_rv = at.stack((base1_rv, base2_rv), axis=axis) | |
y_rv.name = "y" | |
base1_vv = base1_rv.clone() | |
base2_vv = base2_rv.clone() | |
y_vv = y_rv.clone() | |
base_logps = [ | |
at.atleast_1d(logp) | |
for logp in factorized_joint_logprob({base1_rv: base1_vv, base2_rv: base2_vv}).values() | |
] | |
if concatenate: | |
axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim) | |
base_logps = at.concatenate(base_logps, axis=axis_norm - 1) | |
else: | |
axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim + 1) | |
base_logps = at.stack(base_logps, axis=axis_norm - 1) | |
y_logp = joint_logprob({y_rv: y_vv}, sum=False) | |
base1_testval = base1_rv.eval() | |
base2_testval = base2_rv.eval() | |
if concatenate: | |
y_testval = np.concatenate((base1_testval, base2_testval), axis=axis) | |
else: | |
y_testval = np.stack((base1_testval, base2_testval), axis=axis) | |
np.testing.assert_allclose( | |
base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}), | |
y_logp.eval({y_vv: y_testval}), | |
) |
EXCEPT the line assert_no_rvs(...)
on the logp expression. We should add that line in those tests
e5ead95
to
6181e61
Compare
* _joint_logp to pymc/logprob/joint_logprob.py * _get_scaling, _check_no_rvs, logp to logprob/joint_logprob.py * logcdf to logprob/abstract.py * ignore_logprob to logprob/utils.py
6181e61
to
c92b204
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great
Dunno if RTD is a fluke
Thank you for your help @ricardoV94 !
It looks like it is the |
What is this PR about?
This PR shuffles around some of the logprob-related functions.
joint_logprob
, a thin wrapper overfactorized_joint_logprob
, was only used in tests, so it was moved to tests._joint_logprob
now lives asjoint_logprob
inpymc/logprob/joint_logprob.py
_get_scaling
,_check_no_rvs
,logp
topymc/logprob/joint_logprob.py
logcdf
tologprob/abstract.py
ignore_logprob
tologprob/utils.py
Checklist
Major / Breaking Changes
New features
Bugfixes
Documentation
Maintenance