-
-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement utility to recover marginalized variables from MarginalModel
#285
Conversation
b0c58b4
to
1fc5d55
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, I will try and see how we can get transforms out of the way
|
||
rv_loglike_fn = None | ||
if include_samples: | ||
sample_rv_outs = pm.Categorical.dist(logit_p=joint_logps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these joint_logps
normalized? pm.Categorical
won't do it under the hood
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will when you use logit_p
, but the logits added to the inferencedata directly will still not be normalized. I think it may be more intuitive if they are but not sure.
if var_names is None: | ||
var_names = self.marginalized_rvs | ||
|
||
joint_logp = self.logp() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.logp
will return the logp with the variables marginalized so they won't be part of the graph. I guess that's why you have on_unused_inputs
issues later? I imagine you want the original logp that does not marginalize variables so you can give them values.
That's why I was asking how you handle multiple related marginalized variables. It seems to me you have to either evaluate one at a time, conditioned on the previous marginalized variables already evaluated, or create a joint logp for all the combinations of the marginalized variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a test that creates nested marginalized variables: https://github.com/pymc-devs/pymc-experimental/blob/8046695e600970bb30a107376281d3e477a66dd0/pymc_experimental/tests/model/test_marginal_model.py#L169-L204
They get represented as an OpFromGraph where there are more than one output RVs without values. You can raise NotImplementedError for these cases for now so that you know you're working only with independent simple marginalized variables. Still I think you don't want to work with self.logp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._logp()
skips the marginalization step, maybe that's what you need?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I chose to marginalise each discrete variable one at a time. Presumably I still need marginalise for that reason
self.register_rv(rv, name=rv.name) | ||
|
||
def recover_marginals( | ||
self, idata, var_names=None, include_samples=False, extend_inferencedata=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would perhaps default include_samples=True
by default. Also maybe a different name? return_samples
or just sample
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The discrete samples are often not as good for understanding the tail probabilities
This can be seen by comparing the changepoint's logps vs the discrete samples discussed on https://mc-stan.org/docs/stan-users-guide/change-point.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that but I still think it doesn't hurt to include by default. logits are not something many users grok intuitively
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, I left 2 comments.
I would a test with multiple marginalized dependent variables and after that I think this would be pretty much there (except for the transforms ofc)
MarginalModel
9514f90
to
e4c9db7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good 😊
Just two comments plus the question of compiling a single function for when we have multiple marginalized RVs. I am not sure the speed benefits outweigh the extra complexity at this point so fine to leave as is
self.register_rv(rv, name=rv.name) | ||
|
||
def recover_marginals( | ||
self, idata, var_names=None, return_samples=False, extend_inferencedata=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think we should return samples by default.
Ah one reason I see for why we may want to normalize the lps is that we actually don't need to evaluate the joint logp of the whole model, but only those variables that depend on the marginalized one. In the future we may want to be more efficient and compile a logp with |
Yep, I'm convinced. Will make the changes |
I think in the future we can include the optimisations where we compile the joint_logps all at once. As well as only using a logp that includes terms which contain But this should work for now |
2bdae64
to
c21ae69
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small comments, hopefully that's all on my end :)
var_names : sequence of str, optional | ||
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables | ||
return_samples : bool, default True | ||
If True, also return samples of the marginalized variables |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstrings are wrong. Also would be nice to add a code example as we do for other methods
idata : InferenceData | ||
InferenceData with var_names added to posterior |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to emphasize lps will be called lp_{varname}
and be found in the posterior group, same for samples.
logps = np.array(logvs) | ||
rv_dict["lp_" + rv.name] = log_softmax( | ||
np.reshape( | ||
logps, | ||
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:], | ||
), | ||
axis=len(stacked_dims), | ||
) | ||
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is done in both branches, move it out and write only once?
axis1=0, | ||
axis2=-1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean perhaps moveaxis(..., -1, 0)
? This will fail if rv_shape
is larger than 1 no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If that was a case we should add a test as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This adopts logic from replace_finite_discrete_marginal_subgraph
is there a reason this might break and that won't?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you mean the finite_discrete_marginal_rv_logp
and I think it's a bug there as well
var_names : sequence of str, optional | ||
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is wrong, it's a log-probability right? And the default is "all marginalized variables"
"""Test that marginalization works for batched random variables""" | ||
with MarginalModel() as m: | ||
sigma = pm.HalfNormal("sigma") | ||
idx = pm.Bernoulli("idx", p=0.7, shape=(2, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I would give a different length to each dim and check the come out correctly in the idata
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I nice feature (no need to block this PR, can be a follow up issue) would be to reuse dims in the computed variables if the user specified dims for the marginalized variables
As a follow up we may want to standardize the signature of marginalize and recover marginals to allow passing strings or the variables in either case. Right now each is restricted to a different type which feels suboptimal |
else: | ||
var_names = {var_names} | ||
|
||
var_names = {var if isinstance(var, str) else var.name for var in var_names} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One reason I don't like sets is they introduce randomness all over the place.
This made me realize we should allow users to pass a seed and split it for each of the compile_pymc
when we are sampling the Categorical (there's a get_seeds_per_chain
utility in PyMC).
But even with a seed the draws will be different depending on the order we end up creating the functions due to this set
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps) | ||
rv_loglike_fn = compile_pymc( | ||
inputs=other_values, | ||
outputs=[log_softmax(joint_logps, axis=0), sample_rv_outs], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick, move the repeated log_softmax before the if/else
8170109
to
9d9daa4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small confusion, otherwise everything looks ready
joint_logps = pt.moveaxis(joint_logps, 0, -1) | ||
|
||
rv_loglike_fn = None | ||
joint_logps_norm = log_softmax(joint_logps, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this the be last axis now?
axis=1, | ||
) | ||
|
||
np.testing.assert_almost_equal( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add a sanity check assert that logsumexp(lps) is close to 0?
This is a PR to add support for the
recover_marginals
method. This allows us to sample values and get access to the logps of discrete variables which we marginalized out during sampling.Closes #286