Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and features for SubnetLaplace #87

Merged
merged 6 commits into from
Mar 6, 2022
Merged

Fixes and features for SubnetLaplace #87

merged 6 commits into from
Mar 6, 2022

Conversation

edaxberger
Copy link
Collaborator

This PR aims to address #86. Namely, this implements the following changes:

  • It fixes an issue where the subnet indices validity check would fail on GPU as we only allowed torch.LongTensors but not torch.cuda.LongTensors.
  • It technically enables marginal likelihood computation and therefore optimization with SubnetLaplace (i.e. essentially makes sure that there are no dimension errors when calling la.log_marginal_likelihood() on a SubnetLaplace model); however we should generally not recommend using this, as the marginal likelihood is probably not meaningful when computed over a subnetwork (without any additional care that would likely require original research); so I'm not sure whether we should have this feature if we don't know how it'll behave.
  • It adds DiagSubnetLaplace, a subclass of SubnetLaplace that uses a diagonal Hessian approximation (the existing full Hessian variant is now called FullSubnetLaplace); we initially thought that this wouldn't be useful as a full diagonal LA would probably be better in most cases and not incur much overhead, but this feature was explicitly requested in Questions about Subnetwork #86, so why not.

Let me know what you think!

@edaxberger edaxberger linked an issue Jan 31, 2022 that may be closed by this pull request
Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review, looks good to me! I have raised one general concern in my comment.

@@ -117,14 +111,66 @@ def prior_precision_diag(self):

else:
raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.')

@property
def mean_subnet(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue that was not introduced in this PR, but I just thought about it: the mean property should maybe be mean_subset, since "mean" implies a distribution and, by definition, in subnet LA we only have a distribution over the subset. Hence, "mean" is only meaningful in the context of the subset. I know this is mostly semantics, but could actually confuse the user, e.g. when calling subnet_la.mean and subnet_la.posterior_covariance the dimensions won't match.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking over this, Runa! Yes, this is a good point. I've also thought about this, and have considered the solution you proposed, but that would have made some parts of the subnet laplace implementation more complex (can't remember what exactly). But I agree that it makes more sense conceptually and am happy to change this -- I'll open a separate issue for this!

@edaxberger
Copy link
Collaborator Author

@aleximmer Feel free to take a look once you have some time (no rush of course). Would be great to merge this in at some point, as there seems to be some interest in this, now again with issue #90. Thanks a lot!

@aleximmer
Copy link
Owner

Sorry for holding this back. looks good to me.

@runame runame merged commit 5ea2d4b into main Mar 6, 2022
@runame runame deleted the subnetlaplace branch March 6, 2022 20:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Questions about Subnetwork
3 participants