Skip to content

Commit

Permalink
Implement AutoHierarchicalNormalMessenger (#2955)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitkl authored Nov 5, 2021
1 parent 33332bb commit 2984a7e
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 5 deletions.
8 changes: 8 additions & 0 deletions docs/source/infer.autoguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ AutoNormalMessenger
:member-order: bysource
:show-inheritance:

AutoHierarchicalNormalMessenger
-------------------------------
.. autoclass:: pyro.infer.autoguide.AutoHierarchicalNormalMessenger
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:

AutoRegressiveMessenger
-----------------------
.. autoclass:: pyro.infer.autoguide.AutoRegressiveMessenger
Expand Down
2 changes: 2 additions & 0 deletions pyro/infer/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from pyro.infer.autoguide.effect import (
AutoHierarchicalNormalMessenger,
AutoMessenger,
AutoNormalMessenger,
AutoRegressiveMessenger,
Expand Down Expand Up @@ -50,6 +51,7 @@
"AutoMultivariateNormal",
"AutoNormal",
"AutoNormalMessenger",
"AutoHierarchicalNormalMessenger",
"AutoNormalizingFlow",
"AutoRegressiveMessenger",
"AutoStructured",
Expand Down
160 changes: 155 additions & 5 deletions pyro/infer/autoguide/effect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import torch
from torch.distributions import biject_to, constraints
Expand Down Expand Up @@ -63,9 +63,9 @@ def call(self, *args, **kwargs):
return tuple(v for _, v in sorted(result.items()))

@torch.no_grad()
def _remove_outer_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor:
def _adjust_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor:
"""
Removes particle plates from initial values of parameters.
Adjusts plates for generating initial values of parameters.
"""
for f in get_plates():
full_size = getattr(f, "full_size", f.size)
Expand Down Expand Up @@ -187,7 +187,7 @@ def _get_params(self, name: str, prior: Distribution):
event_dim = transform.domain.event_dim
constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
unconstrained = transform.inv(constrained)
init_loc = self._remove_outer_plates(unconstrained, event_dim)
init_loc = self._adjust_plates(unconstrained, event_dim)
init_scale = torch.full_like(init_loc, self._init_scale)

deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim))
Expand All @@ -211,6 +211,156 @@ def _get_posterior_median(self, name, prior):
return transform(loc)


class AutoHierarchicalNormalMessenger(AutoNormalMessenger):
"""
:class:`AutoMessenger` with mean-field normal posterior conditional on all dependencies.
The mean-field posterior at any site is a transformed normal distribution,
the mean of which depends on the value of that site given its dependencies in the model::
loc_total = loc + transform.inv(prior.mean) * weight
Where the value of ``prior.mean`` is conditional on upstream sites in the model,
``loc`` is independent component of the mean in the untransformed space,
``weight`` is element-wise factor that scales the prior mean.
This approach doesn't work for distributions that don't have the mean.
Derived classes may override particular sites and use this simply as a
default, see :class:`AutoNormalMessenger` documentation for example.
:param callable model: A Pyro model.
:param callable init_loc_fn: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
:param float init_scale: Initial scale for the standard deviation of each
(unconstrained transformed) latent variable.
:param float init_weight: Initial value for the weight of the contribution
of hierarchical sites to posterior mean for each latent variable.
:param list hierarchical_sites: List of latent variables (model sites)
that have hierarchical dependencies.
If None, all sites are assumed to have hierarchical dependencies. If None, for the sites
that don't have upstream sites, the loc and weight of the guide
are representing/learning deviation from the prior.
"""

# 'element-wise' or 'scalar'
weight_type = "element-wise"

def __init__(
self,
model: Callable,
*,
init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible),
init_scale: float = 0.1,
amortized_plates: Tuple[str, ...] = (),
init_weight: float = 1.0,
hierarchical_sites: Optional[list] = None,
):
if not isinstance(init_scale, float) or not (init_scale > 0):
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
super().__init__(model, amortized_plates=amortized_plates)
self.init_loc_fn = init_loc_fn
self._init_scale = init_scale
self._init_weight = init_weight
self._hierarchical_sites = hierarchical_sites
self._computing_median = False

def get_posterior(
self, name: str, prior: Distribution
) -> Union[Distribution, torch.Tensor]:
if self._computing_median:
return self._get_posterior_median(name, prior)

with helpful_support_errors({"name": name, "fn": prior}):
transform = biject_to(prior.support)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
# If hierarchical_sites not specified all sites are assumed to be hierarchical
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.mean) * weight
posterior = dist.TransformedDistribution(
dist.Normal(loc, scale).to_event(transform.domain.event_dim),
transform.with_cache(),
)
return posterior
else:
# Fall back to mean field when hierarchical_sites list is not empty and site not in the list.
return super().get_posterior(name, prior)

def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
weight = deep_getattr(self.weights, name)
return loc, scale, weight
else:
return loc, scale
except AttributeError:
pass

# Initialize.
with torch.no_grad():
transform = biject_to(prior.support)
event_dim = transform.domain.event_dim
constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
unconstrained = transform.inv(constrained)
init_loc = self._adjust_plates(unconstrained, event_dim)
init_scale = torch.full_like(init_loc, self._init_scale)
if self.weight_type == "scalar":
# weight is a single value parameter
init_weight = torch.full((), self._init_weight)
if self.weight_type == "element-wise":
# weight is element-wise
init_weight = torch.full_like(init_loc, self._init_weight)
# if site is hierarchical substract contribution of dependencies from init_loc
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
init_prior_mean = transform.inv(prior.mean)
init_prior_mean = self._adjust_plates(init_prior_mean, event_dim)
init_loc = init_loc - init_weight * init_prior_mean

deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim))
deep_setattr(
self,
"scales." + name,
PyroParam(init_scale, constraint=constraints.positive, event_dim=event_dim),
)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
if self.weight_type == "scalar":
# weight is a single value parameter
deep_setattr(
self,
"weights." + name,
PyroParam(init_weight, constraint=constraints.positive),
)
if self.weight_type == "element-wise":
# weight is element-wise
deep_setattr(
self,
"weights." + name,
PyroParam(
init_weight,
constraint=constraints.positive,
event_dim=event_dim,
),
)
return self._get_params(name, prior)

def median(self, *args, **kwargs):
self._computing_median = True
try:
return self(*args, **kwargs)
finally:
self._computing_median = False

def _get_posterior_median(self, name, prior):
transform = biject_to(prior.support)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
loc, scale, weight = self._get_params(name, prior)
loc = loc + transform.inv(prior.mean) * weight
else:
loc, scale = self._get_params(name, prior)
return transform(loc)


class AutoRegressiveMessenger(AutoMessenger):
"""
:class:`AutoMessenger` with recursively affine-transformed priors using
Expand Down Expand Up @@ -291,7 +441,7 @@ def _get_params(self, name: str, prior: Distribution):
unconstrained = transform.inv(constrained)
# Initialize the distribution to be an affine combination:
# init_scale * prior + (1 - init_scale) * init_loc
init_loc = self._remove_outer_plates(unconstrained, event_dim)
init_loc = self._adjust_plates(unconstrained, event_dim)
init_loc = init_loc * (1 - self._init_scale)
init_scale = torch.full_like(init_loc, self._init_scale)

Expand Down
16 changes: 16 additions & 0 deletions tests/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
AutoGaussian,
AutoGuide,
AutoGuideList,
AutoHierarchicalNormalMessenger,
AutoIAFNormal,
AutoLaplaceApproximation,
AutoLowRankMultivariateNormal,
Expand Down Expand Up @@ -119,6 +120,7 @@ def model():
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -378,6 +380,7 @@ def __init__(self, model):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
],
)
@pytest.mark.parametrize("Elbo", [JitTrace_ELBO, JitTraceGraph_ELBO, JitTraceEnum_ELBO])
Expand Down Expand Up @@ -451,6 +454,7 @@ def serialization_model():
],
),
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
xfail_param(AutoRegressiveMessenger, reason="jit does not support _Dirichlet"),
],
)
Expand Down Expand Up @@ -718,6 +722,7 @@ def model():
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
],
)
def test_init_loc_fn(auto_class):
Expand Down Expand Up @@ -783,6 +788,7 @@ def model():
functools.partial(AutoNormal, init_loc_fn=init_to_median),
functools.partial(AutoGaussian, init_loc_fn=init_to_median),
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
],
)
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
Expand Down Expand Up @@ -874,6 +880,7 @@ def forward(self):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1018,6 +1025,7 @@ def forward(self, x, y=None):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1049,6 +1057,7 @@ def model():
AutoDelta,
AutoNormal,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1086,6 +1095,7 @@ def model(x, y=None, batch_size=None):
"auto_class",
[
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1258,6 +1268,7 @@ def model():
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1293,6 +1304,7 @@ def model():
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1370,6 +1382,7 @@ def __init__(self, model):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1437,6 +1450,7 @@ def model(data):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand Down Expand Up @@ -1500,6 +1514,7 @@ def model(data):
AutoGaussian,
AutoGaussianFunsor,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
],
)
Expand All @@ -1508,6 +1523,7 @@ def test_exact_tree(Guide):
AutoNormal,
AutoDiagonalNormal,
AutoNormalMessenger,
AutoHierarchicalNormalMessenger,
AutoRegressiveMessenger,
)

Expand Down

0 comments on commit 2984a7e

Please sign in to comment.