From 2984a7e3b31d6719946fcf86663d6d99570bd434 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Fri, 5 Nov 2021 17:45:22 +0000 Subject: [PATCH] Implement AutoHierarchicalNormalMessenger (#2955) --- docs/source/infer.autoguide.rst | 8 ++ pyro/infer/autoguide/__init__.py | 2 + pyro/infer/autoguide/effect.py | 160 ++++++++++++++++++++++++++++++- tests/infer/test_autoguide.py | 16 ++++ 4 files changed, 181 insertions(+), 5 deletions(-) diff --git a/docs/source/infer.autoguide.rst b/docs/source/infer.autoguide.rst index 0b85e65179..aaa02a93a5 100644 --- a/docs/source/infer.autoguide.rst +++ b/docs/source/infer.autoguide.rst @@ -141,6 +141,14 @@ AutoNormalMessenger :member-order: bysource :show-inheritance: +AutoHierarchicalNormalMessenger +------------------------------- +.. autoclass:: pyro.infer.autoguide.AutoHierarchicalNormalMessenger + :members: + :undoc-members: + :member-order: bysource + :show-inheritance: + AutoRegressiveMessenger ----------------------- .. autoclass:: pyro.infer.autoguide.AutoRegressiveMessenger diff --git a/pyro/infer/autoguide/__init__.py b/pyro/infer/autoguide/__init__.py index 30f928943c..6d23ebadf7 100644 --- a/pyro/infer/autoguide/__init__.py +++ b/pyro/infer/autoguide/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from pyro.infer.autoguide.effect import ( + AutoHierarchicalNormalMessenger, AutoMessenger, AutoNormalMessenger, AutoRegressiveMessenger, @@ -50,6 +51,7 @@ "AutoMultivariateNormal", "AutoNormal", "AutoNormalMessenger", + "AutoHierarchicalNormalMessenger", "AutoNormalizingFlow", "AutoRegressiveMessenger", "AutoStructured", diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 8c12e72197..d7abe537b0 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch.distributions import biject_to, constraints @@ -63,9 +63,9 @@ def call(self, *args, **kwargs): return tuple(v for _, v in sorted(result.items())) @torch.no_grad() - def _remove_outer_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor: + def _adjust_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor: """ - Removes particle plates from initial values of parameters. + Adjusts plates for generating initial values of parameters. """ for f in get_plates(): full_size = getattr(f, "full_size", f.size) @@ -187,7 +187,7 @@ def _get_params(self, name: str, prior: Distribution): event_dim = transform.domain.event_dim constrained = self.init_loc_fn({"name": name, "fn": prior}).detach() unconstrained = transform.inv(constrained) - init_loc = self._remove_outer_plates(unconstrained, event_dim) + init_loc = self._adjust_plates(unconstrained, event_dim) init_scale = torch.full_like(init_loc, self._init_scale) deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim)) @@ -211,6 +211,156 @@ def _get_posterior_median(self, name, prior): return transform(loc) +class AutoHierarchicalNormalMessenger(AutoNormalMessenger): + """ + :class:`AutoMessenger` with mean-field normal posterior conditional on all dependencies. + + The mean-field posterior at any site is a transformed normal distribution, + the mean of which depends on the value of that site given its dependencies in the model:: + + loc_total = loc + transform.inv(prior.mean) * weight + + Where the value of ``prior.mean`` is conditional on upstream sites in the model, + ``loc`` is independent component of the mean in the untransformed space, + ``weight`` is element-wise factor that scales the prior mean. + This approach doesn't work for distributions that don't have the mean. + + Derived classes may override particular sites and use this simply as a + default, see :class:`AutoNormalMessenger` documentation for example. + + :param callable model: A Pyro model. + :param callable init_loc_fn: A per-site initialization function. + See :ref:`autoguide-initialization` section for available functions. + :param float init_scale: Initial scale for the standard deviation of each + (unconstrained transformed) latent variable. + :param float init_weight: Initial value for the weight of the contribution + of hierarchical sites to posterior mean for each latent variable. + :param list hierarchical_sites: List of latent variables (model sites) + that have hierarchical dependencies. + If None, all sites are assumed to have hierarchical dependencies. If None, for the sites + that don't have upstream sites, the loc and weight of the guide + are representing/learning deviation from the prior. + """ + + # 'element-wise' or 'scalar' + weight_type = "element-wise" + + def __init__( + self, + model: Callable, + *, + init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible), + init_scale: float = 0.1, + amortized_plates: Tuple[str, ...] = (), + init_weight: float = 1.0, + hierarchical_sites: Optional[list] = None, + ): + if not isinstance(init_scale, float) or not (init_scale > 0): + raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) + super().__init__(model, amortized_plates=amortized_plates) + self.init_loc_fn = init_loc_fn + self._init_scale = init_scale + self._init_weight = init_weight + self._hierarchical_sites = hierarchical_sites + self._computing_median = False + + def get_posterior( + self, name: str, prior: Distribution + ) -> Union[Distribution, torch.Tensor]: + if self._computing_median: + return self._get_posterior_median(name, prior) + + with helpful_support_errors({"name": name, "fn": prior}): + transform = biject_to(prior.support) + if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): + # If hierarchical_sites not specified all sites are assumed to be hierarchical + loc, scale, weight = self._get_params(name, prior) + loc = loc + transform.inv(prior.mean) * weight + posterior = dist.TransformedDistribution( + dist.Normal(loc, scale).to_event(transform.domain.event_dim), + transform.with_cache(), + ) + return posterior + else: + # Fall back to mean field when hierarchical_sites list is not empty and site not in the list. + return super().get_posterior(name, prior) + + def _get_params(self, name: str, prior: Distribution): + try: + loc = deep_getattr(self.locs, name) + scale = deep_getattr(self.scales, name) + if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): + weight = deep_getattr(self.weights, name) + return loc, scale, weight + else: + return loc, scale + except AttributeError: + pass + + # Initialize. + with torch.no_grad(): + transform = biject_to(prior.support) + event_dim = transform.domain.event_dim + constrained = self.init_loc_fn({"name": name, "fn": prior}).detach() + unconstrained = transform.inv(constrained) + init_loc = self._adjust_plates(unconstrained, event_dim) + init_scale = torch.full_like(init_loc, self._init_scale) + if self.weight_type == "scalar": + # weight is a single value parameter + init_weight = torch.full((), self._init_weight) + if self.weight_type == "element-wise": + # weight is element-wise + init_weight = torch.full_like(init_loc, self._init_weight) + # if site is hierarchical substract contribution of dependencies from init_loc + if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): + init_prior_mean = transform.inv(prior.mean) + init_prior_mean = self._adjust_plates(init_prior_mean, event_dim) + init_loc = init_loc - init_weight * init_prior_mean + + deep_setattr(self, "locs." + name, PyroParam(init_loc, event_dim=event_dim)) + deep_setattr( + self, + "scales." + name, + PyroParam(init_scale, constraint=constraints.positive, event_dim=event_dim), + ) + if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): + if self.weight_type == "scalar": + # weight is a single value parameter + deep_setattr( + self, + "weights." + name, + PyroParam(init_weight, constraint=constraints.positive), + ) + if self.weight_type == "element-wise": + # weight is element-wise + deep_setattr( + self, + "weights." + name, + PyroParam( + init_weight, + constraint=constraints.positive, + event_dim=event_dim, + ), + ) + return self._get_params(name, prior) + + def median(self, *args, **kwargs): + self._computing_median = True + try: + return self(*args, **kwargs) + finally: + self._computing_median = False + + def _get_posterior_median(self, name, prior): + transform = biject_to(prior.support) + if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): + loc, scale, weight = self._get_params(name, prior) + loc = loc + transform.inv(prior.mean) * weight + else: + loc, scale = self._get_params(name, prior) + return transform(loc) + + class AutoRegressiveMessenger(AutoMessenger): """ :class:`AutoMessenger` with recursively affine-transformed priors using @@ -291,7 +441,7 @@ def _get_params(self, name: str, prior: Distribution): unconstrained = transform.inv(constrained) # Initialize the distribution to be an affine combination: # init_scale * prior + (1 - init_scale) * init_loc - init_loc = self._remove_outer_plates(unconstrained, event_dim) + init_loc = self._adjust_plates(unconstrained, event_dim) init_loc = init_loc * (1 - self._init_scale) init_scale = torch.full_like(init_loc, self._init_scale) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 8e39528bff..76a190fd91 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -32,6 +32,7 @@ AutoGaussian, AutoGuide, AutoGuideList, + AutoHierarchicalNormalMessenger, AutoIAFNormal, AutoLaplaceApproximation, AutoLowRankMultivariateNormal, @@ -119,6 +120,7 @@ def model(): AutoGaussian, AutoGaussianFunsor, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, AutoRegressiveMessenger, ], ) @@ -378,6 +380,7 @@ def __init__(self, model): AutoGaussian, AutoGaussianFunsor, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, ], ) @pytest.mark.parametrize("Elbo", [JitTrace_ELBO, JitTraceGraph_ELBO, JitTraceEnum_ELBO]) @@ -451,6 +454,7 @@ def serialization_model(): ], ), AutoNormalMessenger, + AutoHierarchicalNormalMessenger, xfail_param(AutoRegressiveMessenger, reason="jit does not support _Dirichlet"), ], ) @@ -718,6 +722,7 @@ def model(): AutoGaussian, AutoGaussianFunsor, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, ], ) def test_init_loc_fn(auto_class): @@ -783,6 +788,7 @@ def model(): functools.partial(AutoNormal, init_loc_fn=init_to_median), functools.partial(AutoGaussian, init_loc_fn=init_to_median), AutoNormalMessenger, + AutoHierarchicalNormalMessenger, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) @@ -874,6 +880,7 @@ def forward(self): AutoGaussian, AutoGaussianFunsor, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, AutoRegressiveMessenger, ], ) @@ -1018,6 +1025,7 @@ def forward(self, x, y=None): AutoGaussian, AutoGaussianFunsor, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, AutoRegressiveMessenger, ], ) @@ -1049,6 +1057,7 @@ def model(): AutoDelta, AutoNormal, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, AutoRegressiveMessenger, ], ) @@ -1086,6 +1095,7 @@ def model(x, y=None, batch_size=None): "auto_class", [ AutoNormalMessenger, + AutoHierarchicalNormalMessenger, AutoRegressiveMessenger, ], ) @@ -1258,6 +1268,7 @@ def model(): AutoGaussian, AutoGaussianFunsor, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, AutoRegressiveMessenger, ], ) @@ -1293,6 +1304,7 @@ def model(): AutoGaussian, AutoGaussianFunsor, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, AutoRegressiveMessenger, ], ) @@ -1370,6 +1382,7 @@ def __init__(self, model): AutoGaussian, AutoGaussianFunsor, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, AutoRegressiveMessenger, ], ) @@ -1437,6 +1450,7 @@ def model(data): AutoGaussian, AutoGaussianFunsor, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, AutoRegressiveMessenger, ], ) @@ -1500,6 +1514,7 @@ def model(data): AutoGaussian, AutoGaussianFunsor, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, AutoRegressiveMessenger, ], ) @@ -1508,6 +1523,7 @@ def test_exact_tree(Guide): AutoNormal, AutoDiagonalNormal, AutoNormalMessenger, + AutoHierarchicalNormalMessenger, AutoRegressiveMessenger, )