Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AutoHierarchicalNormalMessenger #2955

Merged
merged 46 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
84feca4
Implement Effect_ELBO and AutoRegressiveMessenger
fritzo Oct 28, 2021
afe472c
Fix max_plate_nesting typo
fritzo Oct 28, 2021
75a0a6f
Add docs
fritzo Oct 28, 2021
b656276
Get smoke tests passing
fritzo Oct 28, 2021
6c29e1b
Fix tests
fritzo Oct 28, 2021
2993402
Add AutoNormalMessenger
fritzo Oct 29, 2021
6eb02e4
Add example to docstring
fritzo Oct 29, 2021
84d5e60
Support calling guide(*args, **kwargs)
fritzo Oct 30, 2021
3d5250e
Support init_loc_fn, init_scale
fritzo Oct 30, 2021
0ac9ae1
Add more tests
fritzo Oct 30, 2021
0637e00
Fix jit tests
fritzo Oct 30, 2021
bb092a9
Add more docs
fritzo Oct 30, 2021
bc889e4
Revert unnecessary change
fritzo Oct 30, 2021
4a473fb
draft AutoHierarchicalNormalMessenger
vitkl Oct 30, 2021
a564f5f
added hierarchical loc and hierarchical_sites list
vitkl Oct 30, 2021
56d26af
Document relationship to AutoNormal
fritzo Oct 30, 2021
5fe690a
updated docs, added initialisation for hierarchal sites
vitkl Oct 30, 2021
65c9c3e
renamed class and added rst
vitkl Oct 30, 2021
5683f70
added to autoguide module
vitkl Oct 30, 2021
a9cc1cd
bug fixes
vitkl Oct 31, 2021
bd6b69e
added tests for AutoHierarchicalNormal
vitkl Oct 31, 2021
5db1006
fixed isort
vitkl Oct 31, 2021
9759577
fixed docs error
vitkl Oct 31, 2021
598dca8
added transform.inv(prior.mean) to initialisation
vitkl Oct 31, 2021
26bd9bf
Support subsampling and amortization
fritzo Oct 31, 2021
8556734
lint
fritzo Oct 31, 2021
7715687
Remove debug statement
fritzo Oct 31, 2021
4ea9288
rename AutoHierarchicalNormalMessenger
vitkl Nov 1, 2021
3407c64
element-wise weight;
vitkl Nov 1, 2021
884c435
updated tests
vitkl Nov 1, 2021
da01b98
fixed lint
vitkl Nov 1, 2021
cf41ae6
fixed docs
vitkl Nov 1, 2021
5c3e5bc
Add a poutine.unwrap() helper function
fritzo Nov 1, 2021
f03401a
Eliminate Effect_ELBO and mixin stuff
fritzo Nov 1, 2021
8749c96
Fix docs
fritzo Nov 1, 2021
c161ebb
Revert unnecessary change
fritzo Nov 1, 2021
872b82d
Fix poutine.unwrap()
fritzo Nov 1, 2021
26b443d
merge latest messenger guides;
vitkl Nov 1, 2021
06cc3eb
merge dev
vitkl Nov 1, 2021
96eb005
resolved more conflicts
vitkl Nov 1, 2021
e4fd8af
updated hierarchical guide with latest changes
vitkl Nov 1, 2021
f25365f
fixed lint
vitkl Nov 1, 2021
a3e8d07
fixed isort
vitkl Nov 1, 2021
f48ba6b
updated tests and docs
vitkl Nov 3, 2021
8db479f
changes to initialisation and docs
vitkl Nov 4, 2021
06c6354
fixed initialisation
vitkl Nov 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/infer.autoguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ AutoNormalMessenger
:member-order: bysource
:show-inheritance:

AutoHierarchicalNormalMessenger
-------------------------------
.. autoclass:: pyro.infer.autoguide.AutoHierarchicalNormalMessenger
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:

AutoRegressiveMessenger
-----------------------
.. autoclass:: pyro.infer.autoguide.AutoRegressiveMessenger
Expand Down
2 changes: 2 additions & 0 deletions pyro/infer/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from pyro.infer.autoguide.effect import (
AutoHierarchicalNormalMessenger,
AutoMessenger,
AutoNormalMessenger,
AutoRegressiveMessenger,
Expand Down Expand Up @@ -50,6 +51,7 @@
"AutoMultivariateNormal",
"AutoNormal",
"AutoNormalMessenger",
"AutoHierarchicalNormalMessenger",
"AutoNormalizingFlow",
"AutoRegressiveMessenger",
"AutoStructured",
Expand Down
159 changes: 154 additions & 5 deletions pyro/infer/autoguide/effect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import torch
from torch.distributions import biject_to, constraints
Expand Down Expand Up @@ -63,9 +63,9 @@ def call(self, *args, **kwargs):
return tuple(v for _, v in sorted(result.items()))

@torch.no_grad()
def _remove_outer_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor:
def _adjust_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor:
"""
Removes particle plates from initial values of parameters.
Adjusts plates for generating initial values of parameters.
"""
for f in get_plates():
full_size = getattr(f, "full_size", f.size)
Expand Down Expand Up @@ -187,7 +187,7 @@ def _get_params(self, name: str, prior: Distribution):
event_dim = transform.domain.event_dim
constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
unconstrained = transform.inv(constrained)
init_loc = self._remove_outer_plates(unconstrained, event_dim)
init_loc = self._adjust_plates(unconstrained, event_dim)
init_scale = torch.full_like(init_loc, self._init_scale)

deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim))
Expand All @@ -211,6 +211,155 @@ def _get_posterior_median(self, name, prior):
return transform(loc)


class AutoHierarchicalNormalMessenger(AutoNormalMessenger):
vitkl marked this conversation as resolved.
Show resolved Hide resolved
"""
:class:`AutoMessenger` with mean-field normal posterior conditional on all dependencies.

The mean-field posterior at any site is a transformed normal distribution,
the mean of which depends on the value of that site given its dependencies in the model:
vitkl marked this conversation as resolved.
Show resolved Hide resolved

loc_total = loc + transform.inv(prior.mean) * weight

Where the value of `prior.mean` is conditional on upstream sites in the model,
vitkl marked this conversation as resolved.
Show resolved Hide resolved
loc is independent component of the mean in the untransformed space,
weight is element-wise factor that scales the prior mean.
This approach doesn't work for distributions that don't have the mean.

Derived classes may override particular sites and use this simply as a
default, see :class:`AutoNormalMessenger` documentation for example.

:param callable model: A Pyro model.
:param callable init_loc_fn: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
:param float init_scale: Initial scale for the standard deviation of each
vitkl marked this conversation as resolved.
Show resolved Hide resolved
(unconstrained transformed) latent variable.
:param float init_weight: Initial value for the weight of the contribution
of hierarchical sites to posterior mean for each latent variable.
:param list hierarchical_sites: List of latent variables (model sites)
that have hierarchical dependencies.
If None, all sites are assumed to have hierarchical dependencies. If None, for the sites
that don't have upstream sites, the loc and weight of the guide
are representing/learning deviation from the prior.
"""

# 'element-wise' or 'scalar'
weight_type = "element-wise"

def __init__(
self,
model: Callable,
*,
init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible),
init_scale: float = 0.1,
amortized_plates: Tuple[str, ...] = (),
init_weight: float = 1.0,
hierarchical_sites: Optional[list] = None,
):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
super().__init__(model, amortized_plates=amortized_plates)
self.init_loc_fn = init_loc_fn
self._init_scale = init_scale
self._init_weight = init_weight
self._hierarchical_sites = hierarchical_sites
self._computing_median = False

def get_posterior(
self, name: str, prior: Distribution
) -> Union[Distribution, torch.Tensor]:
if self._computing_median:
return self._get_posterior_median(name, prior)

with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
# If hierarchical_sites not specified all sites are assumed to be hierarchical
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.mean) * weight
posterior = dist.TransformedDistribution(
dist.Normal(loc, scale).to_event(transform.domain.event_dim),
transform.with_cache(),
)
return posterior
else:
# Fall back to mean field when hierarchical_sites list is not empty and site not in the list.
return super().get_posterior(name, prior)

def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
weight = deep_getattr(self.weights, name)
return loc, scale, weight
else:
return loc, scale
except AttributeError:
pass

# Initialize.
with torch.no_grad():
transform = biject_to(prior.support)
event_dim = transform.domain.event_dim
constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
unconstrained = transform.inv(constrained)
init_loc = self._adjust_plates(unconstrained, event_dim)
init_scale = torch.full_like(init_loc, self._init_scale)
if self.weight_type == "scalar":
# weight is a single value parameter
init_weight = torch.full((), self._init_weight)
if self.weight_type == "element-wise":
# weight is element-wise
init_weight = torch.full_like(init_loc, self._init_weight)
# if site is hierarchical substract contribution of dependencies from init_loc
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
init_prior_mean = self._adjust_plates(prior.mean, event_dim)
init_loc = init_loc - init_weight * transform.inv(init_prior_mean)
vitkl marked this conversation as resolved.
Show resolved Hide resolved

deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim))
deep_setattr(
self,
"scales." + name,
PyroParam(init_scale, constraint=constraints.positive, event_dim=event_dim),
)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
if self.weight_type == "scalar":
# weight is a single value parameter
deep_setattr(
self,
"weights." + name,
PyroParam(init_weight, constraint=constraints.positive),
)
if self.weight_type == "element-wise":
# weight is element-wise
deep_setattr(
self,
"weights." + name,
PyroParam(
init_weight,
constraint=constraints.positive,
event_dim=event_dim,
),
)
return self._get_params(name, prior)

def median(self, *args, **kwargs):
self._computing_median = True
try:
return self(*args, **kwargs)
finally:
self._computing_median = False

def _get_posterior_median(self, name, prior):
transform = biject_to(prior.support)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.mean) * weight
else:
loc, scale = self._get_params(name, prior)
return transform(loc)


class AutoRegressiveMessenger(AutoMessenger):
"""
:class:`AutoMessenger` with recursively affine-transformed priors using
Expand Down Expand Up @@ -291,7 +440,7 @@ def _get_params(self, name: str, prior: Distribution):
unconstrained = transform.inv(constrained)
# Initialize the distribution to be an affine combination:
# init_scale * prior + (1 - init_scale) * init_loc
init_loc = self._remove_outer_plates(unconstrained, event_dim)
init_loc = self._adjust_plates(unconstrained, event_dim)
init_loc = init_loc * (1 - self._init_scale)
init_scale = torch.full_like(init_loc, self._init_scale)

Expand Down
16 changes: 16 additions & 0 deletions tests/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
AutoGaussian,
AutoGuide,
AutoGuideList,
AutoHierarchicalNormalMessenger,
AutoIAFNormal,
AutoLaplaceApproximation,
AutoLowRankMultivariateNormal,
Expand Down Expand Up @@ -119,6 +120,7 @@ def model():
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -378,6 +380,7 @@ def __init__(self, model):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
],
)
@pytest.mark.parametrize("Elbo", [JitTrace_ELBO, JitTraceGraph_ELBO, JitTraceEnum_ELBO])
Expand Down Expand Up @@ -451,6 +454,7 @@ def serialization_model():
],
),
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
xfail_param(AutoRegressiveMessenger, reason="jit does not support _Dirichlet"),
],
)
Expand Down Expand Up @@ -718,6 +722,7 @@ def model():
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
],
)
def test_init_loc_fn(auto_class):
Expand Down Expand Up @@ -783,6 +788,7 @@ def model():
functools.partial(AutoNormal, init_loc_fn=init_to_median),
functools.partial(AutoGaussian, init_loc_fn=init_to_median),
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
],
)
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
Expand Down Expand Up @@ -874,6 +880,7 @@ def forward(self):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1018,6 +1025,7 @@ def forward(self, x, y=None):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1049,6 +1057,7 @@ def model():
AutoDelta,
AutoNormal,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1086,6 +1095,7 @@ def model(x, y=None, batch_size=None):
"auto_class",
[
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1258,6 +1268,7 @@ def model():
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1293,6 +1304,7 @@ def model():
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1370,6 +1382,7 @@ def __init__(self, model):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1437,6 +1450,7 @@ def model(data):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1500,6 +1514,7 @@ def model(data):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand All @@ -1508,6 +1523,7 @@ def test_exact_tree(Guide):
AutoNormal,
AutoDiagonalNormal,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
)

Expand Down