Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement utility to recover marginalized variables from MarginalModel #285

Merged
merged 6 commits into from
Dec 25, 2023

Conversation

zaxtax
Copy link
Contributor

@zaxtax zaxtax commented Dec 15, 2023

This is a PR to add support for the recover_marginals method. This allows us to sample values and get access to the logps of discrete variables which we marginalized out during sampling.

Closes #286

@zaxtax zaxtax marked this pull request as draft December 15, 2023 12:21
@zaxtax zaxtax requested a review from ricardoV94 December 15, 2023 12:22
@zaxtax zaxtax force-pushed the sampling_discrete_variable branch from b0c58b4 to 1fc5d55 Compare December 15, 2023 12:46
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, I will try and see how we can get transforms out of the way

pymc_experimental/model/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/tests/model/test_marginal_model.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 added the enhancements New feature or request label Dec 15, 2023

rv_loglike_fn = None
if include_samples:
sample_rv_outs = pm.Categorical.dist(logit_p=joint_logps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these joint_logps normalized? pm.Categorical won't do it under the hood

Copy link
Member

@ricardoV94 ricardoV94 Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will when you use logit_p, but the logits added to the inferencedata directly will still not be normalized. I think it may be more intuitive if they are but not sure.

pymc_experimental/model/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/model/marginal_model.py Outdated Show resolved Hide resolved
if var_names is None:
var_names = self.marginalized_rvs

joint_logp = self.logp()
Copy link
Member

@ricardoV94 ricardoV94 Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.logp will return the logp with the variables marginalized so they won't be part of the graph. I guess that's why you have on_unused_inputs issues later? I imagine you want the original logp that does not marginalize variables so you can give them values.

That's why I was asking how you handle multiple related marginalized variables. It seems to me you have to either evaluate one at a time, conditioned on the previous marginalized variables already evaluated, or create a joint logp for all the combinations of the marginalized variables.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a test that creates nested marginalized variables: https://github.com/pymc-devs/pymc-experimental/blob/8046695e600970bb30a107376281d3e477a66dd0/pymc_experimental/tests/model/test_marginal_model.py#L169-L204

They get represented as an OpFromGraph where there are more than one output RVs without values. You can raise NotImplementedError for these cases for now so that you know you're working only with independent simple marginalized variables. Still I think you don't want to work with self.logp

Copy link
Member

@ricardoV94 ricardoV94 Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._logp() skips the marginalization step, maybe that's what you need?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose to marginalise each discrete variable one at a time. Presumably I still need marginalise for that reason

@zaxtax zaxtax changed the title Adding unmarginalize Adding recover_marginals Dec 15, 2023
self.register_rv(rv, name=rv.name)

def recover_marginals(
self, idata, var_names=None, include_samples=False, extend_inferencedata=True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would perhaps default include_samples=True by default. Also maybe a different name? return_samples or just sample?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The discrete samples are often not as good for understanding the tail probabilities

This can be seen by comparing the changepoint's logps vs the discrete samples discussed on https://mc-stan.org/docs/stan-users-guide/change-point.html

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that but I still think it doesn't hurt to include by default. logits are not something many users grok intuitively

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated name

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, I left 2 comments.

I would a test with multiple marginalized dependent variables and after that I think this would be pretty much there (except for the transforms ofc)

pymc_experimental/model/marginal_model.py Show resolved Hide resolved
pymc_experimental/tests/model/test_marginal_model.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 changed the title Adding recover_marginals Implement utility to recover marginalized variables from MarginalModel Dec 16, 2023
@zaxtax zaxtax force-pushed the sampling_discrete_variable branch 3 times, most recently from 9514f90 to e4c9db7 Compare December 18, 2023 18:35
@zaxtax zaxtax marked this pull request as ready for review December 19, 2023 14:11
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good 😊

Just two comments plus the question of compiling a single function for when we have multiple marginalized RVs. I am not sure the speed benefits outweigh the extra complexity at this point so fine to leave as is

self.register_rv(rv, name=rv.name)

def recover_marginals(
self, idata, var_names=None, return_samples=False, extend_inferencedata=True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think we should return samples by default.

pymc_experimental/model/marginal_model.py Show resolved Hide resolved
@ricardoV94
Copy link
Member

ricardoV94 commented Dec 20, 2023

Ah one reason I see for why we may want to normalize the lps is that we actually don't need to evaluate the joint logp of the whole model, but only those variables that depend on the marginalized one.

In the future we may want to be more efficient and compile a logp with vars=[marginalized, *dependent_RVs] and the unnormalized lps don't make as much sense there. In contrast, the normalized lps should come out exactly the same.

@zaxtax
Copy link
Contributor Author

zaxtax commented Dec 20, 2023

Ah one reason I see for why we may want to normalize the lps is that we actually don't need to evaluate the joint logp of the whole model, but only those variables that depend on the marginalized one.

In the future we may want to be more efficient and compile a logp with vars=[marginalized, *dependent_RVs] and the unnormalized lps don't make as much sense there. In contrast, the normalized lps should come out exactly the same.

Yep, I'm convinced. Will make the changes

@zaxtax
Copy link
Contributor Author

zaxtax commented Dec 20, 2023

I think in the future we can include the optimisations where we compile the joint_logps all at once. As well as only using a logp that includes terms which contain marginalized_value.

But this should work for now

@zaxtax zaxtax force-pushed the sampling_discrete_variable branch from 2bdae64 to c21ae69 Compare December 20, 2023 15:01
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comments, hopefully that's all on my end :)

Comment on lines 296 to 309
var_names : sequence of str, optional
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
return_samples : bool, default True
If True, also return samples of the marginalized variables
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstrings are wrong. Also would be nice to add a code example as we do for other methods

Comment on lines 305 to 306
idata : InferenceData
InferenceData with var_names added to posterior
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to emphasize lps will be called lp_{varname} and be found in the posterior group, same for samples.

pymc_experimental/model/marginal_model.py Show resolved Hide resolved
Comment on lines 401 to 409
logps = np.array(logvs)
rv_dict["lp_" + rv.name] = log_softmax(
np.reshape(
logps,
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
),
axis=len(stacked_dims),
)
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is done in both branches, move it out and write only once?

Comment on lines 354 to 355
axis1=0,
axis2=-1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean perhaps moveaxis(..., -1, 0)? This will fail if rv_shape is larger than 1 no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that was a case we should add a test as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adopts logic from replace_finite_discrete_marginal_subgraph is there a reason this might break and that won't?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mean the finite_discrete_marginal_rv_logp and I think it's a bug there as well

pymc_experimental/model/marginal_model.py Outdated Show resolved Hide resolved
Comment on lines 296 to 297
var_names : sequence of str, optional
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong, it's a log-probability right? And the default is "all marginalized variables"

pymc_experimental/model/marginal_model.py Show resolved Hide resolved
pymc_experimental/model/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/model/marginal_model.py Outdated Show resolved Hide resolved
pymc_experimental/model/marginal_model.py Outdated Show resolved Hide resolved
"""Test that marginalization works for batched random variables"""
with MarginalModel() as m:
sigma = pm.HalfNormal("sigma")
idx = pm.Bernoulli("idx", p=0.7, shape=(2, 2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I would give a different length to each dim and check the come out correctly in the idata

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I nice feature (no need to block this PR, can be a follow up issue) would be to reuse dims in the computed variables if the user specified dims for the marginalized variables

pymc_experimental/model/marginal_model.py Show resolved Hide resolved
pymc_experimental/model/marginal_model.py Show resolved Hide resolved
@ricardoV94
Copy link
Member

As a follow up we may want to standardize the signature of marginalize and recover marginals to allow passing strings or the variables in either case. Right now each is restricted to a different type which feels suboptimal

pymc_experimental/model/marginal_model.py Outdated Show resolved Hide resolved
else:
var_names = {var_names}

var_names = {var if isinstance(var, str) else var.name for var in var_names}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One reason I don't like sets is they introduce randomness all over the place.

This made me realize we should allow users to pass a seed and split it for each of the compile_pymc when we are sampling the Categorical (there's a get_seeds_per_chain utility in PyMC).

But even with a seed the draws will be different depending on the order we end up creating the functions due to this set

sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
rv_loglike_fn = compile_pymc(
inputs=other_values,
outputs=[log_softmax(joint_logps, axis=0), sample_rv_outs],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick, move the repeated log_softmax before the if/else

@zaxtax zaxtax force-pushed the sampling_discrete_variable branch from 8170109 to 9d9daa4 Compare December 22, 2023 00:52
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small confusion, otherwise everything looks ready

joint_logps = pt.moveaxis(joint_logps, 0, -1)

rv_loglike_fn = None
joint_logps_norm = log_softmax(joint_logps, axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this the be last axis now?

axis=1,
)

np.testing.assert_almost_equal(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a sanity check assert that logsumexp(lps) is close to 0?

@zaxtax zaxtax merged commit 4f75687 into pymc-devs:main Dec 25, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request marginalization
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add test for MarginalModel where variable depends on two marginalized variables
2 participants