-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement AutoHierarchicalNormalMessenger #2955
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great @vitkl! Before merging, could you also add docs, tests, and imports? I'll add a task list to the PR description.
Would it make sense to constrain the weight between 0 and 1? loc = loc * (1 - weight) + transform.inv(prior.mean) * weight |
Yes, I think that would make sense. OTOH weights outside of that range should not result in NAN, and my intuition is that it should be easy to learn weight, so I'd opt for the cheaper solution of unconstrained weight. |
@vitkl you should now be able to merge the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logic looks great, I have only minor comments about docs & tests.
Tests: could you make a second pass at test_autoguide.py
and try adding AutoHierarchicalNormalMessenger
to every test supported by AutoNormalMessenger
? In particular test_exact*
tests are good at catching math errors. You may need to add logic to avoid jit elbos in those tests, following AutoRegressiveMessenger
logic.
Adding tests now. Everything seems to work (no need to avoid jit) except ___________________________________________________________ test_subsample_model_amortized[AutoHierarchicalNormalMessenger] ___________________________________________________________
auto_class = <class 'pyro.infer.autoguide.effect.AutoHierarchicalNormalMessenger'>
@pytest.mark.parametrize(
"auto_class",
[
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
def test_subsample_model_amortized(auto_class):
def model(x, y=None, batch_size=None):
loc = pyro.param("loc", lambda: torch.tensor(0.0))
scale = pyro.param(
"scale", lambda: torch.tensor(1.0), constraint=constraints.positive
)
with pyro.plate("batch", len(x), subsample_size=batch_size):
batch_x = pyro.subsample(x, event_dim=0)
batch_y = pyro.subsample(y, event_dim=0) if y is not None else None
mean = loc + scale * batch_x
sigma = pyro.sample("sigma", dist.LogNormal(0.0, 1.0))
return pyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y)
guide1 = auto_class(model)
guide2 = auto_class(model, amortized_plates=("batch",))
full_size = 50
batch_size = 20
pyro.set_rng_seed(123456789)
x = torch.randn(full_size)
with torch.no_grad():
y = model(x)
assert y.shape == x.shape
for guide in guide1, guide2:
pyro.get_param_store().clear()
pyro.set_rng_seed(123456789)
svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO())
for step in range(5):
> svi.step(x, y, batch_size=batch_size)
tests/infer/test_autoguide.py:1131:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyro/infer/svi.py:145: in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
pyro/infer/trace_elbo.py:140: in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
pyro/infer/elbo.py:182: in _get_traces
yield self._get_trace(model, guide, args, kwargs)
pyro/infer/trace_elbo.py:57: in _get_trace
model_trace, guide_trace = get_importance_trace(
pyro/infer/enum.py:57: in get_importance_trace
guide(*args, **kwargs)
pyro/nn/module.py:636: in cached_fn
return fn(self, *args, **kwargs)
pyro/infer/autoguide/effect.py:46: in __call__
return super().__call__(*args, **kwargs)
pyro/poutine/guide.py:45: in __call__
self.model(*args, **kwargs)
tests/infer/test_autoguide.py:1112: in model
sigma = pyro.sample("sigma", dist.LogNormal(0.0, 1.0))
pyro/primitives.py:164: in sample
apply_stack(msg)
pyro/poutine/runtime.py:213: in apply_stack
frame._process_message(msg)
pyro/poutine/messenger.py:154: in _process_message
return method(msg)
pyro/poutine/guide.py:62: in _pyro_sample
posterior = self.get_posterior(msg["name"], prior)
pyro/infer/autoguide/effect.py:276: in get_posterior
loc, scale, weight = self._get_params(name, prior)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AutoHierarchicalNormalMessenger(), name = 'sigma', prior = LogNormal()
def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
weight = deep_getattr(self.weights, name)
return loc, scale, weight
else:
return loc, scale
except AttributeError:
pass
# Initialize.
with torch.no_grad():
transform = biject_to(prior.support)
event_dim = transform.domain.event_dim
constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
unconstrained = transform.inv(constrained)
init_loc = self._remove_outer_plates(unconstrained, event_dim)
init_scale = torch.full_like(init_loc, self._init_scale)
if self.weight_type == "scalar":
# weight is a single value parameter
init_weight = torch.full((), self._init_weight)
if self.weight_type == "element-wise":
# weight is element-wise
init_weight = torch.full_like(init_loc, self._init_weight)
# if site is hierarchical substract contribution of dependencies from init_loc
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
> init_loc = init_loc - init_weight * transform.inv(prior.mean)
E RuntimeError: The size of tensor a (50) must match the size of tensor b (20) at non-singleton dimension 0
pyro/infer/autoguide/effect.py:315: RuntimeError I don't understand why this is happening because the logic is the same as |
@vitkl can you push so I can inspect in a debugger? happy to pair code through the error |
Aha, I think you just need to - init_loc = init_loc - init_weight * transform.inv(prior.mean)
+ mean = transform.inv(prior.mean)
+ mean = self._remove_outer_plates(mean, event_dim=event_dim)
+ init_loc = init_loc - init_weight * mean |
I will add the code with other changes soon. I see - but why should the
fail for non-amortised test? In that test, the subsampling plate is not
supposed to be removed, right?
…On Wed, 3 Nov 2021, 20:16 Fritz Obermeyer, ***@***.***> wrote:
Aha, I think you just need to self._remove_outer_plates(prior.mean):
- init_loc = init_loc - init_weight * transform.inv(prior.mean)+ mean = self._remove_outer_plates(transform.inv(prior.mean))+ init_loc = init_loc - init_weight * mean
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#2955 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFMFTV6IXZ22C6JXUOY7KJ3UKGRDZANCNFSM5HBQUKUA>
.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
|
|
I added the changes you requested and renamed |
Hmm, looks like docs build is failing due to a failure to build the |
Blocked by pypa/setuptools#2849 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for adding the new tests.
Based on #2953
The mean-field posterior at any site is a transformed normal distribution, the mean of which depends on the value of that site given its dependencies in the model:
Where the value of
prior.mean
is conditional on upstream sites in the model. This approach doesn't work for distributions that don't have the mean, as discussed here #2953 (comment).Currently, for the sites that don't have upstream sites, the guide is representing/learning deviation from the prior. Would it be useful to auto-detect sites that do not have upstream sites? Alternatively, hierarchical sites could be specified by the user. @fritzo WDYT?
Tasks
AutoNormalMessenger
pyro/infer/autoguide/__init__.py
and add an entry in that file's__all__
AutoNormalMessenger