Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyro integration #895

Merged
merged 23 commits into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Compose
compose.Encoder
compose.LossRecorder
compose.BaseModuleClass
compose.PyroBaseModuleClass

.. autosummary::
:toctree: reference/
Expand Down Expand Up @@ -68,6 +69,7 @@ steps for modules like `TOTALVAE`, `SCANVAE`, etc.
lightning.TrainingPlan
lightning.SemiSupervisedTrainingPlan
lightning.AdversarialTrainingPlan
lightning.PyroTrainingPlan
lightning.Trainer

Utilities
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pandas = ">=1.0"
pre-commit = {version = ">=2.7.1", optional = true}
prospector = {version = "*", optional = true}
pydata-sphinx-theme = {version = ">=0.4.0", optional = true}
pyro-ppl = ">=1.1.0"
adamgayoso marked this conversation as resolved.
Show resolved Hide resolved
pytest = {version = ">=4.4", optional = true}
python = ">=3.6.1,<4.0"
python-igraph = {version = "*", optional = true}
Expand Down
3 changes: 2 additions & 1 deletion scvi/compose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
MultiDecoder,
MultiEncoder,
)
from ._base_module import BaseModuleClass, LossRecorder
from ._base_module import BaseModuleClass, LossRecorder, PyroBaseModuleClass
from ._decorators import auto_move_data
from ._utils import one_hot # Do we want one_hot here?

Expand All @@ -27,4 +27,5 @@
"BaseModuleClass",
"one_hot",
"auto_move_data",
"PyroBaseModuleClass",
]
67 changes: 66 additions & 1 deletion scvi/compose/_base_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Optional, List, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -194,3 +194,68 @@ def _get_dict_if_none(param):
param = {} if not isinstance(param, dict) else param

return param


class PyroBaseModuleClass(nn.Module):
"""
Base module class for Pyro models.

We use the forward method of this class as the "model" in Pyro terms,
and also implement the guide here, through the guide method.

Parameters
----------
loss
Tensor with loss for minibatch. Should be one dimensional with one value.
Note that loss should be a :class:`~torch.Tensor` and not the result of `.item()`.
reconstruction_loss
Reconstruction loss for each observation in the minibatch.
kl_local
KL divergence associated with each observation in the minibatch.
kl_global
Global kl divergence term. Should be one dimensional with one value.
"""

def __init__(
self,
):
super().__init__()

@abstractmethod
def _get_guide_tensors(
self, tensors: Dict[str, torch.Tensor]
) -> List[torch.Tensor]:
pass

@abstractmethod
def _guide(self, *args, **kwargs):
pass

@abstractmethod
def guide(self, tensors: Dict[str, torch.Tensor]) -> dict:
"""
Pyro Guide method.

This is a wrapper for the `_guide` method. This function parses
the tensors dictionary and passes to `_guide` as positional arguments.
"""
self._guide(*self._get_guide_tensors(tensors))

@abstractmethod
def _get_forward_tensors(
self, tensors: Dict[str, torch.Tensor]
) -> List[torch.Tensor]:
pass

@abstractmethod
def _forward(self, *args, **kwargs):
pass

def forward(self, tensors: Dict[str, torch.Tensor]) -> dict:
"""
Pyro model method.

This is a wrapper for the `_model` method. This function parses
the tensors dictionary and passes to `_model` as positional arguments.
"""
self._forward(*self._get_forward_tensors(tensors))
2 changes: 2 additions & 0 deletions scvi/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from ._trainer import Trainer
from ._trainingplans import (
AdversarialTrainingPlan,
PyroTrainingPlan,
SemiSupervisedTrainingPlan,
TrainingPlan,
)

__all__ = [
"TrainingPlan",
"Trainer",
"PyroTrainingPlan",
"SemiSupervisedTrainingPlan",
"AdversarialTrainingPlan",
]
57 changes: 55 additions & 2 deletions scvi/lightning/_trainingplans.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from inspect import getfullargspec
from typing import Union
from typing import Callable, Union

import pyro
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

from scvi import _CONSTANTS
from scvi._compat import Literal
from scvi.compose import BaseModuleClass, one_hot
from scvi.compose import PyroBaseModuleClass, BaseModuleClass, one_hot
from scvi.modules import Classifier


Expand Down Expand Up @@ -522,3 +523,55 @@ def validation_step(self, batch, batch_idx, optimizer_idx=0):
"kl_global": scvi_losses.kl_global,
"n_obs": reconstruction_loss.shape[0],
}


class PyroTrainingPlan(pl.LightningModule):
"""
Lightning module task to train Pyro scvi-tools modules.

Parameters
----------
pyro_module
A model instance from class ``PyroBaseModuleClass``.
lr
Learning rate used for optimization.
"""

def __init__(
self,
pyro_module: PyroBaseModuleClass,
lr: float = 1e-3,
loss_fn: Callable = pyro.infer.Trace_ELBO(),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: it's dangerous to set stateful default arguments. In this case the same default Trace_ELBO instance will be shared by all models and (details below) bad things may happen. I'd recommend defaulting to None and doing a standard if loss_fn is None: loss_fn = pyro.infer.Trace_ELBO().

Each Trace_ELBO instance guesses and stores the number of plates in its model, and assumes it will be associated with only a single model. If you use it with a different model with a different number of plates, you might see tensor shape errors.

):
super().__init__()
self.module = pyro_module
self.loss_fn = loss_fn
self.lr = lr

self.automatic_optimization = False
self.guide = self.module.guide

self.svi = pyro.infer.SVI(
model=self.module,
guide=self.guide,
optim=pyro.optim.Adam({"lr": self.lr}),
loss=self.loss_fn,
)

def forward(self, *args, **kwargs):
"""Passthrough to `model.forward()`."""
return self.module(*args, **kwargs)

def training_step(self, batch, batch_idx, optimizer_idx=0):

loss = self.svi.step(batch)
self.log("train_loss", loss, prog_bar=True, on_epoch=True)

def configure_optimizers(self):
return None

def optimizer_step(self, *args, **kwargs):
pass

def backward(self, *args, **kwargs):
pass
77 changes: 77 additions & 0 deletions tests/models/test_pyro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pyro
import pyro.distributions as dist
import torch
import torch.nn as nn
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.nn import PyroModule

from scvi import _CONSTANTS
from scvi.compose import PyroBaseModuleClass
from scvi.data import synthetic_iid
from scvi.dataloaders import AnnDataLoader
from scvi.lightning import PyroTrainingPlan, Trainer


class BayesianRegression(PyroModule, PyroBaseModuleClass):
def __init__(self, in_features, out_features):
super().__init__()

self._auto_guide = AutoDiagonalNormal(self)
Copy link

@fritzo fritzo Feb 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technical detail: One of the rules of PyroModules is that the model and guide must be separate PyroModule s, and cannot be contained in a single PyroModule; however they can be contained in a single nn.Module. You might consider refactoring to separate PyroBaseModuleClass from some sort of PyroModelClass like this (where s <: t means issubclass(s, t)):

PyroBaseModuleClass <: nn.Module
  .model : PyroBaseModelClass <: PyroModule
  .guide : AutoNormal <: PyroModule

Basically "a model can't have its guide as an attribute".

The reason is due to PyroModule naming schemes and caching of pyro.sample calls. If model and guide are contained in a single PyroModule there may be weird conflicts in both names and the pyro.sample cache.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super helpful!


self.register_buffer("zero", torch.tensor(0.0, requires_grad=False))
self.register_buffer("one", torch.tensor(1.0, requires_grad=False))
self.register_buffer("ten", torch.tensor(10.0, requires_grad=False))

self.linear = nn.Linear(in_features, out_features)

def _get_forward_tensors(self, tensors):
x = tensors[_CONSTANTS.X_KEY]
y = tensors[_CONSTANTS.LABELS_KEY]

return x, y

def _forward(self, x, y):
sigma = pyro.sample("sigma", dist.Uniform(self.zero, self.ten))
mean = self.linear(x).squeeze(-1)
with pyro.plate("data", x.shape[0]):
pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
return mean

def _get_guide_tensors(self, tensors):
return [tensors]

def _guide(self, tensors):
return self._auto_guide(tensors)


def test_pyro_bayesian_regression():
use_gpu = int(torch.cuda.is_available())
adata = synthetic_iid()
train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128)
pyro.clear_param_store()
model = BayesianRegression(adata.shape[1], 1)
plan = PyroTrainingPlan(model)
trainer = Trainer(
gpus=use_gpu,
max_epochs=2,
)
trainer.fit(plan, train_dl)


def test_pyro_bayesian_regression_jit():
use_gpu = int(torch.cuda.is_available())
adata = synthetic_iid()
train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128)
pyro.clear_param_store()
model = BayesianRegression(adata.shape[1], 1)
# warmup guide for JIT
for tensors in train_dl:
model.guide(tensors)
break
train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128)
plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO())
trainer = Trainer(
gpus=use_gpu,
max_epochs=2,
)
trainer.fit(plan, train_dl)