Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AutoHierarchicalNormalMessenger #2955

Merged
merged 46 commits into from
Nov 5, 2021

Conversation

vitkl
Copy link
Contributor

@vitkl vitkl commented Oct 30, 2021

Based on #2953

The mean-field posterior at any site is a transformed normal distribution, the mean of which depends on the value of that site given its dependencies in the model:

loc = loc + transform.inv(prior.mean) * weight

Where the value of prior.mean is conditional on upstream sites in the model. This approach doesn't work for distributions that don't have the mean, as discussed here #2953 (comment).

Currently, for the sites that don't have upstream sites, the guide is representing/learning deviation from the prior. Would it be useful to auto-detect sites that do not have upstream sites? Alternatively, hierarchical sites could be specified by the user. @fritzo WDYT?

Tasks

  • add a section to infer.autoguide.rst as for AutoNormalMessenger
  • add import in pyro/infer/autoguide/__init__.py and add an entry in that file's __all__
  • add to tests in tests/infer/test_autoguide.py, following AutoNormalMessenger

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great @vitkl! Before merging, could you also add docs, tests, and imports? I'll add a task list to the PR description.

pyro/infer/autoguide/effect.py Outdated Show resolved Hide resolved
pyro/infer/autoguide/effect.py Show resolved Hide resolved
pyro/infer/autoguide/effect.py Outdated Show resolved Hide resolved
pyro/infer/autoguide/effect.py Show resolved Hide resolved
pyro/infer/autoguide/effect.py Outdated Show resolved Hide resolved
@vitkl
Copy link
Contributor Author

vitkl commented Oct 30, 2021

Would it make sense to constrain the weight between 0 and 1?

loc = loc * (1 - weight) + transform.inv(prior.mean) * weight

@fritzo
Copy link
Member

fritzo commented Oct 31, 2021

Would it make sense to constrain the weight between 0 and 1?

Yes, I think that would make sense. OTOH weights outside of that range should not result in NAN, and my intuition is that it should be easy to learn weight, so I'd opt for the cheaper solution of unconstrained weight.

@fritzo fritzo removed the Blocked label Nov 1, 2021
@fritzo
Copy link
Member

fritzo commented Nov 1, 2021

@vitkl you should now be able to merge the guide-messenger branch and then the dev branch (I'd recommend doing so in that order to minimize merge conflicts due to our squash-merge policy)

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic looks great, I have only minor comments about docs & tests.

Tests: could you make a second pass at test_autoguide.py and try adding AutoHierarchicalNormalMessenger to every test supported by AutoNormalMessenger? In particular test_exact* tests are good at catching math errors. You may need to add logic to avoid jit elbos in those tests, following AutoRegressiveMessenger logic.

pyro/infer/autoguide/effect.py Outdated Show resolved Hide resolved
pyro/infer/autoguide/effect.py Outdated Show resolved Hide resolved
@vitkl
Copy link
Contributor Author

vitkl commented Nov 3, 2021

Adding tests now. Everything seems to work (no need to avoid jit) except test_subsample_model and I don't think I understand the source of this error:

___________________________________________________________ test_subsample_model_amortized[AutoHierarchicalNormalMessenger] ___________________________________________________________

auto_class = <class 'pyro.infer.autoguide.effect.AutoHierarchicalNormalMessenger'>

    @pytest.mark.parametrize(
        "auto_class",
        [
            AutoNormalMessenger,
            AutoHierarchicalNormalMessenger,
            AutoRegressiveMessenger,
        ],
    )
    def test_subsample_model_amortized(auto_class):
        def model(x, y=None, batch_size=None):
            loc = pyro.param("loc", lambda: torch.tensor(0.0))
            scale = pyro.param(
                "scale", lambda: torch.tensor(1.0), constraint=constraints.positive
            )
            with pyro.plate("batch", len(x), subsample_size=batch_size):
                batch_x = pyro.subsample(x, event_dim=0)
                batch_y = pyro.subsample(y, event_dim=0) if y is not None else None
                mean = loc + scale * batch_x
                sigma = pyro.sample("sigma", dist.LogNormal(0.0, 1.0))
                return pyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y)

        guide1 = auto_class(model)
        guide2 = auto_class(model, amortized_plates=("batch",))

        full_size = 50
        batch_size = 20
        pyro.set_rng_seed(123456789)
        x = torch.randn(full_size)
        with torch.no_grad():
            y = model(x)
        assert y.shape == x.shape

        for guide in guide1, guide2:
            pyro.get_param_store().clear()
            pyro.set_rng_seed(123456789)
            svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO())
            for step in range(5):
>               svi.step(x, y, batch_size=batch_size)

tests/infer/test_autoguide.py:1131:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyro/infer/svi.py:145: in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
pyro/infer/trace_elbo.py:140: in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
pyro/infer/elbo.py:182: in _get_traces
    yield self._get_trace(model, guide, args, kwargs)
pyro/infer/trace_elbo.py:57: in _get_trace
    model_trace, guide_trace = get_importance_trace(
pyro/infer/enum.py:57: in get_importance_trace
    guide(*args, **kwargs)
pyro/nn/module.py:636: in cached_fn
    return fn(self, *args, **kwargs)
pyro/infer/autoguide/effect.py:46: in __call__
    return super().__call__(*args, **kwargs)
pyro/poutine/guide.py:45: in __call__
    self.model(*args, **kwargs)
tests/infer/test_autoguide.py:1112: in model
    sigma = pyro.sample("sigma", dist.LogNormal(0.0, 1.0))
pyro/primitives.py:164: in sample
    apply_stack(msg)
pyro/poutine/runtime.py:213: in apply_stack
    frame._process_message(msg)
pyro/poutine/messenger.py:154: in _process_message
    return method(msg)
pyro/poutine/guide.py:62: in _pyro_sample
    posterior = self.get_posterior(msg["name"], prior)
pyro/infer/autoguide/effect.py:276: in get_posterior
    loc, scale, weight = self._get_params(name, prior)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = AutoHierarchicalNormalMessenger(), name = 'sigma', prior = LogNormal()

    def _get_params(self, name: str, prior: Distribution):
        try:
            loc = deep_getattr(self.locs, name)
            scale = deep_getattr(self.scales, name)
            if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
                weight = deep_getattr(self.weights, name)
                return loc, scale, weight
            else:
                return loc, scale
        except AttributeError:
            pass

        # Initialize.
        with torch.no_grad():
            transform = biject_to(prior.support)
            event_dim = transform.domain.event_dim
            constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
            unconstrained = transform.inv(constrained)
            init_loc = self._remove_outer_plates(unconstrained, event_dim)
            init_scale = torch.full_like(init_loc, self._init_scale)
            if self.weight_type == "scalar":
                # weight is a single value parameter
                init_weight = torch.full((), self._init_weight)
            if self.weight_type == "element-wise":
                # weight is element-wise
                init_weight = torch.full_like(init_loc, self._init_weight)
            # if site is hierarchical substract contribution of dependencies from init_loc
            if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
>               init_loc = init_loc - init_weight * transform.inv(prior.mean)
E               RuntimeError: The size of tensor a (50) must match the size of tensor b (20) at non-singleton dimension 0

pyro/infer/autoguide/effect.py:315: RuntimeError

I don't understand why this is happening because the logic is the same as AutoNormalMessenger and the new guide works for cell2location model.

@fritzo
Copy link
Member

fritzo commented Nov 3, 2021

@vitkl can you push so I can inspect in a debugger? happy to pair code through the error

@fritzo
Copy link
Member

fritzo commented Nov 3, 2021

Aha, I think you just need to self._remove_outer_plates(prior.mean):

- init_loc = init_loc - init_weight * transform.inv(prior.mean)
+ mean = transform.inv(prior.mean)
+ mean = self._remove_outer_plates(mean, event_dim=event_dim)
+ init_loc = init_loc - init_weight * mean

@vitkl
Copy link
Contributor Author

vitkl commented Nov 3, 2021 via email

@fritzo
Copy link
Member

fritzo commented Nov 3, 2021

subsampling plate is not supposed to be removed, right?

._remove_outer_plates() does a little more than subsampling now, it also fixes subsampled plates. Feel free to rename it to ._adjust_plates() or something.

@vitkl
Copy link
Contributor Author

vitkl commented Nov 3, 2021

I added the changes you requested and renamed ._adjust_plates() - but I don't know why the docs test fails.

pyro/infer/autoguide/effect.py Outdated Show resolved Hide resolved
pyro/infer/autoguide/effect.py Outdated Show resolved Hide resolved
pyro/infer/autoguide/effect.py Outdated Show resolved Hide resolved
@fritzo
Copy link
Member

fritzo commented Nov 4, 2021

Hmm, looks like docs build is failing due to a failure to build the lap wheel, unrelated to your changes. Probably some change in an upstream library. I'll look into it...

@fritzo
Copy link
Member

fritzo commented Nov 4, 2021

Blocked by pypa/setuptools#2849

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for adding the new tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants