-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes and features for SubnetLaplace #87
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late review, looks good to me! I have raised one general concern in my comment.
@@ -117,14 +111,66 @@ def prior_precision_diag(self): | |||
|
|||
else: | |||
raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.') | |||
|
|||
@property | |||
def mean_subnet(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One issue that was not introduced in this PR, but I just thought about it: the mean
property should maybe be mean_subset
, since "mean" implies a distribution and, by definition, in subnet LA we only have a distribution over the subset. Hence, "mean" is only meaningful in the context of the subset. I know this is mostly semantics, but could actually confuse the user, e.g. when calling subnet_la.mean
and subnet_la.posterior_covariance
the dimensions won't match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking over this, Runa! Yes, this is a good point. I've also thought about this, and have considered the solution you proposed, but that would have made some parts of the subnet laplace implementation more complex (can't remember what exactly). But I agree that it makes more sense conceptually and am happy to change this -- I'll open a separate issue for this!
@aleximmer Feel free to take a look once you have some time (no rush of course). Would be great to merge this in at some point, as there seems to be some interest in this, now again with issue #90. Thanks a lot! |
Sorry for holding this back. looks good to me. |
This PR aims to address #86. Namely, this implements the following changes:
torch.LongTensors
but nottorch.cuda.LongTensors
.SubnetLaplace
(i.e. essentially makes sure that there are no dimension errors when callingla.log_marginal_likelihood()
on aSubnetLaplace
model); however we should generally not recommend using this, as the marginal likelihood is probably not meaningful when computed over a subnetwork (without any additional care that would likely require original research); so I'm not sure whether we should have this feature if we don't know how it'll behave.DiagSubnetLaplace
, a subclass ofSubnetLaplace
that uses a diagonal Hessian approximation (the existing full Hessian variant is now calledFullSubnetLaplace
); we initially thought that this wouldn't be useful as a full diagonal LA would probably be better in most cases and not incur much overhead, but this feature was explicitly requested in Questions about Subnetwork #86, so why not.Let me know what you think!