Skip to content

Commit

Permalink
StreamingMCMC class (#2857)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol authored Jun 16, 2021
1 parent a6d120d commit 8fd0bf5
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 78 deletions.
10 changes: 10 additions & 0 deletions docs/source/pyro.infer.mcmc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ MCMC
:undoc-members:
:show-inheritance:

StreamingMCMC
-------------

.. autoclass:: pyro.infer.mcmc.api.StreamingMCMC
:members:
:undoc-members:
:show-inheritance:

MCMCKernel
----------
.. autoclass:: pyro.infer.mcmc.mcmc_kernel.MCMCKernel
Expand Down Expand Up @@ -43,3 +51,5 @@ Utilities
.. autofunction:: pyro.infer.mcmc.util.initialize_model

.. autofunction:: pyro.infer.mcmc.util.diagnostics

.. autofunction:: pyro.infer.mcmc.util.select_samples
3 changes: 2 additions & 1 deletion pyro/infer/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from pyro.infer.mcmc.adaptation import ArrowheadMassMatrix, BlockMassMatrix
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc.api import MCMC, StreamingMCMC
from pyro.infer.mcmc.hmc import HMC
from pyro.infer.mcmc.nuts import NUTS

Expand All @@ -12,4 +12,5 @@
"HMC",
"MCMC",
"NUTS",
"StreamingMCMC",
]
204 changes: 155 additions & 49 deletions pyro/infer/mcmc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
code that works with different backends.
- minimal memory consumption with multiprocessing and CUDA.
"""

import copy
import json
import logging
import queue
import signal
import threading
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Dict

import torch
import torch.multiprocessing as mp
Expand All @@ -31,7 +33,8 @@
initialize_logger,
)
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.mcmc.util import diagnostics, print_summary
from pyro.infer.mcmc.util import diagnostics, print_summary, select_samples
from pyro.ops.streaming import CountMeanVarianceStats, StatsOfDict, StreamingStats
from pyro.util import optional

MAX_SEED = 2**32 - 1
Expand Down Expand Up @@ -257,7 +260,46 @@ def run(self, *args, **kwargs):
self.terminate(terminate_workers=exc_raised)


class MCMC:
class AbstractMCMC(ABC):
"""
Base class for MCMC methods.
"""
def __init__(self, kernel, num_chains, transforms):
self.kernel = kernel
self.num_chains = num_chains
self.transforms = transforms

@abstractmethod
def run(self, *args, **kwargs):
raise NotImplementedError

def _set_transforms(self, *args, **kwargs):
# Use `kernel.transforms` when available
if getattr(self.kernel, "transforms", None) is not None:
self.transforms = self.kernel.transforms
# Else, get transforms from model (e.g. in multiprocessing).
elif self.kernel.model:
warmup_steps = 0
self.kernel.setup(warmup_steps, *args, **kwargs)
self.transforms = self.kernel.transforms
# Assign default value
else:
self.transforms = {}

def _validate_kernel(self, initial_params):
if isinstance(self.kernel, (HMC, NUTS)) and self.kernel.potential_fn is not None:
if initial_params is None:
raise ValueError("Must provide valid initial parameters to begin sampling"
" when using `potential_fn` in HMC/NUTS kernel.")

def _validate_initial_params(self, initial_params):
for v in initial_params.values():
if v.shape[0] != self.num_chains:
raise ValueError("The leading dimension of tensors in `initial_params` "
"must match the number of chains.")


class MCMC(AbstractMCMC):
"""
Wrapper class for Markov Chain Monte Carlo algorithms. Specific MCMC algorithms
are TraceKernel instances and need to be supplied as a ``kernel`` argument
Expand Down Expand Up @@ -307,28 +349,21 @@ class MCMC:
def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None,
num_chains=1, hook_fn=None, mp_context=None, disable_progbar=False,
disable_validation=True, transforms=None, save_params=None):
super().__init__(kernel, num_chains, transforms)
self.warmup_steps = num_samples if warmup_steps is None else warmup_steps # Stan
self.num_samples = num_samples
self.kernel = kernel
self.transforms = transforms
self.disable_validation = disable_validation
self._samples = None
self._args = None
self._kwargs = None
if save_params is not None:
kernel.save_params = save_params
if isinstance(self.kernel, (HMC, NUTS)) and self.kernel.potential_fn is not None:
if initial_params is None:
raise ValueError("Must provide valid initial parameters to begin sampling"
" when using `potential_fn` in HMC/NUTS kernel.")
self._validate_kernel(initial_params)
parallel = False
if num_chains > 1:
# check that initial_params is different for each chain
if initial_params:
for v in initial_params.values():
if v.shape[0] != num_chains:
raise ValueError("The leading dimension of tensors in `initial_params` "
"must match the number of chains.")
self._validate_initial_params(initial_params)
# FIXME: probably we want to use "spawn" method by default to avoid the error
# CUDA initialization error https://github.com/pytorch/pytorch/issues/2517
# even that we run MCMC in CPU.
Expand All @@ -348,10 +383,7 @@ def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None,
else:
if initial_params:
initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()}

self.num_chains = num_chains
self._diagnostics = [None] * num_chains

if parallel:
self.sampler = _MultiSampler(kernel, num_samples, self.warmup_steps, num_chains, mp_context,
disable_progbar, initial_params=initial_params, hook=hook_fn)
Expand Down Expand Up @@ -422,17 +454,7 @@ def model(data):
# If transforms is not explicitly provided, infer automatically using
# model args, kwargs.
if self.transforms is None:
# Use `kernel.transforms` when available
if getattr(self.kernel, "transforms", None) is not None:
self.transforms = self.kernel.transforms
# Else, get transforms from model (e.g. in multiprocessing).
elif self.kernel.model:
warmup_steps = 0
self.kernel.setup(warmup_steps, *args, **kwargs)
self.transforms = self.kernel.transforms
# Assign default value
else:
self.transforms = {}
self._set_transforms(*args, **kwargs)

# transform samples back to constrained space
for name, z in z_acc.items():
Expand All @@ -447,30 +469,10 @@ def get_samples(self, num_samples=None, group_by_chain=False):
"""
Get samples from the MCMC run, potentially resampling with replacement.
:param int num_samples: Number of samples to return. If `None`, all the samples
from an MCMC chain are returned in their original ordering.
:param bool group_by_chain: Whether to preserve the chain dimension. If True,
all samples will have num_chains as the size of their leading dimension.
:return: dictionary of samples keyed by site name.
For parameter details see: :meth:`select_samples <pyro.infer.mcmc.util.select_samples>`.
"""
samples = self._samples
if num_samples is None:
# reshape to collapse chain dim when group_by_chain=False
if not group_by_chain:
samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()}
else:
if not samples:
raise ValueError("No samples found from MCMC run.")
if group_by_chain:
batch_dim = 1
else:
samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()}
batch_dim = 0
sample_tensor = list(samples.values())[0]
batch_size, device = sample_tensor.shape[batch_dim], sample_tensor.device
idxs = torch.randint(0, batch_size, size=(num_samples,), device=device)
samples = {k: v.index_select(batch_dim, idxs) for k, v in samples.items()}
return samples
return select_samples(samples, num_samples, group_by_chain)

def diagnostics(self):
"""
Expand All @@ -496,3 +498,107 @@ def summary(self, prob=0.9):
if 'divergences' in self._diagnostics[0]:
print("Number of divergences: {}".format(
sum([len(self._diagnostics[i]['divergences']) for i in range(self.num_chains)])))


class StreamingMCMC(AbstractMCMC):
"""
MCMC that computes required statistics in a streaming fashion. For this class no samples are retained
but only aggregated statistics. This is useful for running memory expensive models where we care only
about specific statistics (especially useful in a memory constrained environments like GPU).
For available streaming ops please see :mod:`~pyro.ops.streaming`.
"""
def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None,
statistics=None, num_chains=1, hook_fn=None, disable_progbar=False,
disable_validation=True, transforms=None, save_params=None):
super().__init__(kernel, num_chains, transforms)
self.warmup_steps = num_samples if warmup_steps is None else warmup_steps # Stan
self.num_samples = num_samples
self.disable_validation = disable_validation
self._samples = None
self._args = None
self._kwargs = None
if statistics is None:
statistics = StatsOfDict(default=CountMeanVarianceStats)
self._statistics = statistics
self._default_statistics = copy.deepcopy(statistics)
if save_params is not None:
kernel.save_params = save_params
self._validate_kernel(initial_params)
if num_chains > 1:
if initial_params:
self._validate_initial_params(initial_params)
else:
if initial_params:
initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()}
self._diagnostics = [None] * num_chains
self.sampler = _UnarySampler(kernel, num_samples, self.warmup_steps, num_chains, disable_progbar,
initial_params=initial_params, hook=hook_fn)

@poutine.block
def run(self, *args, **kwargs):
"""
Run StreamingMCMC to compute required `self._statistics`.
"""
self._args, self._kwargs = args, kwargs
num_samples = [0] * self.num_chains

with optional(pyro.validation_enabled(not self.disable_validation),
self.disable_validation is not None):
args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
for x, chain_id in self.sampler.run(*args, **kwargs):
if num_samples[chain_id] == 0:
# If transforms is not explicitly provided, infer automatically using
# model args, kwargs.
if self.transforms is None:
self._set_transforms(*args, **kwargs)
num_samples[chain_id] += 1
z_structure = x
elif num_samples[chain_id] == self.num_samples + 1:
self._diagnostics[chain_id] = x
else:
num_samples[chain_id] += 1
if self.num_chains > 1:
x_cloned = x.clone()
del x
else:
x_cloned = x

# unpack latent
pos = 0
z_acc = z_structure.copy()
for k in sorted(z_structure):
shape = z_structure[k]
next_pos = pos + shape.numel()
z_acc[k] = x_cloned[pos:next_pos].reshape(shape)
pos = next_pos

for name, z in z_acc.items():
if name in self.transforms:
z_acc[name] = self.transforms[name].inv(z)

self._statistics.update({
(chain_id, name): transformed_sample for name, transformed_sample in z_acc.items()
})

# terminate the sampler (shut down worker processes)
self.sampler.terminate(True)

def get_statistics(self, group_by_chain=True):
"""
Returns a dict of statistics defined by those passed to the class constructor.
:param bool group_by_chain: Whether statistics should be chain-wise or merged together.
"""
if group_by_chain:
return self._statistics.get()
else:
# merge all chains with respect to names
merged_dict: Dict[str, StreamingStats] = {}
for (_, name), stat in self._statistics.stats.items():
if name in merged_dict:
merged_dict[name] = merged_dict[name].merge(stat)
else:
merged_dict[name] = stat

return {k: v.get() for k, v in merged_dict.items()}
30 changes: 30 additions & 0 deletions pyro/infer/mcmc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,33 @@ def wrapped_fn(*args, **kwargs):
predictions[site] = value.reshape(shape)

return predictions


def select_samples(samples, num_samples=None, group_by_chain=False):
"""
Performs selection from given MCMC samples.
:param dictionary samples: Samples object to sample from.
:param int num_samples: Number of samples to return. If `None`, all the samples
from an MCMC chain are returned in their original ordering.
:param bool group_by_chain: Whether to preserve the chain dimension. If True,
all samples will have num_chains as the size of their leading dimension.
:return: dictionary of samples keyed by site name.
"""
if num_samples is None:
# reshape to collapse chain dim when group_by_chain=False
if not group_by_chain:
samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()}
else:
if not samples:
raise ValueError("No samples found from MCMC run.")
if group_by_chain:
batch_dim = 1
else:
samples = {k: v.reshape((-1,) + v.shape[2:]) for k, v in samples.items()}
batch_dim = 0
sample_tensor = list(samples.values())[0]
batch_size, device = sample_tensor.shape[batch_dim], sample_tensor.device
idxs = torch.randint(0, batch_size, size=(num_samples,), device=device)
samples = {k: v.index_select(batch_dim, idxs) for k, v in samples.items()}
return samples
Loading

0 comments on commit 8fd0bf5

Please sign in to comment.