diff --git a/docs/api/core.rst b/docs/api/core.rst index 2ea20ff450..36b253c32a 100644 --- a/docs/api/core.rst +++ b/docs/api/core.rst @@ -14,6 +14,7 @@ Compose compose.Encoder compose.LossRecorder compose.BaseModuleClass + compose.PyroBaseModuleClass .. autosummary:: :toctree: reference/ @@ -80,6 +81,7 @@ steps for modules like `TOTALVAE`, `SCANVAE`, etc. lightning.TrainingPlan lightning.SemiSupervisedTrainingPlan lightning.AdversarialTrainingPlan + lightning.PyroTrainingPlan lightning.Trainer Utilities diff --git a/pyproject.toml b/pyproject.toml index 3e93c8dd3a..56136d5211 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ openpyxl = ">=3.0" pandas = ">=1.0" pre-commit = {version = ">=2.7.1", optional = true} prospector = {version = "*", optional = true} +pyro-ppl = ">=1.1.0" pytest = {version = ">=4.4", optional = true} python = ">=3.6.1,<4.0" python-igraph = {version = "*", optional = true} diff --git a/scvi/compose/__init__.py b/scvi/compose/__init__.py index 338c4aeb86..f9cac03ff1 100644 --- a/scvi/compose/__init__.py +++ b/scvi/compose/__init__.py @@ -9,7 +9,7 @@ MultiDecoder, MultiEncoder, ) -from ._base_module import BaseModuleClass, LossRecorder +from ._base_module import BaseModuleClass, LossRecorder, PyroBaseModuleClass from ._decorators import auto_move_data from ._utils import one_hot # Do we want one_hot here? @@ -27,4 +27,5 @@ "BaseModuleClass", "one_hot", "auto_move_data", + "PyroBaseModuleClass", ] diff --git a/scvi/compose/_base_module.py b/scvi/compose/_base_module.py index 4bf64efe53..1c20252e58 100644 --- a/scvi/compose/_base_module.py +++ b/scvi/compose/_base_module.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Iterable, Optional, Tuple, Union import torch import torch.nn as nn @@ -194,3 +194,42 @@ def _get_dict_if_none(param): param = {} if not isinstance(param, dict) else param return param + + +class PyroBaseModuleClass(nn.Module): + """ + Base module class for Pyro models. + + In Pyro, `model` and `guide` should have the same signature. Out of convenience, + the forward function of this class passes through to the forward of the `model`. + + There are two ways this class can be equipped with a model and a guide. First, + `model` and `guide` can be class attributes that are :class:`~pyro.nn.PyroModule` + instances. Second, `model` and `guide` methods can be written (see Pyro scANVI example) + https://pyro.ai/examples/scanvi.html + """ + + def __init__(self): + super().__init__() + + @staticmethod + @abstractmethod + def _get_fn_args_from_batch( + tensor_dict: Dict[str, torch.Tensor] + ) -> Union[Iterable, dict]: + """ + Parse the minibatched data to get the correct inputs for `model` and `guide`. + + In Pyro, `model` and `guide` must have the same signature. This is a helper method + that gets the args and kwargs for these two methods. This helper method aids `forward` and + `guide` in having transparent signatures, as well as allows use of our generic + :class:`~scvi.dataloaders.AnnDataLoader`. + + Returns + ------- + args and kwargs for the functions, args should be an Iterable and kwargs a dictionary. + """ + + def forward(self, *args, **kwargs): + """Passthrough to Pyro model.""" + return self.model(*args, **kwargs) diff --git a/scvi/lightning/__init__.py b/scvi/lightning/__init__.py index 8ee6e8d039..75d9a06147 100644 --- a/scvi/lightning/__init__.py +++ b/scvi/lightning/__init__.py @@ -1,6 +1,7 @@ from ._trainer import Trainer from ._trainingplans import ( AdversarialTrainingPlan, + PyroTrainingPlan, SemiSupervisedTrainingPlan, TrainingPlan, ) @@ -8,6 +9,7 @@ __all__ = [ "TrainingPlan", "Trainer", + "PyroTrainingPlan", "SemiSupervisedTrainingPlan", "AdversarialTrainingPlan", ] diff --git a/scvi/lightning/_trainingplans.py b/scvi/lightning/_trainingplans.py index 94b4495171..9826ecbfb4 100644 --- a/scvi/lightning/_trainingplans.py +++ b/scvi/lightning/_trainingplans.py @@ -1,13 +1,14 @@ from inspect import getfullargspec -from typing import Union +from typing import Callable, Optional, Union +import pyro import pytorch_lightning as pl import torch from torch.optim.lr_scheduler import ReduceLROnPlateau from scvi import _CONSTANTS from scvi._compat import Literal -from scvi.compose import BaseModuleClass, one_hot +from scvi.compose import PyroBaseModuleClass, BaseModuleClass, one_hot from scvi.modules import Classifier @@ -522,3 +523,59 @@ def validation_step(self, batch, batch_idx, optimizer_idx=0): "kl_global": scvi_losses.kl_global, "n_obs": reconstruction_loss.shape[0], } + + +class PyroTrainingPlan(pl.LightningModule): + """ + Lightning module task to train Pyro scvi-tools modules. + + Parameters + ---------- + pyro_module + A model instance from class ``PyroBaseModuleClass``. + lr + Learning rate used for optimization. + """ + + def __init__( + self, + pyro_module: PyroBaseModuleClass, + lr: float = 1e-3, + loss_fn: Optional[Callable] = None, + ): + super().__init__() + self.module = pyro_module + self.loss_fn = loss_fn + self.lr = lr + + if loss_fn is None: + self.loss_fn = pyro.infer.Trace_ELBO() + + self.automatic_optimization = False + self.pyro_guide = self.module.guide + self.pyro_model = self.module.model + + self.svi = pyro.infer.SVI( + model=self.pyro_model, + guide=self.pyro_guide, + optim=pyro.optim.Adam({"lr": self.lr}), + loss=self.loss_fn, + ) + + def forward(self, *args, **kwargs): + """Passthrough to `model.forward()`.""" + return self.module(*args, **kwargs) + + def training_step(self, batch, batch_idx, optimizer_idx=0): + args, kwargs = self.module._get_fn_args_from_batch(batch) + loss = self.svi.step(*args, **kwargs) + self.log("train_loss", loss, prog_bar=True, on_epoch=True) + + def configure_optimizers(self): + return None + + def optimizer_step(self, *args, **kwargs): + pass + + def backward(self, *args, **kwargs): + pass diff --git a/scvi/model/base/_base_model.py b/scvi/model/base/_base_model.py index daab890162..5d395ea001 100644 --- a/scvi/model/base/_base_model.py +++ b/scvi/model/base/_base_model.py @@ -6,6 +6,7 @@ from typing import Optional, Sequence import numpy as np +import pyro import rich import torch from anndata import AnnData @@ -13,6 +14,7 @@ from sklearn.model_selection._split import _validate_shuffle_split from scvi import _CONSTANTS, settings +from scvi.compose import PyroBaseModuleClass from scvi.data import get_from_registry, transfer_anndata_setup from scvi.data._anndata import _check_anndata_setup_equivalence from scvi.data._utils import _check_nonnegative_integers @@ -434,7 +436,18 @@ def load( for attr, val in attr_dict.items(): setattr(model, attr, val) - model.module.load_state_dict(model_state_dict) + # some Pyro modules with AutoGuides may need one training step + try: + model.module.load_state_dict(model_state_dict) + except RuntimeError as err: + if isinstance(model.module, PyroBaseModuleClass): + logger.info("Preparing underlying module for load") + model.train(max_steps=1) + pyro.clear_param_store() + model.module.load_state_dict(model_state_dict) + else: + raise err + if use_gpu: model.module.cuda() diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py new file mode 100644 index 0000000000..c230078874 --- /dev/null +++ b/tests/models/test_pyro.py @@ -0,0 +1,120 @@ +import os + +import numpy as np +import pyro +import pyro.distributions as dist +import torch +import torch.nn as nn +from pyro.infer.autoguide import AutoDiagonalNormal +from pyro.nn import PyroModule + +from scvi import _CONSTANTS +from scvi.compose import PyroBaseModuleClass +from scvi.data import synthetic_iid +from scvi.dataloaders import AnnDataLoader +from scvi.lightning import PyroTrainingPlan, Trainer + + +class BayesianRegressionPyroModel(PyroModule): + def __init__(self, in_features, out_features): + super().__init__() + + self.register_buffer("zero", torch.tensor(0.0)) + self.register_buffer("one", torch.tensor(1.0)) + self.register_buffer("ten", torch.tensor(10.0)) + + self.linear = PyroModule[nn.Linear](in_features, out_features) + + def forward(self, x, y): + sigma = pyro.sample("sigma", dist.Uniform(self.zero, self.ten)) + mean = self.linear(x).squeeze(-1) + with pyro.plate("data", x.shape[0]): + pyro.sample("obs", dist.Normal(mean, sigma), obs=y) + return mean + + +class BayesianRegressionModule(PyroBaseModuleClass): + def __init__(self, in_features, out_features): + + super().__init__() + self.model = BayesianRegressionPyroModel(in_features, out_features) + self.guide = AutoDiagonalNormal(self.model) + + @staticmethod + def _get_fn_args_from_batch(tensor_dict): + x = tensor_dict[_CONSTANTS.X_KEY] + y = tensor_dict[_CONSTANTS.LABELS_KEY] + + return (x, y.squeeze(1)), {} + + +def test_pyro_bayesian_regression(save_path): + use_gpu = 0 + adata = synthetic_iid() + train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) + pyro.clear_param_store() + model = BayesianRegressionModule(adata.shape[1], 1) + plan = PyroTrainingPlan(model) + trainer = Trainer( + gpus=use_gpu, + max_epochs=2, + ) + trainer.fit(plan, train_dl) + + # test save and load + post_dl = AnnDataLoader(adata, shuffle=False, batch_size=128) + mean1 = [] + with torch.no_grad(): + for tensors in post_dl: + args, kwargs = model._get_fn_args_from_batch(tensors) + mean1.append(model(*args, **kwargs).cpu().numpy()) + mean1 = np.concatenate(mean1) + + model_save_path = os.path.join(save_path, "model_params.pt") + torch.save(model.state_dict(), model_save_path) + + pyro.clear_param_store() + new_model = BayesianRegressionModule(adata.shape[1], 1) + # run model one step to get autoguide params + try: + new_model.load_state_dict(torch.load(model_save_path)) + except RuntimeError as err: + if isinstance(new_model, PyroBaseModuleClass): + plan = PyroTrainingPlan(new_model) + trainer = Trainer( + gpus=use_gpu, + max_steps=1, + ) + trainer.fit(plan, train_dl) + new_model.load_state_dict(torch.load(model_save_path)) + else: + raise err + + mean2 = [] + with torch.no_grad(): + for tensors in post_dl: + args, kwargs = new_model._get_fn_args_from_batch(tensors) + mean2.append(new_model(*args, **kwargs).cpu().numpy()) + mean2 = np.concatenate(mean2) + + np.testing.assert_array_equal(mean1, mean2) + + +def test_pyro_bayesian_regression_jit(): + use_gpu = 0 + adata = synthetic_iid() + train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) + pyro.clear_param_store() + model = BayesianRegressionModule(adata.shape[1], 1) + # warmup guide for JIT + for tensors in train_dl: + args, kwargs = model._get_fn_args_from_batch(tensors) + model.guide(*args, **kwargs) + break + train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) + plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) + trainer = Trainer( + gpus=use_gpu, + max_epochs=2, + ) + trainer.fit(plan, train_dl)