Skip to content

Add simple RandomWalkKernel #3311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/pyro.infer.mcmc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ NUTS
:undoc-members:
:show-inheritance:

RandomWalkKernel
----------------

.. autoclass:: pyro.infer.mcmc.RandomWalkKernel
:members:
:undoc-members:
:show-inheritance:

BlockMassMatrix
---------------

Expand Down
2 changes: 2 additions & 0 deletions pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc.hmc import HMC
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.mcmc.rwkernel import RandomWalkKernel
from pyro.infer.predictive import Predictive
from pyro.infer.renyi_elbo import RenyiELBO
from pyro.infer.rws import ReweightedWakeSleep
Expand Down Expand Up @@ -45,6 +46,7 @@
"MCMC",
"NUTS",
"Predictive",
"RandomWalkKernel",
"RBFSteinKernel",
"RenyiELBO",
"ReweightedWakeSleep",
Expand Down
2 changes: 2 additions & 0 deletions pyro/infer/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from pyro.infer.mcmc.api import MCMC, StreamingMCMC
from pyro.infer.mcmc.hmc import HMC
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.mcmc.rwkernel import RandomWalkKernel

__all__ = [
"ArrowheadMassMatrix",
"BlockMassMatrix",
"HMC",
"MCMC",
"NUTS",
"RandomWalkKernel",
"StreamingMCMC",
]
143 changes: 143 additions & 0 deletions pyro/infer/mcmc/rwkernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
from collections import OrderedDict

import torch

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc.mcmc_kernel import MCMCKernel
from pyro.infer.mcmc.util import initialize_model


class RandomWalkKernel(MCMCKernel):
r"""
Simple gradient-free kernel that utilizes an isotropic gaussian random walk in the unconstrained
latent space of the model. The step size that controls the variance of the kernel is adapted during
warm-up with a simple adaptation scheme that targets a user-provided acceptance probability.

:param model: Python callable containing Pyro primitives.
:param float init_step_size: A positive float that controls the initial step size. Defaults to 0.1.
:param float target_accept_prob: The target acceptance probability used during adaptation of
the step size. Defaults to 0.234.

Example:

>>> true_coefs = torch.tensor([1., 2., 3.])
>>> data = torch.randn(2000, 3)
>>> dim = 3
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample()
>>>
>>> def model(data):
... coefs_mean = torch.zeros(dim)
... coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3)))
... y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels)
... return y
>>>
>>> hmc_kernel = RandomWalkKernel(model, init_step_size=0.2)
>>> mcmc = MCMC(hmc_kernel, num_samples=200, warmup_steps=100)
>>> mcmc.run(data)
>>> mcmc.get_samples()['beta'].mean(0) # doctest: +SKIP
tensor([ 0.9819, 1.9258, 2.9737])
"""

def __init__(
self, model, init_step_size: float = 0.1, target_accept_prob: float = 0.234
):
if not isinstance(init_step_size, float) or init_step_size <= 0.0:
raise ValueError("init_step_size must be a positive float.")

if (
not isinstance(target_accept_prob, float)
or target_accept_prob <= 0.0
or target_accept_prob >= 1.0
):
raise ValueError(
"target_accept_prob must be a float in the interval (0, 1)."
)

self.model = model
self.init_step_size = init_step_size
self.target_accept_prob = target_accept_prob

self._t = 0
self._log_step_size = math.log(init_step_size)
self._accept_cnt = 0
self._mean_accept_prob = 0.0
super().__init__()

def setup(self, warmup_steps, *args, **kwargs):
self._warmup_steps = warmup_steps
(
self._initial_params,
self.potential_fn,
self.transforms,
self._prototype_trace,
) = initialize_model(
self.model,
model_args=args,
model_kwargs=kwargs,
)
self._energy_last = self.potential_fn(self._initial_params)

def sample(self, params):
step_size = math.exp(self._log_step_size)
new_params = {
k: v + step_size * torch.randn(v.shape, dtype=v.dtype, device=v.device)
for k, v in params.items()
}
energy_proposal = self.potential_fn(new_params)
delta_energy = energy_proposal - self._energy_last

accept_prob = (-delta_energy).exp().clamp(max=1.0).item()
rand = pyro.sample(
"rand_t={}".format(self._t),
dist.Uniform(0.0, 1.0),
)
accepted = False
if rand < accept_prob:
accepted = True
params = new_params
self._energy_last = energy_proposal

if self._t <= self._warmup_steps:
adaptation_speed = max(0.001, 0.1 / math.sqrt(1 + self._t))
self._log_step_size += adaptation_speed * (
accept_prob - self.target_accept_prob
)

self._t += 1

if self._t > self._warmup_steps:
n = self._t - self._warmup_steps
if accepted:
self._accept_cnt += 1
else:
n = self._t

self._mean_accept_prob += (accept_prob - self._mean_accept_prob) / n

return params.copy()

@property
def initial_params(self):
return self._initial_params

@initial_params.setter
def initial_params(self, params):
self._initial_params = params

def logging(self):
return OrderedDict(
[
("step size", "{:.2e}".format(math.exp(self._log_step_size))),
("acc. prob", "{:.3f}".format(self._mean_accept_prob)),
]
)

def diagnostics(self):
return {
"acceptance rate": self._accept_cnt / (self._t - self._warmup_steps),
}
40 changes: 40 additions & 0 deletions tests/infer/mcmc/test_rwkernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc.rwkernel import RandomWalkKernel
from tests.common import assert_equal


def test_beta_bernoulli():
alpha = torch.tensor([1.1, 2.2])
beta = torch.tensor([1.1, 2.2])

def model(data):
p_latent = pyro.sample("p_latent", dist.Beta(alpha, beta))
with pyro.plate("data", data.shape[0], dim=-2):
pyro.sample("obs", dist.Bernoulli(p_latent), obs=data)

num_data = 5
true_probs = torch.tensor([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((num_data,))))

kernel = RandomWalkKernel(model)
mcmc = MCMC(kernel, num_samples=2000, warmup_steps=500)
mcmc.run(data)
samples = mcmc.get_samples()

data_sum = data.sum(0)
alpha_post = alpha + data_sum
beta_post = beta + num_data - data_sum
expected_mean = alpha_post / (alpha_post + beta_post)
expected_var = (
expected_mean.pow(2) * beta_post / (alpha_post * (1 + alpha_post + beta_post))
)

assert_equal(samples["p_latent"].mean(0), expected_mean, prec=0.03)
assert_equal(samples["p_latent"].var(0), expected_var, prec=0.005)